Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions go/plugins/anthropic/anthropic_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,65 @@ func TestAnthropicLive(t *testing.T) {
t.Fatalf("empty usage stats: %#v", *final.Usage)
}
})
t.Run("constrained generation", func(t *testing.T) {
m := anthropicPlugin.Model(g, "claude-sonnet-4-5-20250929")
type outFormat struct {
Country string `json:"country"`
}
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&anthropic.MessageNewParams{MaxTokens: 1024}),
ai.WithPrompt("Which country was Napoleon the emperor of?"),
ai.WithOutputType(outFormat{}),
)
if err != nil {
t.Fatal(err)
}

var ans outFormat
if err := resp.Output(&ans); err != nil {
t.Fatal(err)
}
const want = "France"
if ans.Country != want {
t.Errorf("got %q, expecting %q", ans.Country, want)
}
})
t.Run("streaming constrained generation", func(t *testing.T) {
m := anthropicPlugin.Model(g, "claude-sonnet-4-5-20250929")
type outFormat struct {
Country string `json:"country"`
}

var streamedContent strings.Builder
resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&anthropic.MessageNewParams{MaxTokens: 1024}),
ai.WithPrompt("Which country was Napoleon the emperor of?"),
ai.WithOutputType(outFormat{}),
ai.WithStreaming(func(ctx context.Context, chunk *ai.ModelResponseChunk) error {
streamedContent.WriteString(chunk.Text())
return nil
}),
)
if err != nil {
t.Fatal(err)
}

var ans outFormat
if err := resp.Output(&ans); err != nil {
t.Fatal(err)
}
const want = "France"
if ans.Country != want {
t.Errorf("got %q, expecting %q", ans.Country, want)
}

if streamedContent.Len() == 0 {
t.Error("expected streamed content, got empty")
}
})
t.Run("tools streaming with constrained gen", func(t *testing.T) {
t.Skip("skipped until issue #3851 gets resolved")
m := anthropicPlugin.Model(g, "claude-sonnet-4-5-20250929")
answerOfEverythingTool := genkit.DefineTool(
g,
Expand All @@ -378,12 +435,7 @@ func TestAnthropicLive(t *testing.T) {
ai.WithModel(m),
ai.WithConfig(&anthropic.MessageNewParams{
Temperature: anthropic.Float(1),
Thinking: anthropic.ThinkingConfigParamUnion{
OfEnabled: &anthropic.ThinkingConfigEnabledParam{
BudgetTokens: 1024,
},
},
MaxTokens: 2048,
MaxTokens: 2048,
}),
ai.WithOutputType(Output{}),
ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error {
Expand All @@ -394,9 +446,6 @@ func TestAnthropicLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if resp.Reasoning() == "" {
t.Fatal("empty reasoning found")
}

var out Output
err = resp.Output(&out)
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/anthropic/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

var defaultClaudeOpts = ai.ModelOptions{
Supports: &internal.MultimodalNoConstrained,
Supports: &internal.Multimodal,
Versions: []string{},
Stage: ai.ModelStageStable,
}
Expand Down
41 changes: 41 additions & 0 deletions go/plugins/internal/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func Generate(

req.Model = anthropic.Model(model)

isStructured := input.Output != nil && input.Output.Format == "json" && input.Output.Schema != nil && input.Output.Constrained

// no streaming
if cb == nil {
msg, err := client.Messages.New(ctx, *req)
Expand All @@ -85,6 +87,9 @@ func Generate(
}

r.Request = input
if isStructured {
handleStructuredOutput(r)
}
return r, nil
} else {
stream := client.Messages.NewStreaming(ctx, *req)
Expand All @@ -101,6 +106,8 @@ func Generate(
case anthropic.ContentBlockDeltaEvent:
if event.Delta.Type == "thinking_delta" {
content = append(content, ai.NewReasoningPart(event.Delta.Thinking, []byte(event.Delta.Signature)))
} else if isStructured && event.Delta.Type == "input_json_delta" {
content = append(content, ai.NewTextPart(event.Delta.PartialJSON))
} else {
content = append(content, ai.NewTextPart(event.Delta.Text))
}
Expand All @@ -116,6 +123,9 @@ func Generate(
return nil, err
}
r.Request = input
if isStructured {
handleStructuredOutput(r)
}
return r, nil
}
}
Expand All @@ -127,6 +137,17 @@ func Generate(
return nil, nil
}

func handleStructuredOutput(r *ai.ModelResponse) {
for i, part := range r.Message.Content {
if part.IsToolRequest() && part.ToolRequest.Name == "return_json_output" {
// Convert input to JSON
jsonBytes, _ := json.Marshal(part.ToolRequest.Input)
r.Message.Content[i] = ai.NewTextPart(string(jsonBytes))
r.FinishReason = ai.FinishReasonStop
}
}
}
Comment on lines +140 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error returned by json.Marshal on line 144 is being ignored. This could lead to silent failures where an invalid tool input results in an empty text part without any indication of an error. The function should be modified to return an error, which should then be handled in the Generate function.

For example, you could update the call sites in Generate like this:

// In Generate function (non-streaming):
if isStructured {
    if err := handleStructuredOutput(r); err != nil {
        return nil, err
    }
}
return r, nil
func handleStructuredOutput(r *ai.ModelResponse) error {
	for i, part := range r.Message.Content {
		if part.IsToolRequest() && part.ToolRequest.Name == "return_json_output" {
			// Convert input to JSON
			jsonBytes, err := json.Marshal(part.ToolRequest.Input)
			if err != nil {
				return fmt.Errorf("failed to marshal structured output: %w", err)
			}
			r.Message.Content[i] = ai.NewTextPart(string(jsonBytes))
			r.FinishReason = ai.FinishReasonStop
		}
	}
	return nil
}


func toAnthropicRole(role ai.Role) (anthropic.MessageParamRole, error) {
switch role {
case ai.RoleUser:
Expand Down Expand Up @@ -193,6 +214,26 @@ func toAnthropicRequest(i *ai.ModelRequest) (*anthropic.MessageNewParams, error)
}
req.Tools = tools

if i.Output != nil && i.Output.Format == "json" && i.Output.Schema != nil && i.Output.Constrained {
schema, err := base.MapToStruct[anthropic.ToolInputSchemaParam](i.Output.Schema)
if err != nil {
return nil, fmt.Errorf("unable to parse output schema: %w", err)
}
req.Tools = append(req.Tools, anthropic.ToolUnionParam{
OfTool: &anthropic.ToolParam{
Name: "return_json_output",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string "return_json_output" is used in multiple places in this file (lines 142, 231) and in the test file. It would be better to define it as a constant to avoid typos and improve maintainability.

For example:

const structuredOutputToolName = "return_json_output"

Then use this constant here and in other places where this string appears.

Description: anthropic.String("Return the output in JSON format"),
InputSchema: schema,
},
})
req.ToolChoice = anthropic.ToolChoiceUnionParam{
OfTool: &anthropic.ToolChoiceToolParam{
Name: "return_json_output",
Type: "tool",
},
}
}

return req, nil
}

Expand Down
61 changes: 61 additions & 0 deletions go/plugins/internal/anthropic/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package anthropic

import (
"encoding/json"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -417,3 +418,63 @@ func TestToAnthropicRequest(t *testing.T) {
})
}
}

func TestToAnthropicRequest_StructuredOutput(t *testing.T) {
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"answer": map[string]any{"type": "string"},
},
"required": []string{"answer"},
}

req := &ai.ModelRequest{
Messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("hello")},
},
},
Config: map[string]any{
"max_tokens": 100,
},
Output: &ai.ModelOutputConfig{
Format: "json",
Schema: schema,
},
}

