From 833da5bd08f989073bb6583eaa5f735a0fe541a4 Mon Sep 17 00:00:00 2001 From: Arnab Chatterjee Date: Thu, 29 Jan 2026 17:02:44 +0530 Subject: [PATCH] feat: allow kit init to work with remote repositories (#1055) Signed-off-by: Arnab Chatterjee --- pkg/cmd/kitimport/hfimport.go | 8 +- pkg/cmd/kitinit/cmd.go | 260 +++++++++++++++++++++---- pkg/cmd/kitinit/cmd_test.go | 153 +++++++++++++++ pkg/lib/hf/list.go | 15 +- pkg/lib/hf/parse.go | 96 ++++++++++ pkg/lib/hf/parse_test.go | 351 ++++++++++++++++++++++++++++++++++ pkg/lib/hf/types.go | 25 +++ 7 files changed, 867 insertions(+), 41 deletions(-) create mode 100644 pkg/cmd/kitinit/cmd_test.go create mode 100644 pkg/lib/hf/parse.go create mode 100644 pkg/lib/hf/parse_test.go create mode 100644 pkg/lib/hf/types.go diff --git a/pkg/cmd/kitimport/hfimport.go b/pkg/cmd/kitimport/hfimport.go index daa80435..d0dcfed6 100644 --- a/pkg/cmd/kitimport/hfimport.go +++ b/pkg/cmd/kitimport/hfimport.go @@ -36,10 +36,10 @@ import ( ) func importUsingHF(ctx context.Context, opts *importOptions) error { - // Handle full HF URLs by extracting repository name from URL - repo, err := extractRepoFromURL(opts.repo) + // Parse HuggingFace repository URL to extract repo name and type + repo, repoType, err := hf.ParseHuggingFaceRepo(opts.repo) if err != nil { - return fmt.Errorf("could not process URL %s: %w", opts.repo, err) + return fmt.Errorf("could not process repository %s: %w", opts.repo, err) } tmpDir, cleanupTmp, err := cache.MkCacheDir("import", "") @@ -53,7 +53,7 @@ func importUsingHF(ctx context.Context, opts *importOptions) error { } }() - dirListing, err := hf.ListFiles(ctx, repo, opts.repoRef, opts.token) + dirListing, err := hf.ListFiles(ctx, repo, opts.repoRef, opts.token, repoType) if err != nil { return fmt.Errorf("failed to list files from HuggingFace API: %w", err) } diff --git a/pkg/cmd/kitinit/cmd.go b/pkg/cmd/kitinit/cmd.go index 781cdb77..94c14374 100644 --- a/pkg/cmd/kitinit/cmd.go +++ b/pkg/cmd/kitinit/cmd.go @@ -23,23 +23,32 @@ import ( "io/fs" "os" "path/filepath" + "strings" + "syscall" "github.com/kitops-ml/kitops/pkg/artifact" "github.com/kitops-ml/kitops/pkg/lib/constants" + "github.com/kitops-ml/kitops/pkg/lib/hf" kfgen "github.com/kitops-ml/kitops/pkg/lib/kitfile/generate" "github.com/kitops-ml/kitops/pkg/lib/util" "github.com/kitops-ml/kitops/pkg/output" "github.com/spf13/cobra" + "golang.org/x/term" ) const ( - shortDesc = `Generate a Kitfile for the contents of a directory` - longDesc = `Examine the contents of a directory and attempt to generate a basic Kitfile -based on common file formats. Any files whose type (i.e. model, dataset, etc.) -cannot be determined will be included in a code layer. + shortDesc = `Generate a Kitfile for the contents of a directory or remote repository` + longDesc = `Examine the contents of a directory or remote repository and attempt to generate +a basic Kitfile based on common file formats. Any files whose type (i.e. model, +dataset, etc.) cannot be determined will be included in a code layer. + +For local directories, the generated Kitfile is saved in the target directory. +For remote repositories (e.g. HuggingFace), the Kitfile is printed to stdout +or saved to a path specified with --output. + +By default the command will prompt for input for a name and description for the Kitfile.` -By default the command will prompt for input for a name and description for the Kitfile` example = `# Generate a Kitfile for the current directory: kit init . @@ -47,7 +56,19 @@ kit init . kit init ./my-model --name "mymodel" --desc "This is my model's description" # Generate a Kitfile, overwriting any existing Kitfile: -kit init ./my-model --force` +kit init ./my-model --force + +# Generate a Kitfile for a remote HuggingFace model: +kit init https://huggingface.co/myorg/mymodel + +# Generate a Kitfile for a HuggingFace dataset: +kit init huggingface.co/datasets/myorg/mydataset + +# Generate a Kitfile for a remote repository with a specific ref: +kit init myorg/mymodel --ref v1.0 + +# Save the generated Kitfile to a specific path: +kit init myorg/mymodel --output ./Kitfile` ) type initOptions struct { @@ -57,6 +78,14 @@ type initOptions struct { modelkitDescription string modelkitAuthor string overwrite bool + // Remote repository options + repoRef string + token string + outputPath string + // Computed fields + isRemote bool + repo string + repoType hf.RepositoryType } func InitCommand() *cobra.Command { @@ -75,6 +104,9 @@ func InitCommand() *cobra.Command { cmd.Flags().StringVar(&opts.modelkitDescription, "desc", "", "Description for the ModelKit") cmd.Flags().StringVar(&opts.modelkitAuthor, "author", "", "Author for the ModelKit") cmd.Flags().BoolVarP(&opts.overwrite, "force", "f", false, "Overwrite existing Kitfile if present") + cmd.Flags().StringVar(&opts.repoRef, "ref", "main", "Branch or tag to use for remote repositories") + cmd.Flags().StringVar(&opts.token, "token", "", "Token for authentication with remote repositories") + cmd.Flags().StringVarP(&opts.outputPath, "output", "o", "", "Output path for generated Kitfile (default: stdout for remotes, Kitfile in directory for local)") cmd.Flags().SortFlags = false return cmd } @@ -85,46 +117,126 @@ func runCommand(opts *initOptions) func(*cobra.Command, []string) error { return output.Fatalf("Invalid arguments: %s", err) } - var modelPackage *artifact.Package - if opts.modelkitName != "" || opts.modelkitDescription != "" { - modelPackage = &artifact.Package{ - Name: opts.modelkitName, - Description: opts.modelkitDescription, - } + if opts.isRemote { + return runRemoteInit(cmd.Context(), opts) } - if opts.modelkitAuthor != "" { - modelPackage.Authors = append(modelPackage.Authors, opts.modelkitAuthor) + return runLocalInit(opts) + } +} + +func runLocalInit(opts *initOptions) error { + var modelPackage *artifact.Package + if opts.modelkitName != "" || opts.modelkitDescription != "" { + modelPackage = &artifact.Package{ + Name: opts.modelkitName, + Description: opts.modelkitDescription, } + } + if opts.modelkitAuthor != "" { + if modelPackage == nil { + modelPackage = &artifact.Package{} + } + modelPackage.Authors = append(modelPackage.Authors, opts.modelkitAuthor) + } + + kitfilePath := opts.outputPath + if kitfilePath == "" { + kitfilePath = filepath.Join(opts.path, constants.DefaultKitfileName) + } + + if _, err := os.Stat(kitfilePath); err == nil { + if !opts.overwrite { + return output.Fatalf("Kitfile already exists at %s. Use '--force' to overwrite", kitfilePath) + } + } else if !errors.Is(err, fs.ErrNotExist) { + return output.Fatalf("Error checking for existing Kitfile: %s", err) + } + + dirContents, err := kfgen.DirectoryListingFromFS(opts.path) + if err != nil { + return output.Fatalf("Error processing directory: %s", err) + } + kitfile, err := kfgen.GenerateKitfile(dirContents, modelPackage) + if err != nil { + return output.Fatalf("Error generating Kitfile: %s", err) + } + bytes, err := kitfile.MarshalToYAML() + if err != nil { + return output.Fatalf("Error formatting Kitfile: %s", err) + } + if err := os.WriteFile(kitfilePath, bytes, 0644); err != nil { + return output.Fatalf("Failed to write Kitfile: %s", err) + } + output.Infof("Generated Kitfile:\n\n%s", string(bytes)) + output.Infof("Saved to path '%s'", kitfilePath) + return nil +} + +func runRemoteInit(ctx context.Context, opts *initOptions) error { + if opts.outputPath == "" { + output.SystemInfof("Fetching file listing from remote repository %s (ref: %s)", opts.repo, opts.repoRef) + } else { + output.Infof("Fetching file listing from remote repository %s (ref: %s)", opts.repo, opts.repoRef) + } - // Check for existing Kitfile - kitfilePath := filepath.Join(opts.path, constants.DefaultKitfileName) - if _, err := os.Stat(kitfilePath); err == nil { + dirContents, err := hf.ListFiles(ctx, opts.repo, opts.repoRef, opts.token, opts.repoType) + if err != nil { + return output.Fatalf("Error fetching remote repository: %s", err) + } + + modelPackage := buildPackageFromRepo(opts.repo, opts.modelkitName, opts.modelkitDescription, opts.modelkitAuthor) + + kitfile, err := kfgen.GenerateKitfile(dirContents, modelPackage) + if err != nil { + return output.Fatalf("Error generating Kitfile: %s", err) + } + bytes, err := kitfile.MarshalToYAML() + if err != nil { + return output.Fatalf("Error formatting Kitfile: %s", err) + } + + if opts.outputPath != "" { + if _, err := os.Stat(opts.outputPath); err == nil { if !opts.overwrite { - return output.Fatalf("Kitfile already exists at %s. Use '--force' to overwrite", kitfilePath) + return output.Fatalf("File already exists at %s. Use '--force' to overwrite", opts.outputPath) } } else if !errors.Is(err, fs.ErrNotExist) { - return output.Fatalf("Error checking for existing Kitfile: %s", err) + return output.Fatalf("Error checking for existing file: %s", err) } - dirContents, err := kfgen.DirectoryListingFromFS(opts.path) - if err != nil { - return output.Fatalf("Error processing directory: %s", err) - } - kitfile, err := kfgen.GenerateKitfile(dirContents, modelPackage) - if err != nil { - return output.Fatalf("Error generating Kitfile: %s", err) - } - bytes, err := kitfile.MarshalToYAML() - if err != nil { - return output.Fatalf("Error formatting Kitfile: %s", err) - } - if err := os.WriteFile(kitfilePath, bytes, 0644); err != nil { + if err := os.WriteFile(opts.outputPath, bytes, 0644); err != nil { return output.Fatalf("Failed to write Kitfile: %s", err) } output.Infof("Generated Kitfile:\n\n%s", string(bytes)) - output.Infof("Saved to path '%s'", kitfilePath) - return nil + output.Infof("Saved to path '%s'", opts.outputPath) + } else { + fmt.Print(string(bytes)) } + + return nil +} + +func buildPackageFromRepo(repo, name, description, author string) *artifact.Package { + sections := strings.Split(repo, "/") + modelPackage := &artifact.Package{} + + if name != "" { + modelPackage.Name = name + } else if len(sections) >= 2 { + modelPackage.Name = sections[len(sections)-1] + } + + if description != "" { + modelPackage.Description = description + } + + if author != "" { + modelPackage.Authors = append(modelPackage.Authors, author) + } else if len(sections) >= 2 { + modelPackage.Authors = append(modelPackage.Authors, sections[len(sections)-2]) + } + + return modelPackage } func (opts *initOptions) complete(ctx context.Context, args []string) error { @@ -134,7 +246,27 @@ func (opts *initOptions) complete(ctx context.Context, args []string) error { } opts.configHome = configHome opts.path = args[0] + if opts.path == "~" || strings.HasPrefix(opts.path, "~/") || strings.HasPrefix(opts.path, "~\\") { + home, err := os.UserHomeDir() + if err != nil { + return err + } + if opts.path == "~" { + opts.path = home + } else { + opts.path = filepath.Join(home, opts.path[2:]) + } + } + + opts.isRemote, opts.repo, opts.repoType = detectRemoteRepo(opts.path) + + if opts.isRemote { + return opts.completeRemote() + } + return opts.completeLocal() +} +func (opts *initOptions) completeLocal() error { if util.IsInteractiveSession() { if opts.modelkitName == "" { name, err := util.PromptForInput("Enter a name for the ModelKit: ", false) @@ -158,6 +290,66 @@ func (opts *initOptions) complete(ctx context.Context, args []string) error { opts.modelkitAuthor = author } } + return nil +} +func (opts *initOptions) completeRemote() error { + // For remote repos, only prompt if: + // 1. stdin is a terminal (interactive session) + // 2. stdout is a terminal (not redirected with >) + // 3. no output path specified (output goes to stdout) + stdoutIsTerminal := term.IsTerminal(int(syscall.Stdout)) + if util.IsInteractiveSession() && stdoutIsTerminal && opts.outputPath == "" { + if opts.modelkitDescription == "" { + desc, err := util.PromptForInput("Enter a short description for the ModelKit: ", false) + if err != nil { + return err + } + opts.modelkitDescription = desc + } + } return nil } + +func detectRemoteRepo(path string) (isRemote bool, repo string, repoType hf.RepositoryType) { + // First, check if this looks like an explicit local filesystem path + // Local paths include: ".", "..", paths starting with "./", "../", "/", or containing backslashes + if isLocalPath(path) { + return false, "", hf.RepoTypeUnknown + } + + // Check if the path exists on the filesystem + // This handles cases like "models/my-model" which could be either local or remote + if _, err := os.Stat(path); err == nil { + // Path exists locally - treat as local + return false, "", hf.RepoTypeUnknown + } + + // Path doesn't exist locally - try to parse as a HuggingFace URL/repo + if repo, repoType, err := hf.ParseHuggingFaceRepo(path); err == nil { + return true, repo, repoType + } + + return false, "", hf.RepoTypeUnknown +} + +func isLocalPath(path string) bool { + // Check for common local path patterns + if path == "." || path == ".." { + return true + } + if strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../") { + return true + } + if path == "~" || strings.HasPrefix(path, "~/") || strings.HasPrefix(path, "~\\") { + return true + } + if strings.HasPrefix(path, "/") { + return true + } + // Windows-style paths + if strings.HasPrefix(path, "\\") || (len(path) >= 2 && path[1] == ':') { + return true + } + return false +} diff --git a/pkg/cmd/kitinit/cmd_test.go b/pkg/cmd/kitinit/cmd_test.go new file mode 100644 index 00000000..48828a75 --- /dev/null +++ b/pkg/cmd/kitinit/cmd_test.go @@ -0,0 +1,153 @@ +// Copyright 2024 The KitOps Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package kitinit + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/kitops-ml/kitops/pkg/lib/constants" + "github.com/kitops-ml/kitops/pkg/lib/hf" + "github.com/stretchr/testify/assert" +) + +func TestDetectRemoteRepo(t *testing.T) { + // Create a temporary local directory for testing filesystem detection + tmpDir := t.TempDir() + localNestedDir := filepath.Join(tmpDir, "models", "my-model") + err := os.MkdirAll(localNestedDir, 0755) + assert.NoError(t, err) + + testcases := []struct { + input string + expectedIsRemote bool + expectedRepo string + expectedType hf.RepositoryType + }{ + // HuggingFace URLs should be detected as remote + {input: "https://huggingface.co/org/repo", expectedIsRemote: true, expectedRepo: "org/repo", expectedType: hf.RepoTypeModel}, + {input: "https://huggingface.co/datasets/org/repo", expectedIsRemote: true, expectedRepo: "org/repo", expectedType: hf.RepoTypeDataset}, + {input: "huggingface.co/org/repo", expectedIsRemote: true, expectedRepo: "org/repo", expectedType: hf.RepoTypeModel}, + {input: "huggingface.co/datasets/org/repo", expectedIsRemote: true, expectedRepo: "org/repo", expectedType: hf.RepoTypeDataset}, + {input: "http://huggingface.co/org/repo", expectedIsRemote: true, expectedRepo: "org/repo", expectedType: hf.RepoTypeModel}, + + // org/repo pattern that doesn't exist locally should be treated as remote + {input: "nonexistent-org/nonexistent-repo", expectedIsRemote: true, expectedRepo: "nonexistent-org/nonexistent-repo", expectedType: hf.RepoTypeModel}, + {input: "datasets/org/repo", expectedIsRemote: true, expectedRepo: "org/repo", expectedType: hf.RepoTypeDataset}, + + // Explicit local path patterns should not be detected as remote + {input: ".", expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + {input: "./my-model", expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + {input: "/absolute/path/to/model", expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + {input: "../relative/path", expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + + // CRITICAL: Local nested directories that exist should be treated as local, not remote + {input: localNestedDir, expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + + // Non-HuggingFace URLs should not be detected as remote (for now) + {input: "https://github.com/org/repo", expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + {input: "https://example.com/org/repo", expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + + // Malicious URLs should not be detected as remote + {input: "https://huggingface.co.evil.com/org/repo", expectedIsRemote: false, expectedRepo: "", expectedType: hf.RepoTypeUnknown}, + } + + for _, tt := range testcases { + t.Run(fmt.Sprintf("handles %s", tt.input), func(t *testing.T) { + isRemote, repo, repoType := detectRemoteRepo(tt.input) + assert.Equal(t, tt.expectedIsRemote, isRemote, "isRemote mismatch") + assert.Equal(t, tt.expectedRepo, repo, "repo mismatch") + assert.Equal(t, tt.expectedType, repoType, "repoType mismatch") + }) + } +} + +func TestBuildPackageFromRepo(t *testing.T) { + testcases := []struct { + name string + repo string + inputName string + inputDesc string + inputAuthor string + expectedName string + expectedDesc string + expectedAuthors []string + }{ + { + name: "extracts name and author from repo", + repo: "myorg/mymodel", + expectedName: "mymodel", + expectedAuthors: []string{"myorg"}, + }, + { + name: "user-provided name overrides repo name", + repo: "myorg/mymodel", + inputName: "custom-name", + expectedName: "custom-name", + expectedAuthors: []string{"myorg"}, + }, + { + name: "user-provided author overrides repo org", + repo: "myorg/mymodel", + inputAuthor: "custom-author", + expectedName: "mymodel", + expectedAuthors: []string{"custom-author"}, + }, + { + name: "user-provided description is used", + repo: "myorg/mymodel", + inputDesc: "My model description", + expectedName: "mymodel", + expectedDesc: "My model description", + expectedAuthors: []string{"myorg"}, + }, + { + name: "all user-provided values override defaults", + repo: "myorg/mymodel", + inputName: "custom-name", + inputDesc: "Custom description", + inputAuthor: "custom-author", + expectedName: "custom-name", + expectedDesc: "Custom description", + expectedAuthors: []string{"custom-author"}, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + pkg := buildPackageFromRepo(tt.repo, tt.inputName, tt.inputDesc, tt.inputAuthor) + assert.Equal(t, tt.expectedName, pkg.Name) + assert.Equal(t, tt.expectedDesc, pkg.Description) + assert.Equal(t, tt.expectedAuthors, pkg.Authors) + }) + } +} + +func TestCompleteExpandsTilde(t *testing.T) { + home, err := os.UserHomeDir() + if err != nil { + t.Fatalf("failed to get home dir: %v", err) + } + opts := &initOptions{} + ctx := context.WithValue(context.Background(), constants.ConfigKey{}, "/tmp") + err = opts.complete(ctx, []string{"~/model"}) + assert.NoError(t, err) + assert.Equal(t, filepath.Join(home, "model"), opts.path) +} diff --git a/pkg/lib/hf/list.go b/pkg/lib/hf/list.go index 05a360df..e0b78e77 100644 --- a/pkg/lib/hf/list.go +++ b/pkg/lib/hf/list.go @@ -31,7 +31,8 @@ import ( ) const ( - treeURLFmt = "https://huggingface.co/api/models/%s/tree/%s" + treeURLFmt = "https://huggingface.co/api/models/%s/tree/%s" + datasetTreeURLFmt = "https://huggingface.co/api/datasets/%s/tree/%s" ) type hfTreeResponse []struct { @@ -45,11 +46,19 @@ type hfErrorResponse struct { Error string `json:"error"` } -func ListFiles(ctx context.Context, modelRepo, ref string, token string) (*kfgen.DirectoryListing, error) { +func ListFiles(ctx context.Context, repo, ref, token string, repoType RepositoryType) (*kfgen.DirectoryListing, error) { client := &http.Client{ Timeout: 10 * time.Second, } - baseURL, err := url.Parse(fmt.Sprintf(treeURLFmt, modelRepo, ref)) + + var urlFmt string + if repoType == RepoTypeDataset { + urlFmt = datasetTreeURLFmt + } else { + urlFmt = treeURLFmt + } + + baseURL, err := url.Parse(fmt.Sprintf(urlFmt, repo, ref)) if err != nil { return nil, fmt.Errorf("failed to parse URL: %w", err) } diff --git a/pkg/lib/hf/parse.go b/pkg/lib/hf/parse.go new file mode 100644 index 00000000..161bf74a --- /dev/null +++ b/pkg/lib/hf/parse.go @@ -0,0 +1,96 @@ +// Copyright 2024 The KitOps Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package hf + +import ( + "fmt" + "net/url" + "strings" +) + +// ParseHuggingFaceRepo parses a HuggingFace repository path or URL and returns +// the normalized repository name and type. +// +// Supported formats: +// - Full URL: https://huggingface.co/org/repo +// - Full URL (dataset): https://huggingface.co/datasets/org/repo +// - Short form: org/repo (defaults to model) +// - Dataset short form: datasets/org/repo +func ParseHuggingFaceRepo(path string) (repo string, repoType RepositoryType, err error) { + // Handle "datasets/" prefix (short form) + if strings.HasPrefix(path, "datasets/") { + repo = strings.TrimPrefix(path, "datasets/") + if !isValidRepoFormat(repo) { + return "", RepoTypeUnknown, fmt.Errorf("invalid dataset repository format: %s", path) + } + return repo, RepoTypeDataset, nil + } + + // Check if this looks like a URL (contains :// or starts with domain) + if strings.Contains(path, "://") || strings.HasPrefix(path, "huggingface.co/") { + // Normalize URL by adding scheme if missing + urlStr := path + if !strings.Contains(path, "://") { + urlStr = "https://" + path + } + + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", RepoTypeUnknown, fmt.Errorf("invalid URL format: %w", err) + } + + // Security: validate hostname is exactly huggingface.co + if parsedURL.Hostname() != "huggingface.co" { + return "", RepoTypeUnknown, fmt.Errorf("unsupported hostname: %s (only huggingface.co is supported)", parsedURL.Hostname()) + } + + // Parse path segments + pathSegments := strings.Split(strings.Trim(parsedURL.Path, "/"), "/") + + // Check for dataset URLs + if len(pathSegments) >= 1 && pathSegments[0] == "datasets" { + // Dataset URLs must have format: datasets/org/repo + if len(pathSegments) >= 3 { + repo = strings.Join(pathSegments[1:3], "/") + return repo, RepoTypeDataset, nil + } + // Invalid dataset URL (not enough segments) + return "", RepoTypeUnknown, fmt.Errorf("invalid dataset URL: expected datasets/org/repo, got '%s'", parsedURL.Path) + } + + // Model URLs: should be exactly 2 segments (org/repo) + if len(pathSegments) == 2 { + repo = strings.Join(pathSegments, "/") + return repo, RepoTypeModel, nil + } + + return "", RepoTypeUnknown, fmt.Errorf("unrecognized HuggingFace URL pattern: %s", path) + } + + // Handle short form (org/repo) + if isValidRepoFormat(path) { + return path, RepoTypeModel, nil + } + + return "", RepoTypeUnknown, fmt.Errorf("invalid repository format: %s", path) +} + +// isValidRepoFormat checks if a string is in "org/repo" format +func isValidRepoFormat(s string) bool { + parts := strings.Split(s, "/") + return len(parts) == 2 && parts[0] != "" && parts[1] != "" +} diff --git a/pkg/lib/hf/parse_test.go b/pkg/lib/hf/parse_test.go new file mode 100644 index 00000000..34f420ed --- /dev/null +++ b/pkg/lib/hf/parse_test.go @@ -0,0 +1,351 @@ +// Copyright 2024 The KitOps Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package hf + +import ( + "testing" +) + +func TestParseHuggingFaceRepo(t *testing.T) { + tests := []struct { + name string + input string + wantRepo string + wantRepoType RepositoryType + wantErr bool + }{ + // Valid model URLs + { + name: "full HTTPS model URL", + input: "https://huggingface.co/myorg/mymodel", + wantRepo: "myorg/mymodel", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "full HTTP model URL", + input: "http://huggingface.co/myorg/mymodel", + wantRepo: "myorg/mymodel", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "scheme-less model URL", + input: "huggingface.co/myorg/mymodel", + wantRepo: "myorg/mymodel", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "short form model", + input: "myorg/mymodel", + wantRepo: "myorg/mymodel", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + + // Valid dataset URLs + { + name: "full HTTPS dataset URL", + input: "https://huggingface.co/datasets/myorg/mydataset", + wantRepo: "myorg/mydataset", + wantRepoType: RepoTypeDataset, + wantErr: false, + }, + { + name: "full HTTP dataset URL", + input: "http://huggingface.co/datasets/myorg/mydataset", + wantRepo: "myorg/mydataset", + wantRepoType: RepoTypeDataset, + wantErr: false, + }, + { + name: "scheme-less dataset URL", + input: "huggingface.co/datasets/myorg/mydataset", + wantRepo: "myorg/mydataset", + wantRepoType: RepoTypeDataset, + wantErr: false, + }, + { + name: "short form dataset", + input: "datasets/myorg/mydataset", + wantRepo: "myorg/mydataset", + wantRepoType: RepoTypeDataset, + wantErr: false, + }, + + // URLs with trailing slashes + { + name: "model URL with trailing slash", + input: "https://huggingface.co/myorg/mymodel/", + wantRepo: "myorg/mymodel", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "dataset URL with trailing slash", + input: "https://huggingface.co/datasets/myorg/mydataset/", + wantRepo: "myorg/mydataset", + wantRepoType: RepoTypeDataset, + wantErr: false, + }, + + // Security: Malicious URLs that should be rejected + { + name: "subdomain attack", + input: "https://huggingface.co.evil.com/myorg/mymodel", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "prefix attack", + input: "https://evilhuggingface.co/myorg/mymodel", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "suffix attack", + input: "https://huggingface.co.attacker.com/myorg/mymodel", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + + // Non-HuggingFace URLs + { + name: "GitHub URL", + input: "https://github.com/myorg/myrepo", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "GitLab URL", + input: "https://gitlab.com/myorg/myrepo", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "random domain", + input: "https://example.com/myorg/myrepo", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + + // Invalid formats + { + name: "single segment", + input: "myrepo", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "too many segments", + input: "org/repo/extra", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "empty string", + input: "", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "only slashes", + input: "///", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "incomplete dataset short form", + input: "datasets/myorg", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + { + name: "incomplete dataset URL", + input: "https://huggingface.co/datasets/myorg", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, + }, + + // Edge cases with special characters + { + name: "repo with hyphens", + input: "my-org/my-model", + wantRepo: "my-org/my-model", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "repo with underscores", + input: "my_org/my_model", + wantRepo: "my_org/my_model", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "repo with numbers", + input: "org123/model456", + wantRepo: "org123/model456", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "repo with dots", + input: "my.org/my.model", + wantRepo: "my.org/my.model", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + + // Real-world examples + { + name: "GPT-2 model", + input: "https://huggingface.co/gpt2", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, // Single segment is invalid + }, + { + name: "BERT base model", + input: "https://huggingface.co/bert-base-uncased", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, // Single segment is invalid + }, + { + name: "OpenAI GPT-2", + input: "openai/gpt2", + wantRepo: "openai/gpt2", + wantRepoType: RepoTypeModel, + wantErr: false, + }, + { + name: "SQuAD dataset", + input: "datasets/squad", + wantRepo: "", + wantRepoType: RepoTypeUnknown, + wantErr: true, // Need org/repo format + }, + { + name: "SQuAD dataset full form", + input: "datasets/rajpurkar/squad", + wantRepo: "rajpurkar/squad", + wantRepoType: RepoTypeDataset, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotRepo, gotRepoType, err := ParseHuggingFaceRepo(tt.input) + + if (err != nil) != tt.wantErr { + t.Errorf("ParseHuggingFaceRepo() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if gotRepo != tt.wantRepo { + t.Errorf("ParseHuggingFaceRepo() repo = %v, want %v", gotRepo, tt.wantRepo) + } + + if gotRepoType != tt.wantRepoType { + t.Errorf("ParseHuggingFaceRepo() repoType = %v, want %v", gotRepoType, tt.wantRepoType) + } + }) + } +} + +func TestIsValidRepoFormat(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "valid org/repo", + input: "myorg/myrepo", + want: true, + }, + { + name: "valid with hyphens", + input: "my-org/my-repo", + want: true, + }, + { + name: "valid with underscores", + input: "my_org/my_repo", + want: true, + }, + { + name: "single segment", + input: "myrepo", + want: false, + }, + { + name: "three segments", + input: "org/repo/extra", + want: false, + }, + { + name: "empty string", + input: "", + want: false, + }, + { + name: "only slash", + input: "/", + want: false, + }, + { + name: "empty org", + input: "/myrepo", + want: false, + }, + { + name: "empty repo", + input: "myorg/", + want: false, + }, + { + name: "double slash", + input: "myorg//myrepo", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isValidRepoFormat(tt.input) + if got != tt.want { + t.Errorf("isValidRepoFormat() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/lib/hf/types.go b/pkg/lib/hf/types.go new file mode 100644 index 00000000..2ea7ade5 --- /dev/null +++ b/pkg/lib/hf/types.go @@ -0,0 +1,25 @@ +// Copyright 2024 The KitOps Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package hf + +type RepositoryType int + +const ( + RepoTypeUnknown RepositoryType = iota + RepoTypeModel + RepoTypeDataset +)