diff --git a/agent.go b/agent.go index aa11f7e6f49ee7c1a3a4db764b5bae6f822b854c..b7435ef28ee42e58aad28b91ce6eab5e78bdf00a 100644 --- a/agent.go +++ b/agent.go @@ -709,7 +709,20 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall ClientMetadata: toolResult.Metadata, ProviderExecuted: false, } + } else if toolResult.Type == "image" || toolResult.Type == "media" { + result = ToolResultContent{ + ToolCallID: toolCall.ToolCallID, + ToolName: toolCall.ToolName, + Result: ToolResultOutputContentMedia{ + Data: string(toolResult.Data), + MediaType: toolResult.MediaType, + Text: toolResult.Content, + }, + ClientMetadata: toolResult.Metadata, + ProviderExecuted: false, + } } else { + // Default to text response result = ToolResultContent{ ToolCallID: toolCall.ToolCallID, ToolName: toolCall.ToolName, diff --git a/agent_test.go b/agent_test.go index 8074b519b167b6878fea712ffdd8ab4704444caf..33c137a1362c11865ec7353137d8febe515c4273 100644 --- a/agent_test.go +++ b/agent_test.go @@ -1535,3 +1535,236 @@ func TestToolCallRepair(t *testing.T) { require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input") }) } + +// Test media and image tool responses +func TestAgent_MediaToolResponses(t *testing.T) { + t.Parallel() + + imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes + audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes + + t.Run("Image tool response", func(t *testing.T) { + t.Parallel() + + imageTool := &mockTool{ + name: "generate_image", + description: "Generates an image", + executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) { + return NewImageResponse(imageData, "image/png"), nil + }, + } + + model := &mockLanguageModel{ + generateFunc: func(ctx context.Context, call Call) (*Response, error) { + if len(call.Prompt) == 1 { + // First call - request image tool + return &Response{ + Content: []Content{ + ToolCallContent{ + ToolCallID: "img-1", + ToolName: "generate_image", + Input: `{}`, + }, + }, + Usage: Usage{TotalTokens: 10}, + FinishReason: FinishReasonToolCalls, + }, nil + } + // Second call - after tool execution + return &Response{ + Content: []Content{TextContent{Text: "Image generated"}}, + Usage: Usage{TotalTokens: 20}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3))) + + result, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "Generate an image", + }) + + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Steps, 2) // Tool call step + final response + + // Check tool results in first step + toolResults := result.Steps[0].Content.ToolResults() + require.Len(t, toolResults, 1) + + mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia) + require.True(t, ok, "Expected media result") + require.Equal(t, string(imageData), mediaResult.Data) + require.Equal(t, "image/png", mediaResult.MediaType) + }) + + t.Run("Media tool response (audio)", func(t *testing.T) { + t.Parallel() + + audioTool := &mockTool{ + name: "generate_audio", + description: "Generates audio", + executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) { + return NewMediaResponse(audioData, "audio/wav"), nil + }, + } + + model := &mockLanguageModel{ + generateFunc: func(ctx context.Context, call Call) (*Response, error) { + if len(call.Prompt) == 1 { + return &Response{ + Content: []Content{ + ToolCallContent{ + ToolCallID: "audio-1", + ToolName: "generate_audio", + Input: `{}`, + }, + }, + Usage: Usage{TotalTokens: 10}, + FinishReason: FinishReasonToolCalls, + }, nil + } + return &Response{ + Content: []Content{TextContent{Text: "Audio generated"}}, + Usage: Usage{TotalTokens: 20}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3))) + + result, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "Generate audio", + }) + + require.NoError(t, err) + require.NotNil(t, result) + + toolResults := result.Steps[0].Content.ToolResults() + require.Len(t, toolResults, 1) + + mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia) + require.True(t, ok, "Expected media result") + require.Equal(t, string(audioData), mediaResult.Data) + require.Equal(t, "audio/wav", mediaResult.MediaType) + }) + + t.Run("Media response with text", func(t *testing.T) { + t.Parallel() + + imageTool := &mockTool{ + name: "screenshot", + description: "Takes a screenshot", + executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) { + resp := NewImageResponse(imageData, "image/png") + resp.Content = "Screenshot captured successfully" + return resp, nil + }, + } + + model := &mockLanguageModel{ + generateFunc: func(ctx context.Context, call Call) (*Response, error) { + if len(call.Prompt) == 1 { + return &Response{ + Content: []Content{ + ToolCallContent{ + ToolCallID: "screen-1", + ToolName: "screenshot", + Input: `{}`, + }, + }, + Usage: Usage{TotalTokens: 10}, + FinishReason: FinishReasonToolCalls, + }, nil + } + return &Response{ + Content: []Content{TextContent{Text: "Done"}}, + Usage: Usage{TotalTokens: 20}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3))) + + result, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "Take a screenshot", + }) + + require.NoError(t, err) + require.NotNil(t, result) + + toolResults := result.Steps[0].Content.ToolResults() + require.Len(t, toolResults, 1) + + mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia) + require.True(t, ok, "Expected media result") + require.Equal(t, string(imageData), mediaResult.Data) + require.Equal(t, "image/png", mediaResult.MediaType) + require.Equal(t, "Screenshot captured successfully", mediaResult.Text) + }) + + t.Run("Media response preserves metadata", func(t *testing.T) { + t.Parallel() + + type ImageMetadata struct { + Width int `json:"width"` + Height int `json:"height"` + } + + imageTool := &mockTool{ + name: "generate_image", + description: "Generates an image", + executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) { + resp := NewImageResponse(imageData, "image/png") + return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil + }, + } + + model := &mockLanguageModel{ + generateFunc: func(ctx context.Context, call Call) (*Response, error) { + if len(call.Prompt) == 1 { + return &Response{ + Content: []Content{ + ToolCallContent{ + ToolCallID: "img-1", + ToolName: "generate_image", + Input: `{}`, + }, + }, + Usage: Usage{TotalTokens: 10}, + FinishReason: FinishReasonToolCalls, + }, nil + } + return &Response{ + Content: []Content{TextContent{Text: "Done"}}, + Usage: Usage{TotalTokens: 20}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3))) + + result, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "Generate image", + }) + + require.NoError(t, err) + require.NotNil(t, result) + + toolResults := result.Steps[0].Content.ToolResults() + require.Len(t, toolResults, 1) + + // Check metadata was preserved + require.NotEmpty(t, toolResults[0].ClientMetadata) + + var metadata ImageMetadata + err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata) + require.NoError(t, err) + require.Equal(t, 800, metadata.Width) + require.Equal(t, 600, metadata.Height) + }) +} diff --git a/content.go b/content.go index 93dc7d1ca87962a9351062def5fb616d022c3c36..5bd15e6bf4b599d056a7d82de445d3a00bbf0f2a 100644 --- a/content.go +++ b/content.go @@ -306,8 +306,9 @@ func (t ToolResultOutputContentError) GetType() ToolResultContentType { // ToolResultOutputContentMedia represents media output content of a tool result. type ToolResultOutputContentMedia struct { - Data string `json:"data"` // for media type (base64) - MediaType string `json:"media_type"` // for media type + Data string `json:"data"` // for media type (base64) + MediaType string `json:"media_type"` // for media type + Text string `json:"text,omitempty"` // optional text content accompanying the media } // GetType returns the type of the tool result output content media. diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index afc8dcc5dd2c92596eadf730105b298c48a1e325..39f1c9b603625451fed4cae1f2392f997618014b 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -591,11 +591,19 @@ func toPrompt(prompt fantasy.Prompt, sendReasoningData bool) ([]anthropic.TextBl if !ok { continue } - toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{ + contentBlocks := []anthropic.ToolResultBlockParamContentUnion{ { OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage, }, } + if content.Text != "" { + contentBlocks = append(contentBlocks, anthropic.ToolResultBlockParamContentUnion{ + OfText: &anthropic.TextBlockParam{ + Text: content.Text, + }, + }) + } + toolResultBlock.Content = contentBlocks case fantasy.ToolResultContentTypeError: content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output) if !ok { diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 7a30139ac31080566fa43393512bcc3b8ba085b6..85a86513d73f56de56b20f9a968fb97ef1b1ab7a 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -554,29 +554,6 @@ func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string) (respons continue } outputStr = output.Error.Error() - case fantasy.ToolResultContentTypeMedia: - output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResultPart.Output) - if !ok { - warnings = append(warnings, fantasy.CallWarning{ - Type: fantasy.CallWarningTypeOther, - Message: "tool result output does not have the right type", - }) - continue - } - // For media content, encode as JSON with data and media type - mediaContent := map[string]string{ - "data": output.Data, - "media_type": output.MediaType, - } - jsonBytes, err := json.Marshal(mediaContent) - if err != nil { - warnings = append(warnings, fantasy.CallWarning{ - Type: fantasy.CallWarningTypeOther, - Message: fmt.Sprintf("failed to marshal tool result: %v", err), - }) - continue - } - outputStr = string(jsonBytes) } input = append(input, responses.ResponseInputItemParamOfFunctionCallOutput(toolResultPart.ToolCallID, outputStr)) diff --git a/tool.go b/tool.go index e6823062419756d7313d3dce9fb19036775fb9ec..be146ef3fce25310be2a2a569ab4c2965c40eaa7 100644 --- a/tool.go +++ b/tool.go @@ -30,10 +30,14 @@ type ToolCall struct { // ToolResponse represents the response from a tool execution, matching the existing pattern. type ToolResponse struct { - Type string `json:"type"` - Content string `json:"content"` - Metadata string `json:"metadata,omitempty"` - IsError bool `json:"is_error"` + Type string `json:"type"` + Content string `json:"content"` + // Data contains binary data for image/media responses (e.g., image bytes, audio data). + Data []byte `json:"data,omitempty"` + // MediaType specifies the MIME type of the media (e.g., "image/png", "audio/wav"). + MediaType string `json:"media_type,omitempty"` + Metadata string `json:"metadata,omitempty"` + IsError bool `json:"is_error"` } // NewTextResponse creates a text response. @@ -53,6 +57,24 @@ func NewTextErrorResponse(content string) ToolResponse { } } +// NewImageResponse creates an image response with binary data. +func NewImageResponse(data []byte, mediaType string) ToolResponse { + return ToolResponse{ + Type: "image", + Data: data, + MediaType: mediaType, + } +} + +// NewMediaResponse creates a media response with binary data (e.g., audio, video). +func NewMediaResponse(data []byte, mediaType string) ToolResponse { + return ToolResponse{ + Type: "media", + Data: data, + MediaType: mediaType, + } +} + // WithResponseMetadata adds metadata to a response. func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse { if metadata != nil { diff --git a/tool_test.go b/tool_test.go index 9f7db1d56905102d911fc69d7f62063bc0295c01..7567c7cd45a38504c42d4e27da6942f51ca5805d 100644 --- a/tool_test.go +++ b/tool_test.go @@ -83,3 +83,29 @@ func TestEnumToolExample(t *testing.T) { require.Contains(t, result.Content, "San Francisco") require.Contains(t, result.Content, "72°F") } + +func TestNewImageResponse(t *testing.T) { + imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes + mediaType := "image/png" + + resp := NewImageResponse(imageData, mediaType) + + require.Equal(t, "image", resp.Type) + require.Equal(t, imageData, resp.Data) + require.Equal(t, mediaType, resp.MediaType) + require.False(t, resp.IsError) + require.Empty(t, resp.Content) +} + +func TestNewMediaResponse(t *testing.T) { + audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes + mediaType := "audio/wav" + + resp := NewMediaResponse(audioData, mediaType) + + require.Equal(t, "media", resp.Type) + require.Equal(t, audioData, resp.Data) + require.Equal(t, mediaType, resp.MediaType) + require.False(t, resp.IsError) + require.Empty(t, resp.Content) +}