got, err := toAnthropicRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// We expect a tool named "return_json_output" (or similar) with the schema
var foundTool *anthropic.ToolParam
for _, toolUnion := range got.Tools {
if toolUnion.OfTool != nil && toolUnion.OfTool.Name == "return_json_output" {
foundTool = toolUnion.OfTool
break
}
}

if foundTool == nil {
t.Errorf("expected tool 'return_json_output' not found in tools: %+v", got.Tools)
} else {
// Verify schema
inputSchemaBytes, _ := json.Marshal(foundTool.InputSchema)
expectedSchemaBytes, _ := json.Marshal(schema)
if len(inputSchemaBytes) == 0 {
t.Errorf("tool input schema is empty")
}
t.Logf("Schema found: %s", string(inputSchemaBytes))
t.Logf("Expected: %s", string(expectedSchemaBytes))
}
Comment on lines +464 to +472
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The schema verification in this test is incomplete. It logs the schemas but doesn't actually assert that they are equal. Additionally, errors from json.Marshal are ignored.

A more robust test would be to unmarshal the schemas into comparable maps and then use reflect.DeepEqual for comparison.

		// Verify schema
		inputSchemaBytes, err := json.Marshal(foundTool.InputSchema)
		if err != nil {
			t.Fatalf("failed to marshal found tool schema: %v", err)
		}
		expectedSchemaBytes, err := json.Marshal(schema)
		if err != nil {
			t.Fatalf("failed to marshal expected schema: %v", err)
		}

		var inputSchemaMap, expectedSchemaMap map[string]any
		if err := json.Unmarshal(inputSchemaBytes, &inputSchemaMap); err != nil {
			t.Fatalf("failed to unmarshal input schema: %v", err)
		}
		if err := json.Unmarshal(expectedSchemaBytes, &expectedSchemaMap); err != nil {
			t.Fatalf("failed to unmarshal expected schema: %v", err)
		}

		if !reflect.DeepEqual(inputSchemaMap, expectedSchemaMap) {
			t.Errorf("schema mismatch:\ngot: %s\nwant: %s", string(inputSchemaBytes), string(expectedSchemaBytes))
		}
	}


// We expect ToolChoice to be set to force this tool
if got.ToolChoice.OfTool == nil {
t.Errorf("expected ToolChoice to be set to specific tool, got nil or auto")
} else if got.ToolChoice.OfTool.Name != "return_json_output" {
t.Errorf("expected ToolChoice name to be 'return_json_output', got %q", got.ToolChoice.OfTool.Name)
}
}
Loading