Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 68 additions & 27 deletions runner/internal/model_hub/model_hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package model_hub
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
Expand All @@ -35,6 +36,8 @@ import (
"github.com/NexaAI/nexa-sdk/runner/internal/types"
)

const ProgressSuffix = ".progress"

type ModelFileInfo struct {
Name string `json:"name"`
Size int64 `json:"size"`
Expand Down Expand Up @@ -129,11 +132,12 @@ func GetFileContent(ctx context.Context, modelName, fileName string) ([]byte, er

type downloadTask struct {
OutputPath string

ModelName string
FileName string
Offset int64
Limit int64
ModelName string
FileName string
Offset int64
Limit int64
MarkerPath string
ChunkIndex int
}

const (
Expand All @@ -144,7 +148,6 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo
slog.Info("Starting download", "model", modelName, "outputPath", outputPath, "files", files)

hub, err := getHub(ctx, modelName)

if err != nil {
resCh := make(chan types.DownloadInfo)
errCh := make(chan error, 1)
Expand All @@ -157,20 +160,18 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo
maxConcurrency := hub.MaxConcurrency()
resCh := make(chan types.DownloadInfo)
errCh := make(chan error, maxConcurrency)

slog.Info("GetHub", "hub", reflect.TypeOf(hub), "maxConcurrency", maxConcurrency)

go func() {
defer close(errCh)
defer close(resCh)

var downloaded int64
var totalSize int64
for _, f := range files {
totalSize += f.Size
}

// create tasks
var downloaded int64
var markerPaths []string
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)

Expand All @@ -179,28 +180,62 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo
errCh <- fmt.Errorf("failed to create directory: %v, %s", err, f.Name)
return
}

// create download tasks for each chunk
chunkSize := max(minChunkSize, f.Size/128)
nChunks := int((f.Size + chunkSize - 1) / chunkSize)
outPath := filepath.Join(outputPath, f.Name)
markerPath := filepath.Join(outputPath, f.Name+ProgressSuffix)

markers, err := os.ReadFile(markerPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
errCh <- err
return
}
if err != nil || len(markers) != nChunks {
markers = make([]byte, nChunks)
if err := os.WriteFile(markerPath, markers, 0o644); err != nil {
errCh <- err
return
}
}
file, err := os.OpenFile(outPath, os.O_RDWR|os.O_CREATE, 0o644)
if err != nil {
errCh <- err
return
}
if fi, _ := file.Stat(); fi == nil || fi.Size() < f.Size {
if err := file.Truncate(f.Size); err != nil {
file.Close()
errCh <- err
return
}
}
file.Close()
markerPaths = append(markerPaths, markerPath)

slog.Info("Download file", "name", f.Name, "size", f.Size, "chunkSize", chunkSize)

for offset := int64(0); offset < f.Size; offset += chunkSize {
task := downloadTask{
for i, marker := range markers {
if marker == 0x01 {
downloaded += min(chunkSize, f.Size-int64(i)*chunkSize)
continue
}
offset := int64(i) * chunkSize
t := downloadTask{
OutputPath: outputPath,
ModelName: modelName,
FileName: f.Name,
Offset: offset,
Limit: min(chunkSize, f.Size-offset),
MarkerPath: markerPath,
ChunkIndex: i,
}

g.Go(func() error {
if err := doTask(gctx, hub, task); err != nil {
slog.Error("Download task failed", "task", task, "error", err)
if err := doTask(gctx, hub, t); err != nil {
slog.Error("Download task failed", "task", t, "error", err)
return err
}

resCh <- types.DownloadInfo{
TotalDownloaded: atomic.AddInt64(&downloaded, task.Limit),
TotalDownloaded: atomic.AddInt64(&downloaded, t.Limit),
TotalSize: totalSize,
}
return nil
Expand All @@ -210,7 +245,12 @@ func StartDownload(ctx context.Context, modelName, outputPath string, files []Mo

if err := g.Wait(); err != nil {
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When an error occurs in g.Wait() at line 398, the function returns immediately without saving the progress files one final time. This means any chunks that were marked done since the last ticker save (up to 2 seconds worth) will be lost, requiring those chunks to be re-downloaded on the next attempt.

Consider adding a final save loop before returning on error to minimize data loss and reduce unnecessary re-downloads.

Suggested change
if err := g.Wait(); err != nil {
if err := g.Wait(); err != nil {
// Ensure we persist the latest progress state before returning on error.
for _, w := range workList {
if w.progress == nil {
continue
}
// Only attempt to save if there is any completed chunk.
if !w.progress.anyDone() {
continue
}
// Use the same persistence mechanism as the periodic/ticker save loop.
if errSave := w.progress.save(); errSave != nil {
slog.Debug("save progress file on error", "path", w.progress.path, "error", errSave)
}
}

Copilot uses AI. Check for mistakes.
errCh <- err
return
}
for _, p := range markerPaths {
_ = os.Remove(p)
}
slog.Info("download complete", "model", modelName, "outputPath", outputPath)
}()

return resCh, errCh
Expand Down Expand Up @@ -290,19 +330,20 @@ func doTask(ctx context.Context, hub ModelHub, task downloadTask) error {
if err != nil {
return err
}

_, err = file.Seek(task.Offset, io.SeekStart)
if err != nil {
defer file.Close()
if _, err := file.Seek(task.Offset, io.SeekStart); err != nil {
return err
}

err = hub.GetFileContent(ctx, task.ModelName, task.FileName, task.Offset, task.Limit, file)
if err := hub.GetFileContent(ctx, task.ModelName, task.FileName, task.Offset, task.Limit, file); err != nil {
return err
}
marker, err := os.OpenFile(task.MarkerPath, os.O_WRONLY, 0o644)
if err != nil {
file.Close()
return err
}

return file.Close()
defer marker.Close()
_, _ = marker.WriteAt([]byte{0x01}, int64(task.ChunkIndex))
return nil
}

func code2error(client *resty.Client, response *resty.Response) error {
Expand Down
21 changes: 15 additions & 6 deletions runner/internal/store/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ import (
"log/slog"
"os"
"path/filepath"
"strings"
"sync"

"github.com/gofrs/flock"

"github.com/NexaAI/nexa-sdk/runner/internal/config"
"github.com/NexaAI/nexa-sdk/runner/internal/model_hub"
)

type Store struct {
Expand Down Expand Up @@ -111,12 +113,19 @@ func (s *Store) cleanCorruptedDirectories() {

func (s *Store) isCorruptedModelDirectory(name string) bool {
manifestPath := s.ModelfilePath(name, "nexa.manifest")
if _, err := os.Stat(manifestPath); err != nil {
slog.Info("Cleaning corrupted model directory", "name", err)
if _, err := os.Stat(manifestPath); err == nil {
return false
}
dir := s.ModelfilePath(name, "")
entries, err := os.ReadDir(dir)
if err != nil {
return true
}

// TDOD: Check Manifest file should be valid JSON and parseable

return false
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(e.Name(), model_hub.ProgressSuffix) {
return false
}
}
slog.Info("Cleaning corrupted model directory", "name", name)
return true
}
22 changes: 15 additions & 7 deletions runner/internal/store/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,21 @@ func (s *Store) Pull(ctx context.Context, mf types.ModelManifest) (infoCh <-chan
return
}

// clean before
if err := s.Remove(mf.Name); err != nil {
errC <- err
return
modelDir := filepath.Join(s.home, "models", mf.Name)
hasProgress := false
if entries, _ := os.ReadDir(modelDir); entries != nil {
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(e.Name(), model_hub.ProgressSuffix) {
hasProgress = true
break
}
}
}
if !hasProgress {
if err := s.Remove(mf.Name); err != nil {
errC <- err
return
}
}

if err := s.LockModel(mf.Name); err != nil {
Expand Down Expand Up @@ -230,9 +241,6 @@ func (s *Store) Pull(ctx context.Context, mf types.ModelManifest) (infoCh <-chan
return
}

// Pull downloads a model from HuggingFace and stores it locally
// It fetches the model tree, finds .gguf files, downloads them, and saves metadata
// if model not specify, all is set true, and autodetect true
func (s *Store) PullExtraQuant(ctx context.Context, omf, nmf types.ModelManifest) (infoCh <-chan types.DownloadInfo, errCh <-chan error) {
infoC := make(chan types.DownloadInfo, 10)
infoCh = infoC
Expand Down