Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@
# Dependency directories (remove the comment below to include it)
# vendor/
.env

# For now ignore local VS Code settings
.vscode/
24 changes: 24 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
version: 2

# See https://golangci-lint.run/usage/configuration/ for all options

run:
timeout: 5m

linters:
disable-all: true
enable:
- govet
- errcheck
- staticcheck
- unused
- ineffassign
- misspell

formatters:
enable:
- gofmt
- goimports

issues:
exclude-use-default: false
77 changes: 77 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Default target
default: build

# Build the application
.PHONY: build

build:
go build ./...

# Test target to run all tests in the project
.PHONY: test

test: build
go test ./... -v

# Lint target to run the linter
.PHONY: lint

lint:
golangci-lint run

# Run linter with fixes
.PHONY: lint-fix

lint-fix:
golangci-lint run --fix

# Check Go code formatting without modifying files
.PHONY: fmt-check

fmt-check:
@echo "Checking Go formatting..."
@if [ -n "$$(gofmt -l .)" ]; then \
echo "The following files need formatting:"; \
gofmt -l .; \
exit 1; \
fi
@if [ -n "$$(goimports -l .)" ]; then \
echo "The following files need import formatting:"; \
goimports -l .; \
exit 1; \
fi
@echo "All files are properly formatted!"

# Format Go code
.PHONY: fmt

fmt:
gofmt -w .
goimports -w .

# Install the application
.PHONY: install

install: build
go install

# Check for Go modules updates
.PHONY: update

update:
go get -u ./...
go mod tidy


# Ensure pre-commit hook is executable
.PHONY: config-pre-commit

config-pre-commit:
@echo "Setting up pre-commit hook..."
git config --local core.hooksPath .githooks/

# Pre-commit checks (format, lint, test)
.PHONY: pre-commit

pre-commit: fmt lint test
@echo "Pre-commit checks passed!"
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ Or pass in the required secrets directly:

```go
client, err := wx.NewClient(
wx.WithClientRetryConfig(wx.NewRetryConfig(
wx.WithReturnHTTPStatusAsErr(false)),
),
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
Expand All @@ -44,7 +47,7 @@ Generation:

```go
result, _ := client.GenerateText(
"meta-llama/llama-3-1-8b-instruct",
"meta-llama/llama-3-3-70b-instruct",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the initial models seem to no longer be available by default

"Hi, who are you?",
wx.WithTemperature(0.4),
wx.WithMaxNewTokens(512),
Expand All @@ -57,7 +60,7 @@ Stream Generation:

```go
dataChan, _ := client.GenerateTextStream(
"meta-llama/llama-3-1-8b-instruct",
"meta-llama/llama-3-3-70b-instruct",
"Hi, who are you?",
wx.WithTemperature(0.4),
wx.WithMaxNewTokens(512),
Expand Down Expand Up @@ -142,6 +145,9 @@ Specify the Watsonx URL and IAM endpoint through the parameters of the NewClient

```go
client, err := wx.NewClient(
wx.WithClientRetryConfig(wx.NewRetryConfig(
wx.WithReturnHTTPStatusAsErr(false)),
),
wx.WithURL("us-south.ml.test.cloud.ibm.com"),
wx.WithIAM("iam.test.cloud.ibm.com"),
wx.WithWatsonxAPIKey(apiKey),
Expand Down
3 changes: 2 additions & 1 deletion pkg/internal/tests/models/embedding_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package test

import (
wx "github.com/IBM/watsonx-go/pkg/models"
"reflect"
"testing"

wx "github.com/IBM/watsonx-go/pkg/models"
)

const (
Expand Down
20 changes: 14 additions & 6 deletions pkg/internal/tests/models/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import (
wx "github.com/IBM/watsonx-go/pkg/models"
)

const (
modelLlama3 = "meta-llama/llama-3-3-70b-instruct"
modelFlanUL2 = "google/flan-ul2"
)

func TestClientCreationWithEnvVars(t *testing.T) {
_, err := wx.NewClient()

Expand All @@ -26,6 +31,9 @@ func TestClientCreationWithPassing(t *testing.T) {
}

_, err := wx.NewClient(
wx.WithClientRetryConfig(wx.NewRetryConfig(
wx.WithReturnHTTPStatusAsErr(false),
)),
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
Expand All @@ -51,7 +59,7 @@ func TestNilOptions(t *testing.T) {
client := getClient(t)

_, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
modelLlama3,
"What day is it?",
nil,
)
Expand All @@ -64,7 +72,7 @@ func TestValidPrompt(t *testing.T) {
client := getClient(t)

_, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
modelLlama3,
"Test prompt",
)
if err != nil {
Expand All @@ -76,7 +84,7 @@ func TestGenerateText(t *testing.T) {
client := getClient(t)

result, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
modelLlama3,
"Hi, who are you?",
wx.WithTemperature(0.9),
wx.WithTopP(.5),
Expand All @@ -95,7 +103,7 @@ func TestGenerateTextStream(t *testing.T) {
client := getClient(t)

dataChan, err := client.GenerateTextStream(
"google/flan-ul2",
modelFlanUL2,
"Hi, who are you?",
wx.WithTemperature(0.9),
wx.WithTopP(.5),
Expand Down Expand Up @@ -126,7 +134,7 @@ func TestGenerateTextWithNoPrompt(t *testing.T) {
client := getClient(t)

dataChan, err := client.GenerateTextStream(
"google/flan-ul2",
modelFlanUL2,
"",
wx.WithTemperature(0.9),
wx.WithTopP(.5),
Expand Down Expand Up @@ -158,7 +166,7 @@ func TestGenerateTextWithNilOptions(t *testing.T) {
client := getClient(t)

result, err := client.GenerateText(
"meta-llama/llama-3-70b-instruct",
modelLlama3,
"Who are you?",
nil,
)
Expand Down
95 changes: 84 additions & 11 deletions pkg/internal/tests/models/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ func TestRetryWithSuccessOnFirstRequest(t *testing.T) {

resp, err := wx.Retry(
sendRequest,
wx.WithOnRetry(func(n uint, err error) {
retryCount = n
log.Printf("Retrying request after error: %v", err)
}),
wx.NewRetryConfig(
wx.WithOnRetry(func(n uint, err error) {
retryCount = n
log.Printf("Retrying request after error: %v", err)
}),
),
)

if err != nil {
Expand All @@ -60,8 +62,8 @@ func TestRetryWithSuccessOnFirstRequest(t *testing.T) {
}
}

// TestRetryWithNoSuccessStatusOnAnyRequest tests the retry mechanism with a server that always returns a 429 status code.
func TestRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) {
// TestLegacyRetryWithNoSuccessStatusOnAnyRequest tests the retry mechanism with a server that always returns a 429 status code.
func TestLegacyRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
}))
Expand All @@ -79,11 +81,13 @@ func TestRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) {

resp, err := wx.Retry(
sendRequest,
wx.WithBackoff(backoffTime),
wx.WithOnRetry(func(n uint, err error) {
retryCount = n
log.Printf("Retrying request after error: %v", err)
}),
wx.NewRetryConfig(
wx.WithBackoff(backoffTime),
wx.WithOnRetry(func(n uint, err error) {
retryCount = n
log.Printf("Retrying request after error: %v", err)
}),
),
)

endTime := time.Now()
Expand All @@ -108,3 +112,72 @@ func TestRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) {
t.Errorf("Expected minimum time of %v, but got %v", expectedMinimumTime, elapsedTime)
}
}

// TestRetryWithNoSuccessStatusOnAnyRequest tests the retry mechanism with a server that always returns a 429 status code.
func TestRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) {
expectedStatusCode := http.StatusTooManyRequests

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatusCode)
}))
defer server.Close()

var backoffTime = 2 * time.Second
var retryCount uint = 0
var expectedRetries uint = 3

sendRequest := func() (*http.Response, error) {
return http.Get(server.URL + "/notfound")
}

startTime := time.Now()

resp, err := wx.Retry(
sendRequest,
wx.NewRetryConfig(
wx.WithReturnHTTPStatusAsErr(false), // Use new behavior: only return actual network errors
wx.WithBackoff(backoffTime),
wx.WithOnRetryV2(func(n uint, resp *http.Response, err error) {
retryCount = n
if err != nil {
t.Errorf("In OnRetry, expected nil, got error: %v", err)
}

if resp == nil {
t.Errorf("In OnRetry, expected non-nil response, got nil")
}

if resp != nil && resp.StatusCode != expectedStatusCode {
t.Errorf("Expected status code %d, got %d", expectedStatusCode, resp.StatusCode)
}

log.Printf("Retrying request after response with status code: %d", resp.StatusCode)
}),
),
)

endTime := time.Now()

elapsedTime := endTime.Sub(startTime)
expectedMinimumTime := backoffTime * time.Duration(expectedRetries)

if err != nil {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: The new behavior is that the error will be nil and the response will be non-nil.

This will give the user the opportunity to react to the response and use the response contents for troubleshooting purposes by logging/printing them to the end user.

t.Errorf("Expected nil, got error: %v", err)
}

if resp == nil {
t.Errorf("Expected non-nil response, got nil")
}

if resp != nil && resp.StatusCode != expectedStatusCode {
t.Errorf("Expected status code %d, got %d", expectedStatusCode, resp.StatusCode)
}

if retryCount != expectedRetries {
t.Errorf("Expected 3 retries, but got %d", retryCount)
}

if elapsedTime < expectedMinimumTime {
t.Errorf("Expected minimum time of %v, but got %v", expectedMinimumTime, elapsedTime)
}
}
6 changes: 5 additions & 1 deletion pkg/internal/tests/models/test_utils.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package test

import (
wx "github.com/IBM/watsonx-go/pkg/models"
"os"
"testing"

wx "github.com/IBM/watsonx-go/pkg/models"
)

func getClient(t *testing.T) *wx.Client {
Expand All @@ -17,6 +18,9 @@ func getClient(t *testing.T) *wx.Client {
}

client, err := wx.NewClient(
wx.WithClientRetryConfig(wx.NewRetryConfig(
wx.WithReturnHTTPStatusAsErr(false)),
),
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
)
Expand Down
Loading