Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@
# Dependency directories (remove the comment below to include it)
# vendor/
.env

# For now ignore local VS Code settings
.vscode/
24 changes: 24 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
version: 2

# See https://golangci-lint.run/usage/configuration/ for all options

run:
timeout: 5m

linters:
disable-all: true
enable:
- govet
- errcheck
- staticcheck
- unused
- ineffassign
- misspell

formatters:
enable:
- gofmt
- goimports

issues:
exclude-use-default: false
77 changes: 77 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Default target
default: build

# Build the application
.PHONY: build

build:
go build ./...

# Test target to run all tests in the project
.PHONY: test

test: build
go test ./... -v

# Lint target to run the linter
.PHONY: lint

lint:
golangci-lint run

# Run linter with fixes
.PHONY: lint-fix

lint-fix:
golangci-lint run --fix

# Check Go code formatting without modifying files
.PHONY: fmt-check

fmt-check:
@echo "Checking Go formatting..."
@if [ -n "$$(gofmt -l .)" ]; then \
echo "The following files need formatting:"; \
gofmt -l .; \
exit 1; \
fi
@if [ -n "$$(goimports -l .)" ]; then \
echo "The following files need import formatting:"; \
goimports -l .; \
exit 1; \
fi
@echo "All files are properly formatted!"

# Format Go code
.PHONY: fmt

fmt:
gofmt -w .
goimports -w .

# Install the application
.PHONY: install

install: build
go install

# Check for Go modules updates
.PHONY: update

update:
go get -u ./...
go mod tidy


# Ensure pre-commit hook is executable
.PHONY: config-pre-commit

config-pre-commit:
@echo "Setting up pre-commit hook..."
git config --local core.hooksPath .githooks/

# Pre-commit checks (format, lint, test)
.PHONY: pre-commit

pre-commit: fmt lint test
@echo "Pre-commit checks passed!"
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ Or pass in the required secrets directly:

```go
client, err := wx.NewClient(
wx.WithClientRetryConfig(wx.NewRetryConfig(
wx.WithReturnHTTPStatusAsErr(false)),
),
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
Expand All @@ -44,7 +47,7 @@ Generation:

```go
result, _ := client.GenerateText(
"meta-llama/llama-3-1-8b-instruct",
"meta-llama/llama-3-3-70b-instruct",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the initial models seem to no longer be available by default

"Hi, who are you?",
wx.WithTemperature(0.4),
wx.WithMaxNewTokens(512),
Expand All @@ -57,7 +60,7 @@ Stream Generation:

```go
dataChan, _ := client.GenerateTextStream(
"meta-llama/llama-3-1-8b-instruct",
"meta-llama/llama-3-3-70b-instruct",
"Hi, who are you?",
wx.WithTemperature(0.4),
wx.WithMaxNewTokens(512),
Expand Down Expand Up @@ -195,6 +198,9 @@ Specify the Watsonx URL and IAM endpoint through the parameters of the NewClient

```go
client, err := wx.NewClient(
wx.WithClientRetryConfig(wx.NewRetryConfig(
wx.WithReturnHTTPStatusAsErr(false)),
),
wx.WithURL("us-south.ml.test.cloud.ibm.com"),
wx.WithIAM("iam.test.cloud.ibm.com"),
wx.WithWatsonxAPIKey(apiKey),
Expand Down
36 changes: 32 additions & 4 deletions pkg/internal/tests/models/embedding_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package test

import (
wx "github.com/IBM/watsonx-go/pkg/models"
"fmt"
"reflect"
"testing"

wx "github.com/IBM/watsonx-go/pkg/models"
)

const (
EmbeddingModelId = "ibm/slate-30m-english-rtrvr"
EmbeddingModelDimension = 384
EmbeddingModelId = "ibm/slate-30m-english-rtrvr"
modelLlama3DoesNotSupportEmbedding = "meta-llama/llama-3-3-70b-instruct"
EmbeddingModelDimension = 384
)

func TestEmbeddingSingleQuery(t *testing.T) {
Expand Down Expand Up @@ -78,7 +81,7 @@ func TestEmbeddingSingleQueryWithOptions(t *testing.T) {
}

if reflect.DeepEqual(response.Results[0].Embedding, responseNoOptions.Results[0].Embedding) {
t.Fatalf("Expected different embeddings with and without options, but got the same")
t.Fatal("Expected different embeddings with and without options, but got the same")
}
}

Expand Down Expand Up @@ -115,3 +118,28 @@ func TestEmbeddingMultipleQueries(t *testing.T) {
t.Fatalf("Expected model to be %s, but got %s", EmbeddingModelId, response.Model)
}
}

func TestEmbeddingModelDoesNotSupportEmbedding(t *testing.T) {
client := getClient(t)

text := "Hello, world!"

_, err := client.EmbedQuery(modelLlama3DoesNotSupportEmbedding, text)

if err == nil {
t.Fatal("Expected error for invalid model but got nil")
}

errorMessage := parseResponseErrMessage(t, err)
if errorMessage == nil {
t.Fatalf("Expected JSON error message, but got: %v", err)
}

if errorMessage.StatusCode != 400 {
t.Fatalf("Expected status code 400, but got: %d", errorMessage.StatusCode)
}
expectedErrText := fmt.Sprintf("Model '%s' does not support function 'function_embedding'", modelLlama3DoesNotSupportEmbedding)
if len(errorMessage.Errors) == 0 || errorMessage.Errors[0].Message != expectedErrText {
t.Fatalf("Expected error message to be %s, but got: %v", expectedErrText, errorMessage.Errors[0].Message)
}
}
Loading