-
Notifications
You must be signed in to change notification settings - Fork 666
feat(go/plugins/anthropic): support native structured output #4701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,8 @@ func Generate( | |
|
|
||
| req.Model = anthropic.Model(model) | ||
|
|
||
| isStructured := input.Output != nil && input.Output.Format == "json" && input.Output.Schema != nil && input.Output.Constrained | ||
|
|
||
| // no streaming | ||
| if cb == nil { | ||
| msg, err := client.Messages.New(ctx, *req) | ||
|
|
@@ -85,6 +87,9 @@ func Generate( | |
| } | ||
|
|
||
| r.Request = input | ||
| if isStructured { | ||
| handleStructuredOutput(r) | ||
| } | ||
| return r, nil | ||
| } else { | ||
| stream := client.Messages.NewStreaming(ctx, *req) | ||
|
|
@@ -101,6 +106,8 @@ func Generate( | |
| case anthropic.ContentBlockDeltaEvent: | ||
| if event.Delta.Type == "thinking_delta" { | ||
| content = append(content, ai.NewReasoningPart(event.Delta.Thinking, []byte(event.Delta.Signature))) | ||
| } else if isStructured && event.Delta.Type == "input_json_delta" { | ||
| content = append(content, ai.NewTextPart(event.Delta.PartialJSON)) | ||
| } else { | ||
| content = append(content, ai.NewTextPart(event.Delta.Text)) | ||
| } | ||
|
|
@@ -116,6 +123,9 @@ func Generate( | |
| return nil, err | ||
| } | ||
| r.Request = input | ||
| if isStructured { | ||
| handleStructuredOutput(r) | ||
| } | ||
| return r, nil | ||
| } | ||
| } | ||
|
|
@@ -127,6 +137,17 @@ func Generate( | |
| return nil, nil | ||
| } | ||
|
|
||
| func handleStructuredOutput(r *ai.ModelResponse) { | ||
| for i, part := range r.Message.Content { | ||
| if part.IsToolRequest() && part.ToolRequest.Name == "return_json_output" { | ||
| // Convert input to JSON | ||
| jsonBytes, _ := json.Marshal(part.ToolRequest.Input) | ||
| r.Message.Content[i] = ai.NewTextPart(string(jsonBytes)) | ||
| r.FinishReason = ai.FinishReasonStop | ||
| } | ||
| } | ||
| } | ||
|
|
||
| func toAnthropicRole(role ai.Role) (anthropic.MessageParamRole, error) { | ||
| switch role { | ||
| case ai.RoleUser: | ||
|
|
@@ -193,6 +214,26 @@ func toAnthropicRequest(i *ai.ModelRequest) (*anthropic.MessageNewParams, error) | |
| } | ||
| req.Tools = tools | ||
|
|
||
| if i.Output != nil && i.Output.Format == "json" && i.Output.Schema != nil && i.Output.Constrained { | ||
| schema, err := base.MapToStruct[anthropic.ToolInputSchemaParam](i.Output.Schema) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("unable to parse output schema: %w", err) | ||
| } | ||
| req.Tools = append(req.Tools, anthropic.ToolUnionParam{ | ||
| OfTool: &anthropic.ToolParam{ | ||
| Name: "return_json_output", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The string For example: const structuredOutputToolName = "return_json_output"Then use this constant here and in other places where this string appears. |
||
| Description: anthropic.String("Return the output in JSON format"), | ||
| InputSchema: schema, | ||
| }, | ||
| }) | ||
| req.ToolChoice = anthropic.ToolChoiceUnionParam{ | ||
| OfTool: &anthropic.ToolChoiceToolParam{ | ||
| Name: "return_json_output", | ||
| Type: "tool", | ||
| }, | ||
| } | ||
| } | ||
|
|
||
| return req, nil | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| package anthropic | ||
|
|
||
| import ( | ||
| "encoding/json" | ||
| "reflect" | ||
| "strings" | ||
| "testing" | ||
|
|
@@ -417,3 +418,63 @@ func TestToAnthropicRequest(t *testing.T) { | |
| }) | ||
| } | ||
| } | ||
|
|
||
| func TestToAnthropicRequest_StructuredOutput(t *testing.T) { | ||
| schema := map[string]any{ | ||
| "type": "object", | ||
| "properties": map[string]any{ | ||
| "answer": map[string]any{"type": "string"}, | ||
| }, | ||
| "required": []string{"answer"}, | ||
| } | ||
|
|
||
| req := &ai.ModelRequest{ | ||
| Messages: []*ai.Message{ | ||
| { | ||
| Role: ai.RoleUser, | ||
| Content: []*ai.Part{ai.NewTextPart("hello")}, | ||
| }, | ||
| }, | ||
| Config: map[string]any{ | ||
| "max_tokens": 100, | ||
| }, | ||
| Output: &ai.ModelOutputConfig{ | ||
| Format: "json", | ||
| Schema: schema, | ||
| }, | ||
| } | ||
|
|
||
| got, err := toAnthropicRequest(req) | ||
| if err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
|
|
||
| // We expect a tool named "return_json_output" (or similar) with the schema | ||
| var foundTool *anthropic.ToolParam | ||
| for _, toolUnion := range got.Tools { | ||
| if toolUnion.OfTool != nil && toolUnion.OfTool.Name == "return_json_output" { | ||
| foundTool = toolUnion.OfTool | ||
| break | ||
| } | ||
| } | ||
|
|
||
| if foundTool == nil { | ||
| t.Errorf("expected tool 'return_json_output' not found in tools: %+v", got.Tools) | ||
| } else { | ||
| // Verify schema | ||
| inputSchemaBytes, _ := json.Marshal(foundTool.InputSchema) | ||
| expectedSchemaBytes, _ := json.Marshal(schema) | ||
| if len(inputSchemaBytes) == 0 { | ||
| t.Errorf("tool input schema is empty") | ||
| } | ||
| t.Logf("Schema found: %s", string(inputSchemaBytes)) | ||
| t.Logf("Expected: %s", string(expectedSchemaBytes)) | ||
| } | ||
|
Comment on lines
+464
to
+472
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The schema verification in this test is incomplete. It logs the schemas but doesn't actually assert that they are equal. Additionally, errors from A more robust test would be to unmarshal the schemas into comparable maps and then use // Verify schema
inputSchemaBytes, err := json.Marshal(foundTool.InputSchema)
if err != nil {
t.Fatalf("failed to marshal found tool schema: %v", err)
}
expectedSchemaBytes, err := json.Marshal(schema)
if err != nil {
t.Fatalf("failed to marshal expected schema: %v", err)
}
var inputSchemaMap, expectedSchemaMap map[string]any
if err := json.Unmarshal(inputSchemaBytes, &inputSchemaMap); err != nil {
t.Fatalf("failed to unmarshal input schema: %v", err)
}
if err := json.Unmarshal(expectedSchemaBytes, &expectedSchemaMap); err != nil {
t.Fatalf("failed to unmarshal expected schema: %v", err)
}
if !reflect.DeepEqual(inputSchemaMap, expectedSchemaMap) {
t.Errorf("schema mismatch:\ngot: %s\nwant: %s", string(inputSchemaBytes), string(expectedSchemaBytes))
}
} |
||
|
|
||
| // We expect ToolChoice to be set to force this tool | ||
| if got.ToolChoice.OfTool == nil { | ||
| t.Errorf("expected ToolChoice to be set to specific tool, got nil or auto") | ||
| } else if got.ToolChoice.OfTool.Name != "return_json_output" { | ||
| t.Errorf("expected ToolChoice name to be 'return_json_output', got %q", got.ToolChoice.OfTool.Name) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error returned by
json.Marshalon line 144 is being ignored. This could lead to silent failures where an invalid tool input results in an empty text part without any indication of an error. The function should be modified to return an error, which should then be handled in theGeneratefunction.For example, you could update the call sites in
Generatelike this: