From d41b51e989df711159a28e6b719cd01d9eb7f03d Mon Sep 17 00:00:00 2001 From: Emily Casey Date: Tue, 1 Jul 2025 09:38:30 -0600 Subject: [PATCH 1/5] wip --- cmd/mdltool/main.go | 24 ++++++++++++++++++----- tar/target.go | 47 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 tar/target.go diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index ed692f8..3fae486 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -11,6 +11,7 @@ import ( "github.com/docker/model-distribution/builder" "github.com/docker/model-distribution/distribution" "github.com/docker/model-distribution/registry" + "github.com/docker/model-distribution/tar" ) // stringSliceFlag is a flag that can be specified multiple times to collect multiple string values @@ -153,11 +154,17 @@ func cmdPull(client *distribution.Client, args []string) int { func cmdPackage(args []string) int { fs := flag.NewFlagSet("package", flag.ExitOnError) - var licensePaths stringSliceFlag - var contextSize uint64 + var ( + licensePaths stringSliceFlag + contextSize uint64 + file string + ) fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)") fs.Uint64Var(&contextSize, "context-size", 0, "Context size in tokens") + fs.StringVar(&file, "file", "", "Write model to the given file instead of pushing to a registry") + fs.Parse(args) + fs.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] \n\n") fmt.Fprintf(os.Stderr, "Options:\n") @@ -207,10 +214,17 @@ func cmdPackage(args []string) int { // Create registry client once with all options registryClient := registry.NewClient(registryClientOpts...) - // Parse the reference - target, err := registryClient.NewTarget(reference) + var ( + target builder.Target + err error + ) + if file != "" { + target, err = tar.NewTarget(reference, file) + } else { + target, err = registryClient.NewTarget(reference) + } if err != nil { - fmt.Fprintf(os.Stderr, "Error parsing reference: %v\n", err) + fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err) return 1 } diff --git a/tar/target.go b/tar/target.go new file mode 100644 index 0000000..c86235a --- /dev/null +++ b/tar/target.go @@ -0,0 +1,47 @@ +package tar + +import ( + "context" + "fmt" + "io" + "os" + + "github.com/docker/model-distribution/internal/progress" + "github.com/docker/model-distribution/types" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/tarball" +) + +type Target struct { + reference name.Reference + writer io.WriteCloser +} + +func (t *Target) Write(ctx context.Context, mdl types.ModelArtifact, progressWriter io.Writer) error { + defer t.writer.Close() + + pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, nil) + defer pr.Wait() + + if err := tarball.Write(t.reference, mdl, t.writer, + tarball.WithProgress(pr.Updates()), + ); err != nil { + return fmt.Errorf("write to tarball %q: %w", t.reference.String(), err) + } + return nil +} + +func NewTarget(tag string, path string) (*Target, error) { + ref, err := name.NewTag(tag) + if err != nil { + return nil, fmt.Errorf("invalid tag: %q: %w", ref, err) + } + f, err := os.Create(path) + if err != nil { + return nil, fmt.Errorf("error creating tar archive at path %q: %w", path, err) + } + return &Target{ + reference: ref, + writer: f, + }, nil +} From cd4b309af0ba7944e752c0e13a12d38ce2e1caf4 Mon Sep 17 00:00:00 2001 From: Emily Casey Date: Tue, 15 Jul 2025 17:14:01 -0600 Subject: [PATCH 2/5] wip --- assets/dummy.tar | Bin 0 -> 6656 bytes cmd/mdltool/main.go | 35 ++++++++- distribution/client.go | 63 +++++++++++++++- distribution/import_test.go | 54 ++++++++++++++ internal/store/blobs.go | 38 ++++++++++ internal/store/manifests.go | 4 +- internal/store/store.go | 87 +++++++++++++++++++++- internal/store/store_test.go | 29 ++++++++ tar/target.go | 47 ------------ tarball/target.go | 135 +++++++++++++++++++++++++++++++++++ tarball/target_test.go | 109 ++++++++++++++++++++++++++++ 11 files changed, 545 insertions(+), 56 deletions(-) create mode 100644 assets/dummy.tar create mode 100644 distribution/import_test.go delete mode 100644 tar/target.go create mode 100644 tarball/target.go create mode 100644 tarball/target_test.go diff --git a/assets/dummy.tar b/assets/dummy.tar new file mode 100644 index 0000000000000000000000000000000000000000..3f51d627fa49aabcdad00aedf10d370ab32ca526 GIT binary patch literal 6656 zcmeHL&1>976kn(5cOj5|oY%T`SDr78q=gcbLQL;TE+x3L(rB$lE4h-^X<~ft|B(Nn zm!3*bh4xsQQc5YI6k2>P1PU!Zwb1vpl38n4N{qcp!TVq|nt8wXF>glV%^RyY40M*8 zA7)A1Sq>2(DP;$Hpn6XvJDQvFDYYr-bjfnuIj(h*~aVOXCOeFS= z%-hjpbyKHVn8e0_^^CFhLv{V`t~LgS+TD#pPvyYd*o%EL@j^2ktC2Rx)ydhTZ{y)H zdWM(8K{zr7`>iVc#JjCiQ-$U@@pWWQrOZ8I9x2MgdtltLXwUG&5gbqOB->HUk^^E< z%XK{8w|pwNOLX86A%$RyTfP#KGZ{F-XO5&W=Mpsx_)svPa1hslZyfZDNZp0y8iV(r z)X~vs5}11&mN@l{5h~|h^K2XJ5oH6H26PzsLQpTDmdz>Gt`3Gy;Q2Hl4yV-Rlu%y> zgj-tqj1M_84sIPh?L#n(l3})FqxDatHlBX-#1C6hgS}&FT8)cs*V{PuLRE>2G zKkZ2v=MJ(izrGlN{54=bdUG)!6;OErnBsDSjRJu_J!n(H*+qbSe6vjrXU_xXlTRN! zdJB4_+H!&@Q8{NX0O$9I|Gd-UvlJMiKYsi2e!EF2UIpBx^>_8-%u4YZAh-X%X8(Hs zt1AuR941NhBEZ%fT+8w_j7Kj40?i`r=qO6Mcak*n8y2n7)C{S78Ne823?tH7=K_o` zc3PwY&QJ=&g_200ZN0bo#h?AnzyAIFN{b};iIV7OC5c;uR5)A>3cvCEAgv=DDc5n9 zq}v1;cy#Ji!%@UlbBOKY1E`uS^)b|ii$SV-eBZ&Td;3xmmfx#d0=IqfZqbmV=rf4(cjZ9378Y{gsn7ITvMS z*SUF2sIpq?zq#7!U~8gpf!Ef)1OET1)z|So`1dT^4`6?p?MKjGwffp|{1focQ>(|| zA*>%;6a5UlleD^?UqF7g)%oq?WBK1{?^X?_9+Ur3i}9)^fY1(MRN8cv|22g3OsDr} zF6}{*jul?D@wy9l-H#?J&cl1f`9X(58$Bg8~L=PRZd|2CJ44e9}7WMb?LM>VB^-oQ!nA1 R>HlD*Y!!i31kO(c{sT1%c6|T< literal 0 HcmV?d00001 diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 3fae486..c1c1e56 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -11,7 +11,7 @@ import ( "github.com/docker/model-distribution/builder" "github.com/docker/model-distribution/distribution" "github.com/docker/model-distribution/registry" - "github.com/docker/model-distribution/tar" + "github.com/docker/model-distribution/tarball" ) // stringSliceFlag is a flag that can be specified multiple times to collect multiple string values @@ -104,6 +104,8 @@ func main() { exitCode = cmdRm(client, args) case "tag": exitCode = cmdTag(client, args) + case "import": + exitCode = cmdImport(client, args) default: fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command) printUsage() @@ -158,11 +160,13 @@ func cmdPackage(args []string) int { licensePaths stringSliceFlag contextSize uint64 file string + load bool ) fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)") fs.Uint64Var(&contextSize, "context-size", 0, "Context size in tokens") fs.StringVar(&file, "file", "", "Write model to the given file instead of pushing to a registry") + fs.BoolVar(&load, "load", false, "Load the model to the store instead of pushing to a registry") fs.Parse(args) fs.Usage = func() { @@ -219,7 +223,13 @@ func cmdPackage(args []string) int { err error ) if file != "" { - target, err = tar.NewTarget(reference, file) + target, err = tarball.NewFileTarget(reference, file) + } else if load { + //target, err = distribution.NewClient() + //if err != nil { + // fmt.Fprintf(os.Stderr, "Failed creating distribution client: %v\n", err) + // return 1 + //} } else { target, err = registryClient.NewTarget(reference) } @@ -258,6 +268,27 @@ func cmdPackage(args []string) int { return 0 } +func cmdImport(client *distribution.Client, args []string) int { + if len(args) < 1 { + fmt.Fprintf(os.Stderr, "Error: missing argument\n") + fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool import \n") + return 1 + } + f, err := os.Open(args[0]) + if err != nil { + fmt.Fprintf(os.Stderr, "Error opening model: %v\n", err) + return 1 + } + defer f.Close() + ctx := context.Background() + + if err := client.ImportModel(ctx, "", f, os.Stdout); err != nil { + fmt.Fprintf(os.Stderr, "Error importing model: %v\n", err) + return 1 + } + return 0 +} + func cmdPush(client *distribution.Client, args []string) int { if len(args) < 1 { fmt.Fprintf(os.Stderr, "Error: missing tag argument\n") diff --git a/distribution/client.go b/distribution/client.go index 1057944..5910c7e 100644 --- a/distribution/client.go +++ b/distribution/client.go @@ -3,12 +3,11 @@ package distribution import ( "context" "fmt" + "github.com/sirupsen/logrus" "io" "net/http" "os" - "github.com/sirupsen/logrus" - "github.com/docker/model-distribution/internal/progress" "github.com/docker/model-distribution/internal/store" "github.com/docker/model-distribution/registry" @@ -206,6 +205,66 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter return nil } +//func (c *Client) ImportModel(ctx context.Context, reference string, rc io.ReadCloser, progressWriter io.Writer) error { +// //tag, err := name.NewTag(reference) +// //if err != nil { +// // return fmt.Errorf("parsing reference: %w", err) +// //} +// mdl, err := tarball.Image(func() (io.ReadCloser, error) { +// return rc, nil +// }, nil) +// if err != nil { +// return fmt.Errorf("reading inpute: %w", err) +// } +// tr := tar.NewReader(rc) +// hdr, err := tr.Next() +// if err == io.EOF { +// return nil +// } else if err != nil { +// return fmt.Errorf("reading tarball: %w", err) +// } +// if hdr.Name != "manifest.json" { +// return fmt.Errorf("expected manfifest as first entry got %q", hdr.Name) +// } +// c.store.WriteManifest() +// else { +// hash := strings.TrimSuffix() +// } +// +// // Model doesn't exist in local store or digests don't match, pull from remote +// +// if err = c.store.Write(mdl, []string{}, progressWriter); err != nil { +// if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { +// c.log.Warnf("Failed to write error message: %v", writeErr) +// // If we fail to write error message, don't try again +// progressWriter = nil +// } +// return fmt.Errorf("writing image to store: %w", err) +// } +// +// if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil { +// c.log.Warnf("Failed to write success message: %v", err) +// // If we fail to write success message, don't try again +// progressWriter = nil +// } +// +// return nil +//} + +func (c *Client) ImportModel(ctx context.Context, reference string, rc io.ReadCloser, progressWriter io.Writer) error { + if err := c.store.Stream(rc, []string{reference}, nil); err != nil { + return err + } + + if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil { + c.log.Warnf("Failed to write success message: %v", err) + // If we fail to write success message, don't try again + progressWriter = nil + } + + return nil +} + // ListModels returns all available models func (c *Client) ListModels() ([]types.Model, error) { c.log.Infoln("Listing available models") diff --git a/distribution/import_test.go b/distribution/import_test.go new file mode 100644 index 0000000..88abdd5 --- /dev/null +++ b/distribution/import_test.go @@ -0,0 +1,54 @@ +package distribution + +import ( + "github.com/docker/model-distribution/builder" + "github.com/docker/model-distribution/tarball" + "io" + "os" + "testing" +) + +func TestImportModel(t *testing.T) { + // Create temp directory for store + tempDir, err := os.MkdirTemp("", "model-distribution-test-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create client + client, err := NewClient(WithStoreRootPath(tempDir)) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + pr, pw := io.Pipe() + target, err := tarball.NewTarget(pw) + if err != nil { + t.Fatalf("Failed to create target: %v", err) + } + done := make(chan error) + go func() { + done <- client.ImportModel(t.Context(), "some/model", pr, nil) + }() + // Create model archive + bldr, err := builder.FromGGUF(testGGUFFile) + if err != nil { + t.Fatalf("Failed to create builder: %v", err) + } + err = bldr.Build(t.Context(), target, nil) + if err != nil { + t.Fatalf("Failed to build model: %v", err) + } + select { + case err := <-done: + if err != nil { + t.Fatalf("Failed to import model: %v", err) + } + case <-t.Context().Done(): + } + + if _, err := client.GetModel("some/model"); err != nil { + t.Fatalf("Failed to get model: %v", err) + } +} diff --git a/internal/store/blobs.go b/internal/store/blobs.go index 3a75d4b..52fe5d7 100644 --- a/internal/store/blobs.go +++ b/internal/store/blobs.go @@ -66,6 +66,44 @@ func (s *LocalStore) writeBlob(layer blob, progress chan<- v1.Update) error { return nil } +// writeBlob write the blob to the store, reporting progress to the given channel. +// If the blob is already in the store, it is a no-op. +//func (s *LocalStore) importBlobs(tr tar.Reader error { +// hash, err := layer.DiffID() +// if err != nil { +// return fmt.Errorf("get file hash: %w", err) +// } +// if s.hasBlob(hash) { +// // todo: write something to the progress channel (we probably need to redo progress reporting a little bit) +// return nil +// } +// +// path := s.blobPath(hash) +// lr, err := layer.Uncompressed() +// if err != nil { +// return fmt.Errorf("get blob contents: %w", err) +// } +// defer lr.Close() +// r := withProgress(lr, progress) +// +// f, err := createFile(incompletePath(path)) +// if err != nil { +// return fmt.Errorf("create blob file: %w", err) +// } +// defer os.Remove(incompletePath(path)) +// defer f.Close() +// +// if _, err := io.Copy(f, r); err != nil { +// return fmt.Errorf("copy blob %q to store: %w", hash.String(), err) +// } +// +// f.Close() // Rename will fail on Windows if the file is still open. +// if err := os.Rename(incompletePath(path), path); err != nil { +// return fmt.Errorf("rename blob file: %w", err) +// } +// return nil +//} + // removeBlob removes the blob with the given hash from the store. func (s *LocalStore) removeBlob(hash v1.Hash) error { return os.Remove(s.blobPath(hash)) diff --git a/internal/store/manifests.go b/internal/store/manifests.go index b8b458f..04940a3 100644 --- a/internal/store/manifests.go +++ b/internal/store/manifests.go @@ -17,8 +17,8 @@ func (s *LocalStore) manifestPath(hash v1.Hash) string { return filepath.Join(s.rootPath, manifestsDir, hash.Algorithm, hash.Hex) } -// writeManifest writes the model's manifest to the store -func (s *LocalStore) writeManifest(mdl v1.Image) error { +// WriteManifest writes the model's manifest to the store +func (s *LocalStore) WriteManifest(mdl v1.Image) error { digest, err := mdl.Digest() if err != nil { return fmt.Errorf("get digest: %w", err) diff --git a/internal/store/store.go b/internal/store/store.go index 3ba8ec5..31da4ff 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,13 +1,14 @@ package store import ( + "archive/tar" + "bytes" "fmt" + "github.com/docker/model-distribution/internal/progress" "io" "os" "path/filepath" - "github.com/docker/model-distribution/internal/progress" - v1 "github.com/google/go-containerregistry/pkg/v1" ) @@ -217,7 +218,7 @@ func (s *LocalStore) Write(mdl v1.Image, tags []string, w io.Writer) error { } // Write the manifest - if err := s.writeManifest(mdl); err != nil { + if err := s.WriteManifest(mdl); err != nil { return fmt.Errorf("writing manifest: %w", err) } @@ -245,6 +246,86 @@ func (s *LocalStore) Write(mdl v1.Image, tags []string, w io.Writer) error { return s.writeIndex(idx) } +var _ blob = &streamBlob{} + +type streamBlob struct { + diffID v1.Hash + rc io.ReadCloser +} + +func (s streamBlob) DiffID() (v1.Hash, error) { + return s.diffID, nil +} + +func (s streamBlob) Uncompressed() (io.ReadCloser, error) { + return s.rc, nil +} + +// Write writes a model to the store +func (s *LocalStore) Stream(rc io.ReadCloser, tags []string, w io.Writer) error { + tr := tar.NewReader(rc) + + entry := IndexEntry{ + Tags: tags, + } + + for hdr, err := tr.Next(); err != io.EOF; hdr, err = tr.Next() { + fi := hdr.FileInfo() + if fi.IsDir() { + continue + } + fmt.Println("processing", hdr.Name) + if filepath.Dir(filepath.Dir(hdr.Name)) == "blobs" { + fmt.Println("writing blob", fi.Name()) + diffID := v1.Hash{ + Algorithm: "sha256", + Hex: filepath.Base(fi.Name()), + } + if err := s.writeBlob(&streamBlob{ + diffID: diffID, + rc: io.NopCloser(tr), + }, nil); err != nil { + return fmt.Errorf("writing blob: %w", err) + } + entry.Files = append(entry.Files, s.blobPath(diffID)) + } + if hdr.Name == "manifest.json" { + rm, err := io.ReadAll(tr) + if err != nil { + return fmt.Errorf("reading manifest: %w", err) + } + digest, _, err := v1.SHA256(bytes.NewBuffer(rm)) + if err != nil { + return fmt.Errorf("digest: %w", err) + } + fmt.Println("writing manifest", fi.Name()) + if err := writeFile(s.manifestPath(digest), rm); err != nil { + return fmt.Errorf("writing manifest: %w", err) + } + entry.ID = digest.String() + } + } + + // Add the model to the index + idx, err := s.readIndex() + if err != nil { + return fmt.Errorf("reading models: %w", err) + } + + // Add the model tags + idx = idx.Add(entry) + for _, tag := range tags { + updatedIdx, err := idx.Tag(entry.ID, tag) + if err != nil { + fmt.Printf("Warning: failed to tag model %q with tag %q: %v\n", entry.ID, tag, err) + continue + } + idx = updatedIdx + } + + return s.writeIndex(idx) +} + // Read reads a model from the store by reference (either tag or ID) func (s *LocalStore) Read(reference string) (*Model, error) { models, err := s.List() diff --git a/internal/store/store_test.go b/internal/store/store_test.go index b580aef..1222362 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -455,6 +455,35 @@ func TestIncompleteFileHandling(t *testing.T) { } } +func TestStream(t *testing.T) { + tempDir, err := os.MkdirTemp("", "incomplete-file-test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create store + storePath := filepath.Join(tempDir, "incomplete-model-store") + s, err := store.New(store.Options{ + RootPath: storePath, + }) + if err != nil { + t.Fatalf("Failed to create store: %v", err) + } + + f, err := os.Open("/tmp/alpine.tar") + if err != nil { + t.Fatalf("Open failed: %v", err) + } + defer f.Close() + + if err := s.Stream(f, []string{}, nil); err != nil { + t.Fatalf("Stream failed: %v", err) + } + fmt.Println("done") + +} + // Helper function to check if a tag is in a slice of tags func containsTag(tags []string, tag string) bool { for _, t := range tags { diff --git a/tar/target.go b/tar/target.go deleted file mode 100644 index c86235a..0000000 --- a/tar/target.go +++ /dev/null @@ -1,47 +0,0 @@ -package tar - -import ( - "context" - "fmt" - "io" - "os" - - "github.com/docker/model-distribution/internal/progress" - "github.com/docker/model-distribution/types" - "github.com/google/go-containerregistry/pkg/name" - "github.com/google/go-containerregistry/pkg/v1/tarball" -) - -type Target struct { - reference name.Reference - writer io.WriteCloser -} - -func (t *Target) Write(ctx context.Context, mdl types.ModelArtifact, progressWriter io.Writer) error { - defer t.writer.Close() - - pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, nil) - defer pr.Wait() - - if err := tarball.Write(t.reference, mdl, t.writer, - tarball.WithProgress(pr.Updates()), - ); err != nil { - return fmt.Errorf("write to tarball %q: %w", t.reference.String(), err) - } - return nil -} - -func NewTarget(tag string, path string) (*Target, error) { - ref, err := name.NewTag(tag) - if err != nil { - return nil, fmt.Errorf("invalid tag: %q: %w", ref, err) - } - f, err := os.Create(path) - if err != nil { - return nil, fmt.Errorf("error creating tar archive at path %q: %w", path, err) - } - return &Target{ - reference: ref, - writer: f, - }, nil -} diff --git a/tarball/target.go b/tarball/target.go new file mode 100644 index 0000000..826fb9a --- /dev/null +++ b/tarball/target.go @@ -0,0 +1,135 @@ +package tarball + +import ( + "archive/tar" + "context" + "fmt" + "github.com/docker/model-distribution/types" + "github.com/google/go-containerregistry/pkg/name" + "io" + "os" + "path/filepath" +) + +type Target struct { + reference name.Tag + writer io.Writer +} + +func (t *Target) Write(ctx context.Context, mdl types.ModelArtifact, progressWriter io.Writer) error { + //defer t.writer.Close() + + //pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, nil) + //defer pr.Wait() + + tw := tar.NewWriter(t.writer) + defer tw.Close() + + rm, err := mdl.RawManifest() + if err != nil { + return err + } + + if err := tw.WriteHeader(&tar.Header{ + Name: "blobs", + Typeflag: tar.TypeDir, + }); err != nil { + return err + } + + ls, err := mdl.Layers() + if err != nil { + return err + } + algDirs := map[string]struct{}{} + for _, layer := range ls { + dgst, err := layer.Digest() + if err != nil { + return err + } + _, ok := algDirs[dgst.Algorithm] + if !ok { + if err = tw.WriteHeader(&tar.Header{ + Name: filepath.Join("blobs", dgst.Algorithm), + Typeflag: tar.TypeDir, + }); err != nil { + return err + } + algDirs[dgst.Algorithm] = struct{}{} + } + sz, err := layer.Size() + if err != nil { + return err + } + if err = tw.WriteHeader(&tar.Header{ + Name: filepath.Join("blobs", dgst.Algorithm, dgst.Hex), + Mode: 0666, + Size: sz, + }); err != nil { + return err + } + rc, err := layer.Uncompressed() + if err != nil { + return err + } + defer rc.Close() + if _, err = io.Copy(tw, rc); err != nil { + return err + } + } + rcf, err := mdl.RawConfigFile() + if err != nil { + return err + } + cn, err := mdl.ConfigName() + if err != nil { + return err + } + if err = tw.WriteHeader(&tar.Header{ + Name: filepath.Join("blobs", cn.Algorithm, cn.Hex), + Mode: 0666, + Size: int64(len(rcf)), + }); err != nil { + return err + } + if _, err = tw.Write(rcf); err != nil { + return fmt.Errorf("write config blob contents: %w", err) + } + + if err := tw.WriteHeader(&tar.Header{ + Name: "manifest.json", + Size: int64(len(rm)), + Mode: 0666, + }); err != nil { + return fmt.Errorf("write manifest.json header: %w", err) + } + if _, err = tw.Write(rm); err != nil { + return fmt.Errorf("write manifest.json contents: %w", err) + } + + return nil +} + +func NewFileTarget(tag string, path string) (*Target, error) { + var ref name.Tag + if tag != "" { + ref, err := name.NewTag(tag) + if err != nil { + return nil, fmt.Errorf("invalid tag: %q: %w", ref, err) + } + } + f, err := os.Create(path) + if err != nil { + return nil, fmt.Errorf("error creating tar archive at path %q: %w", path, err) + } + return &Target{ + reference: ref, + writer: f, + }, nil +} + +func NewTarget(w io.Writer) (*Target, error) { + return &Target{ + writer: w, + }, nil +} diff --git a/tarball/target_test.go b/tarball/target_test.go new file mode 100644 index 0000000..15a7c4c --- /dev/null +++ b/tarball/target_test.go @@ -0,0 +1,109 @@ +package tarball_test + +import ( + "archive/tar" + "bytes" + "github.com/docker/model-distribution/internal/gguf" + "github.com/docker/model-distribution/tarball" + v1 "github.com/google/go-containerregistry/pkg/v1" + "io" + "os" + "path/filepath" + "testing" +) + +func TestTarget(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "tar-test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + path := filepath.Join(tmpDir, "result.tar") + + target, err := tarball.NewFileTarget("", path) + if err != nil { + t.Fatalf("Failed to create tar target: %v", err) + } + + mdl, err := gguf.NewModel(filepath.Join("..", "assets", "dummy.gguf")) + if err != nil { + t.Fatalf("Failed to create model: %v", err) + } + + f, err := os.Open(filepath.Join("..", "assets", "dummy.gguf")) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + blobContents, err := io.ReadAll(f) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + blobHash, _, err := v1.SHA256(bytes.NewReader(blobContents)) + if err != nil { + t.Fatalf("Failed to calculate hash: %v", err) + } + configDigest, err := mdl.ConfigName() + if err != nil { + t.Fatalf("Failed to get raw config: %v", err) + } + configContents, err := mdl.RawConfigFile() + if err != nil { + t.Fatalf("Failed to get raw config: %v", err) + } + manifestContents, err := mdl.RawManifest() + if err != nil { + t.Fatalf("Failed to get raw manifest contents: %v", err) + } + + if err := target.Write(t.Context(), mdl, nil); err != nil { + t.Fatalf("Failed to write model to tar file: %v", err) + } + + tf, err := os.Open(path) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + tr := tar.NewReader(tf) + hasDir(t, tr, "blobs") + hasDir(t, tr, "blobs/sha256") + hasFile(t, tr, "blobs/sha256/"+blobHash.Hex, blobContents) + hasFile(t, tr, "blobs/sha256/"+configDigest.Hex, configContents) + hasFile(t, tr, "manifest.json", manifestContents) + hasFile(t, tr, "manifest.json", manifestContents) +} + +func hasFile(t *testing.T, tr *tar.Reader, name string, contents []byte) { + hdr, err := tr.Next() + if err != nil { + t.Fatalf("Failed to read header: %v", err) + } + if hdr.Name != name { + t.Fatalf("Unexpected next entry with name %q got %q", name, hdr.Name) + } + if hdr.Typeflag != tar.TypeReg { + t.Fatalf("Unexpected entry with name %q to be a file got type %v", name, hdr.Typeflag) + } + if hdr.Size != int64(len(contents)) { + t.Fatalf("Unexpected entry with name %q size %d got %d", name, hdr.Size, hdr.Size) + } + c, err := io.ReadAll(tr) + if err != nil { + t.Fatalf("Failed to read contents: %v", err) + } + if !bytes.Equal(contents, c) { + t.Fatalf("Unexpected contents for file %q", name) + } +} + +func hasDir(t *testing.T, tr *tar.Reader, name string) { + hdr, err := tr.Next() + if err != nil { + t.Fatalf("Failed to read header: %v", err) + } + if hdr.Name != name { + t.Fatalf("Unexpected next entry with name %q got %q", name, hdr.Name) + } + if hdr.Typeflag != tar.TypeDir { + t.Fatalf("Unexpected entry with name %q to be a directory got type %v", name, hdr.Typeflag) + } +} From 07e8611b6d44d14950a08057aea6bc32d9080655 Mon Sep 17 00:00:00 2001 From: Emily Casey Date: Wed, 16 Jul 2025 08:13:53 -0600 Subject: [PATCH 3/5] file target Signed-off-by: Emily Casey --- cmd/mdltool/main.go | 16 +++++----------- tarball/file.go | 32 ++++++++++++++++++++++++++++++++ tarball/target.go | 19 ------------------- tarball/target_test.go | 18 +++++++----------- 4 files changed, 44 insertions(+), 41 deletions(-) create mode 100644 tarball/file.go diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index c1c1e56..39665b1 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -223,19 +223,13 @@ func cmdPackage(args []string) int { err error ) if file != "" { - target, err = tarball.NewFileTarget(reference, file) - } else if load { - //target, err = distribution.NewClient() - //if err != nil { - // fmt.Fprintf(os.Stderr, "Failed creating distribution client: %v\n", err) - // return 1 - //} + target = tarball.NewFileTarget(file) } else { target, err = registryClient.NewTarget(reference) - } - if err != nil { - fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err) - return 1 + if err != nil { + fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err) + return 1 + } } // Create image with layer diff --git a/tarball/file.go b/tarball/file.go new file mode 100644 index 0000000..6b0132a --- /dev/null +++ b/tarball/file.go @@ -0,0 +1,32 @@ +package tarball + +import ( + "context" + "fmt" + "github.com/docker/model-distribution/types" + "io" + "os" +) + +type FileTarget struct { + path string +} + +func (t *FileTarget) Write(ctx context.Context, mdl types.ModelArtifact, pw io.Writer) error { + f, err := os.Create(t.path) + if err != nil { + return fmt.Errorf("create file for archive: %w", err) + } + defer f.Close() + target, err := NewTarget(f) + if err != nil { + return fmt.Errorf("create target: %w", err) + } + return target.Write(ctx, mdl, pw) +} + +func NewFileTarget(path string) *FileTarget { + return &FileTarget{ + path: path, + } +} diff --git a/tarball/target.go b/tarball/target.go index 826fb9a..a705bb5 100644 --- a/tarball/target.go +++ b/tarball/target.go @@ -7,7 +7,6 @@ import ( "github.com/docker/model-distribution/types" "github.com/google/go-containerregistry/pkg/name" "io" - "os" "path/filepath" ) @@ -110,24 +109,6 @@ func (t *Target) Write(ctx context.Context, mdl types.ModelArtifact, progressWri return nil } -func NewFileTarget(tag string, path string) (*Target, error) { - var ref name.Tag - if tag != "" { - ref, err := name.NewTag(tag) - if err != nil { - return nil, fmt.Errorf("invalid tag: %q: %w", ref, err) - } - } - f, err := os.Create(path) - if err != nil { - return nil, fmt.Errorf("error creating tar archive at path %q: %w", path, err) - } - return &Target{ - reference: ref, - writer: f, - }, nil -} - func NewTarget(w io.Writer) (*Target, error) { return &Target{ writer: w, diff --git a/tarball/target_test.go b/tarball/target_test.go index 15a7c4c..5d341a2 100644 --- a/tarball/target_test.go +++ b/tarball/target_test.go @@ -13,14 +13,15 @@ import ( ) func TestTarget(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "tar-test") + f, err := os.CreateTemp("", "tar-test") if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) + t.Fatalf("Failed to file for tar: %v", err) } - defer os.RemoveAll(tmpDir) - path := filepath.Join(tmpDir, "result.tar") + path := f.Name() + defer os.Remove(f.Name()) + defer f.Close() - target, err := tarball.NewFileTarget("", path) + target, err := tarball.NewTarget(f) if err != nil { t.Fatalf("Failed to create tar target: %v", err) } @@ -30,11 +31,7 @@ func TestTarget(t *testing.T) { t.Fatalf("Failed to create model: %v", err) } - f, err := os.Open(filepath.Join("..", "assets", "dummy.gguf")) - if err != nil { - t.Fatalf("Failed to open file: %v", err) - } - blobContents, err := io.ReadAll(f) + blobContents, err := os.ReadFile(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to read file: %v", err) } @@ -69,7 +66,6 @@ func TestTarget(t *testing.T) { hasFile(t, tr, "blobs/sha256/"+blobHash.Hex, blobContents) hasFile(t, tr, "blobs/sha256/"+configDigest.Hex, configContents) hasFile(t, tr, "manifest.json", manifestContents) - hasFile(t, tr, "manifest.json", manifestContents) } func hasFile(t *testing.T, tr *tar.Reader, name string, contents []byte) { From 8a7c37dada1d8a75d6617302891b0517690a38c0 Mon Sep 17 00:00:00 2001 From: Emily Casey Date: Wed, 16 Jul 2025 08:26:31 -0600 Subject: [PATCH 4/5] rename to load --- cmd/mdltool/main.go | 25 ++++++---- distribution/client.go | 50 +------------------ distribution/{import_test.go => load_test.go} | 12 ++--- 3 files changed, 22 insertions(+), 65 deletions(-) rename distribution/{import_test.go => load_test.go} (81%) diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 39665b1..d745e80 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -104,8 +104,8 @@ func main() { exitCode = cmdRm(client, args) case "tag": exitCode = cmdTag(client, args) - case "import": - exitCode = cmdImport(client, args) + case "load": + exitCode = cmdLoad(client, args) default: fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command) printUsage() @@ -160,13 +160,11 @@ func cmdPackage(args []string) int { licensePaths stringSliceFlag contextSize uint64 file string - load bool ) fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)") fs.Uint64Var(&contextSize, "context-size", 0, "Context size in tokens") fs.StringVar(&file, "file", "", "Write model to the given file instead of pushing to a registry") - fs.BoolVar(&load, "load", false, "Load the model to the store instead of pushing to a registry") fs.Parse(args) fs.Usage = func() { @@ -262,22 +260,31 @@ func cmdPackage(args []string) int { return 0 } -func cmdImport(client *distribution.Client, args []string) int { +func cmdLoad(client *distribution.Client, args []string) int { + fs := flag.NewFlagSet("load", flag.ExitOnError) + var ( + tag string + ) + fs.StringVar(&tag, "tag", "", "Apply the tag to the loaded model") + fs.Parse(args) + args = fs.Args() + if len(args) < 1 { fmt.Fprintf(os.Stderr, "Error: missing argument\n") - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool import \n") + fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool load \n") return 1 } + f, err := os.Open(args[0]) if err != nil { - fmt.Fprintf(os.Stderr, "Error opening model: %v\n", err) + fmt.Fprintf(os.Stderr, "Error opening model file: %v\n", err) return 1 } defer f.Close() ctx := context.Background() - if err := client.ImportModel(ctx, "", f, os.Stdout); err != nil { - fmt.Fprintf(os.Stderr, "Error importing model: %v\n", err) + if err := client.LoadModel(ctx, tag, f, os.Stdout); err != nil { + fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err) return 1 } return 0 diff --git a/distribution/client.go b/distribution/client.go index 5910c7e..357d4c0 100644 --- a/distribution/client.go +++ b/distribution/client.go @@ -205,58 +205,12 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter return nil } -//func (c *Client) ImportModel(ctx context.Context, reference string, rc io.ReadCloser, progressWriter io.Writer) error { -// //tag, err := name.NewTag(reference) -// //if err != nil { -// // return fmt.Errorf("parsing reference: %w", err) -// //} -// mdl, err := tarball.Image(func() (io.ReadCloser, error) { -// return rc, nil -// }, nil) -// if err != nil { -// return fmt.Errorf("reading inpute: %w", err) -// } -// tr := tar.NewReader(rc) -// hdr, err := tr.Next() -// if err == io.EOF { -// return nil -// } else if err != nil { -// return fmt.Errorf("reading tarball: %w", err) -// } -// if hdr.Name != "manifest.json" { -// return fmt.Errorf("expected manfifest as first entry got %q", hdr.Name) -// } -// c.store.WriteManifest() -// else { -// hash := strings.TrimSuffix() -// } -// -// // Model doesn't exist in local store or digests don't match, pull from remote -// -// if err = c.store.Write(mdl, []string{}, progressWriter); err != nil { -// if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { -// c.log.Warnf("Failed to write error message: %v", writeErr) -// // If we fail to write error message, don't try again -// progressWriter = nil -// } -// return fmt.Errorf("writing image to store: %w", err) -// } -// -// if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil { -// c.log.Warnf("Failed to write success message: %v", err) -// // If we fail to write success message, don't try again -// progressWriter = nil -// } -// -// return nil -//} - -func (c *Client) ImportModel(ctx context.Context, reference string, rc io.ReadCloser, progressWriter io.Writer) error { +func (c *Client) LoadModel(ctx context.Context, reference string, rc io.ReadCloser, progressWriter io.Writer) error { if err := c.store.Stream(rc, []string{reference}, nil); err != nil { return err } - if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil { + if err := progress.WriteSuccess(progressWriter, "Model loaded successfully"); err != nil { c.log.Warnf("Failed to write success message: %v", err) // If we fail to write success message, don't try again progressWriter = nil diff --git a/distribution/import_test.go b/distribution/load_test.go similarity index 81% rename from distribution/import_test.go rename to distribution/load_test.go index 88abdd5..3ab5a17 100644 --- a/distribution/import_test.go +++ b/distribution/load_test.go @@ -8,7 +8,7 @@ import ( "testing" ) -func TestImportModel(t *testing.T) { +func TestLoadModel(t *testing.T) { // Create temp directory for store tempDir, err := os.MkdirTemp("", "model-distribution-test-*") if err != nil { @@ -29,7 +29,7 @@ func TestImportModel(t *testing.T) { } done := make(chan error) go func() { - done <- client.ImportModel(t.Context(), "some/model", pr, nil) + done <- client.LoadModel(t.Context(), "some/model", pr, nil) }() // Create model archive bldr, err := builder.FromGGUF(testGGUFFile) @@ -40,12 +40,8 @@ func TestImportModel(t *testing.T) { if err != nil { t.Fatalf("Failed to build model: %v", err) } - select { - case err := <-done: - if err != nil { - t.Fatalf("Failed to import model: %v", err) - } - case <-t.Context().Done(): + if err := <-done; err != nil { + t.Fatalf("LoadModel exited with error: %v", err) } if _, err := client.GetModel("some/model"); err != nil { From ed48d68a706d7d3e2bb1b824528694501c59a6fa Mon Sep 17 00:00:00 2001 From: Emily Casey Date: Wed, 16 Jul 2025 15:00:50 -0600 Subject: [PATCH 5/5] adds testdata --- tarball/reader.go | 39 ++++++++++++++++++ tarball/reader_test.go | 1 + tarball/testdata/archive.tar | Bin 0 -> 18432 bytes ...948ac0d8348b1138b628ba4946fbb8bd77828fe3bc | 1 + ...f53354b168041c352e750c001165f3848ee9fc509f | 1 + 5 files changed, 42 insertions(+) create mode 100644 tarball/reader.go create mode 100644 tarball/reader_test.go create mode 100644 tarball/testdata/archive.tar create mode 100644 tarball/testdata/archive/blobs/sha256/e3e972b5cbbd6aace0837af6e18be2daba89e5c3099ecddfb6cc3e7fc1bae145b72a499f29d0b191ec4668948ac0d8348b1138b628ba4946fbb8bd77828fe3bc create mode 100644 tarball/testdata/archive/blobs/sha512/38e52851c619572fa0c879f53354b168041c352e750c001165f3848ee9fc509f diff --git a/tarball/reader.go b/tarball/reader.go new file mode 100644 index 0000000..eb221b7 --- /dev/null +++ b/tarball/reader.go @@ -0,0 +1,39 @@ +package tarball + +import ( + "archive/tar" + v1 "github.com/google/go-containerregistry/pkg/v1" + "io" + "log" +) + +type Reader struct { + tr tar.Reader +} + +type Blob struct { + diffID v1.Hash + rc io.ReadCloser +} + +func (r Reader) NextBlob() (*Blob, error) { + hdr, err := r.tr.Next() + if err != nil { + return nil, err + } + for { + hdr, err := r.tr.Next() + if err == io.EOF { + break // End of archive + } + if err != nil { + log.Fatalf("Error reading tar entry: %v", err) + } + } +} + +func NewReader(rc io.ReadCloser) *Reader { + return &Reader{ + tr: tar.NewReader(rc), + } +} diff --git a/tarball/reader_test.go b/tarball/reader_test.go new file mode 100644 index 0000000..7ab3c4b --- /dev/null +++ b/tarball/reader_test.go @@ -0,0 +1 @@ +package tarball diff --git a/tarball/testdata/archive.tar b/tarball/testdata/archive.tar new file mode 100644 index 0000000000000000000000000000000000000000..c111f61623cd56b1307ec82512c1df5d77a5ac87 GIT binary patch literal 18432 zcmeHOOK;;g5Y{eG6iv_Vp(g=7ugur*amXQU7f68gwVN*5UILP)ngDj3wY^35(nJ44 zfdUKk(p&#lul*6F>?pEqZ?sv6iXBr1A(<8*o|%vHksM~HSrYm}2nm2RDej!l?ufAY z-RJpdnqokdBkVskMVN3PNs2;utf>FhWU9x$Yi@Wxcy0B>y;L{J2$4TA4-#)#=b#QX>?E1B$#GgKW zyj1K_mH&wy4KsatIdIwKc=W5g&==Mvj~*w#fBwV2pZ)j8o28=hYnu}IPG225-FM@j zeKj7ti)sIS+-ChnT^<*=;~$A7{BuS<{;%5Z$h!J=9REb6!|D0Z?PCFiE2)^s7*>>U zO7=jc$46frJp1zS=*83QN>8WbY#mzn557M-_~rbY7wl?08T>riV=8^_3myM-TcHJr z@2%5J%cWvkWf7Z)R&0u6E0cG8!|BC-1js^Gpj|vg_ zUqFp4R2})>f-rXPG3!+fFw{8Eh<{v;|0Vtbh#&tW{|j%A6xH~jP?BUqP^<`miUcsc z0`9qmtT#h_P!`yP|BC;W_@|sk{@00Rl={z3CubUPvD(x6)ujJA|0~x2Smb}LYvC;w zM*jB}!R$tMiT{cPwm%Lu;=jiK3jF)!|H%J3F~Al6Csf8U2nig?$p1Rg!F@JT;(w0% zFf_0U|8@R1^Z$|mbz~H!{%8D;01we@lm6@cuUP*BzbPE|e?#bErOwFzRz$K(872N} z{IBJ4pb`Hy{#W3iLoxr4>wn=bafScgtp72=BmWDpfO~GC#J}f%9`&2R{|fvw&f@xC zCzer-``M|ZP6=Yb8q?=mTZbeSI_C~c<49i{EtLaHky6g~`?=xPQYUhYjdqxU5kxbk za-#akV1=D!oXc7K4`=guzxDBd0$=a{BL;+|h*IEU%g6sZusYG@`!s@mVCJQsp6%EB z>!Sj~Sc76xThy17NrN$!h7)PL)i9nLBTZijNo4M*v3tv|2@8}dL6Pw9h!`>SKH6R_ z(}A9g_`lone-r^#5Jpe>6RF<0nN7&=WhlnEC@ClZS5dU5IsS+E}2eKKjQ``>1q4Bw)+20zq<( ztQ0B-KPP7fbBP$X6o?amEJD~X)#g+(>72?fAhqs@++FxTV~`>ONNhX*-xdpq{C`_C z6{Md__^+k|Esg_?_`f;-hu9qdoDV{!d^68Ti<5fhfD%_D^Kl?Q17| zjZ64%5%pEyoA8fFdH%n|KjYx{f8zQ7N`|4HCZnN?^gonzW4%!VQ36o{n=66;0cN;> AmjD0& literal 0 HcmV?d00001 diff --git a/tarball/testdata/archive/blobs/sha256/e3e972b5cbbd6aace0837af6e18be2daba89e5c3099ecddfb6cc3e7fc1bae145b72a499f29d0b191ec4668948ac0d8348b1138b628ba4946fbb8bd77828fe3bc b/tarball/testdata/archive/blobs/sha256/e3e972b5cbbd6aace0837af6e18be2daba89e5c3099ecddfb6cc3e7fc1bae145b72a499f29d0b191ec4668948ac0d8348b1138b628ba4946fbb8bd77828fe3bc new file mode 100644 index 0000000..34e9ef6 --- /dev/null +++ b/tarball/testdata/archive/blobs/sha256/e3e972b5cbbd6aace0837af6e18be2daba89e5c3099ecddfb6cc3e7fc1bae145b72a499f29d0b191ec4668948ac0d8348b1138b628ba4946fbb8bd77828fe3bc @@ -0,0 +1 @@ +other-blob-contents diff --git a/tarball/testdata/archive/blobs/sha512/38e52851c619572fa0c879f53354b168041c352e750c001165f3848ee9fc509f b/tarball/testdata/archive/blobs/sha512/38e52851c619572fa0c879f53354b168041c352e750c001165f3848ee9fc509f new file mode 100644 index 0000000..ec9277e --- /dev/null +++ b/tarball/testdata/archive/blobs/sha512/38e52851c619572fa0c879f53354b168041c352e750c001165f3848ee9fc509f @@ -0,0 +1 @@ +some-blob-contents