feat: add support for media tool response (#91)

Kujtim Hoxha created

Change summary

agent.go                                     |  13 +
agent_test.go                                | 233 ++++++++++++++++++++++
content.go                                   |   5 
providers/anthropic/anthropic.go             |  10 
providers/openai/responses_language_model.go |  23 --
tool.go                                      |  30 ++
tool_test.go                                 |  26 ++
7 files changed, 310 insertions(+), 30 deletions(-)

Detailed changes

agent.go 🔗

@@ -709,7 +709,20 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 				ClientMetadata:   toolResult.Metadata,
 				ProviderExecuted: false,
 			}
+		} else if toolResult.Type == "image" || toolResult.Type == "media" {
+			result = ToolResultContent{
+				ToolCallID: toolCall.ToolCallID,
+				ToolName:   toolCall.ToolName,
+				Result: ToolResultOutputContentMedia{
+					Data:      string(toolResult.Data),
+					MediaType: toolResult.MediaType,
+					Text:      toolResult.Content,
+				},
+				ClientMetadata:   toolResult.Metadata,
+				ProviderExecuted: false,
+			}
 		} else {
+			// Default to text response
 			result = ToolResultContent{
 				ToolCallID: toolCall.ToolCallID,
 				ToolName:   toolCall.ToolName,

agent_test.go 🔗

@@ -1535,3 +1535,236 @@ func TestToolCallRepair(t *testing.T) {
 		require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
 	})
 }
+
+// Test media and image tool responses
+func TestAgent_MediaToolResponses(t *testing.T) {
+	t.Parallel()
+
+	imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
+	audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
+
+	t.Run("Image tool response", func(t *testing.T) {
+		t.Parallel()
+
+		imageTool := &mockTool{
+			name:        "generate_image",
+			description: "Generates an image",
+			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+				return NewImageResponse(imageData, "image/png"), nil
+			},
+		}
+
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				if len(call.Prompt) == 1 {
+					// First call - request image tool
+					return &Response{
+						Content: []Content{
+							ToolCallContent{
+								ToolCallID: "img-1",
+								ToolName:   "generate_image",
+								Input:      `{}`,
+							},
+						},
+						Usage:        Usage{TotalTokens: 10},
+						FinishReason: FinishReasonToolCalls,
+					}, nil
+				}
+				// Second call - after tool execution
+				return &Response{
+					Content:      []Content{TextContent{Text: "Image generated"}},
+					Usage:        Usage{TotalTokens: 20},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "Generate an image",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 2) // Tool call step + final response
+
+		// Check tool results in first step
+		toolResults := result.Steps[0].Content.ToolResults()
+		require.Len(t, toolResults, 1)
+
+		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
+		require.True(t, ok, "Expected media result")
+		require.Equal(t, string(imageData), mediaResult.Data)
+		require.Equal(t, "image/png", mediaResult.MediaType)
+	})
+
+	t.Run("Media tool response (audio)", func(t *testing.T) {
+		t.Parallel()
+
+		audioTool := &mockTool{
+			name:        "generate_audio",
+			description: "Generates audio",
+			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+				return NewMediaResponse(audioData, "audio/wav"), nil
+			},
+		}
+
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				if len(call.Prompt) == 1 {
+					return &Response{
+						Content: []Content{
+							ToolCallContent{
+								ToolCallID: "audio-1",
+								ToolName:   "generate_audio",
+								Input:      `{}`,
+							},
+						},
+						Usage:        Usage{TotalTokens: 10},
+						FinishReason: FinishReasonToolCalls,
+					}, nil
+				}
+				return &Response{
+					Content:      []Content{TextContent{Text: "Audio generated"}},
+					Usage:        Usage{TotalTokens: 20},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "Generate audio",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+
+		toolResults := result.Steps[0].Content.ToolResults()
+		require.Len(t, toolResults, 1)
+
+		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
+		require.True(t, ok, "Expected media result")
+		require.Equal(t, string(audioData), mediaResult.Data)
+		require.Equal(t, "audio/wav", mediaResult.MediaType)
+	})
+
+	t.Run("Media response with text", func(t *testing.T) {
+		t.Parallel()
+
+		imageTool := &mockTool{
+			name:        "screenshot",
+			description: "Takes a screenshot",
+			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+				resp := NewImageResponse(imageData, "image/png")
+				resp.Content = "Screenshot captured successfully"
+				return resp, nil
+			},
+		}
+
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				if len(call.Prompt) == 1 {
+					return &Response{
+						Content: []Content{
+							ToolCallContent{
+								ToolCallID: "screen-1",
+								ToolName:   "screenshot",
+								Input:      `{}`,
+							},
+						},
+						Usage:        Usage{TotalTokens: 10},
+						FinishReason: FinishReasonToolCalls,
+					}, nil
+				}
+				return &Response{
+					Content:      []Content{TextContent{Text: "Done"}},
+					Usage:        Usage{TotalTokens: 20},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "Take a screenshot",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+
+		toolResults := result.Steps[0].Content.ToolResults()
+		require.Len(t, toolResults, 1)
+
+		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
+		require.True(t, ok, "Expected media result")
+		require.Equal(t, string(imageData), mediaResult.Data)
+		require.Equal(t, "image/png", mediaResult.MediaType)
+		require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
+	})
+
+	t.Run("Media response preserves metadata", func(t *testing.T) {
+		t.Parallel()
+
+		type ImageMetadata struct {
+			Width  int `json:"width"`
+			Height int `json:"height"`
+		}
+
+		imageTool := &mockTool{
+			name:        "generate_image",
+			description: "Generates an image",
+			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+				resp := NewImageResponse(imageData, "image/png")
+				return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
+			},
+		}
+
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				if len(call.Prompt) == 1 {
+					return &Response{
+						Content: []Content{
+							ToolCallContent{
+								ToolCallID: "img-1",
+								ToolName:   "generate_image",
+								Input:      `{}`,
+							},
+						},
+						Usage:        Usage{TotalTokens: 10},
+						FinishReason: FinishReasonToolCalls,
+					}, nil
+				}
+				return &Response{
+					Content:      []Content{TextContent{Text: "Done"}},
+					Usage:        Usage{TotalTokens: 20},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "Generate image",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+
+		toolResults := result.Steps[0].Content.ToolResults()
+		require.Len(t, toolResults, 1)
+
+		// Check metadata was preserved
+		require.NotEmpty(t, toolResults[0].ClientMetadata)
+
+		var metadata ImageMetadata
+		err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
+		require.NoError(t, err)
+		require.Equal(t, 800, metadata.Width)
+		require.Equal(t, 600, metadata.Height)
+	})
+}

content.go 🔗

@@ -306,8 +306,9 @@ func (t ToolResultOutputContentError) GetType() ToolResultContentType {
 
 // ToolResultOutputContentMedia represents media output content of a tool result.
 type ToolResultOutputContentMedia struct {
-	Data      string `json:"data"`       // for media type (base64)
-	MediaType string `json:"media_type"` // for media type
+	Data      string `json:"data"`           // for media type (base64)
+	MediaType string `json:"media_type"`     // for media type
+	Text      string `json:"text,omitempty"` // optional text content accompanying the media
 }
 
 // GetType returns the type of the tool result output content media.

providers/anthropic/anthropic.go 🔗

@@ -591,11 +591,19 @@ func toPrompt(prompt fantasy.Prompt, sendReasoningData bool) ([]anthropic.TextBl
 							if !ok {
 								continue
 							}
-							toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
+							contentBlocks := []anthropic.ToolResultBlockParamContentUnion{
 								{
 									OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage,
 								},
 							}
+							if content.Text != "" {
+								contentBlocks = append(contentBlocks, anthropic.ToolResultBlockParamContentUnion{
+									OfText: &anthropic.TextBlockParam{
+										Text: content.Text,
+									},
+								})
+							}
+							toolResultBlock.Content = contentBlocks
 						case fantasy.ToolResultContentTypeError:
 							content, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Output)
 							if !ok {

providers/openai/responses_language_model.go 🔗

@@ -554,29 +554,6 @@ func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string) (respons
 						continue
 					}
 					outputStr = output.Error.Error()
-				case fantasy.ToolResultContentTypeMedia:
-					output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResultPart.Output)
-					if !ok {
-						warnings = append(warnings, fantasy.CallWarning{
-							Type:    fantasy.CallWarningTypeOther,
-							Message: "tool result output does not have the right type",
-						})
-						continue
-					}
-					// For media content, encode as JSON with data and media type
-					mediaContent := map[string]string{
-						"data":       output.Data,
-						"media_type": output.MediaType,
-					}
-					jsonBytes, err := json.Marshal(mediaContent)
-					if err != nil {
-						warnings = append(warnings, fantasy.CallWarning{
-							Type:    fantasy.CallWarningTypeOther,
-							Message: fmt.Sprintf("failed to marshal tool result: %v", err),
-						})
-						continue
-					}
-					outputStr = string(jsonBytes)
 				}
 
 				input = append(input, responses.ResponseInputItemParamOfFunctionCallOutput(toolResultPart.ToolCallID, outputStr))

tool.go 🔗

@@ -30,10 +30,14 @@ type ToolCall struct {
 
 // ToolResponse represents the response from a tool execution, matching the existing pattern.
 type ToolResponse struct {
-	Type     string `json:"type"`
-	Content  string `json:"content"`
-	Metadata string `json:"metadata,omitempty"`
-	IsError  bool   `json:"is_error"`
+	Type    string `json:"type"`
+	Content string `json:"content"`
+	// Data contains binary data for image/media responses (e.g., image bytes, audio data).
+	Data []byte `json:"data,omitempty"`
+	// MediaType specifies the MIME type of the media (e.g., "image/png", "audio/wav").
+	MediaType string `json:"media_type,omitempty"`
+	Metadata  string `json:"metadata,omitempty"`
+	IsError   bool   `json:"is_error"`
 }
 
 // NewTextResponse creates a text response.
@@ -53,6 +57,24 @@ func NewTextErrorResponse(content string) ToolResponse {
 	}
 }
 
+// NewImageResponse creates an image response with binary data.
+func NewImageResponse(data []byte, mediaType string) ToolResponse {
+	return ToolResponse{
+		Type:      "image",
+		Data:      data,
+		MediaType: mediaType,
+	}
+}
+
+// NewMediaResponse creates a media response with binary data (e.g., audio, video).
+func NewMediaResponse(data []byte, mediaType string) ToolResponse {
+	return ToolResponse{
+		Type:      "media",
+		Data:      data,
+		MediaType: mediaType,
+	}
+}
+
 // WithResponseMetadata adds metadata to a response.
 func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
 	if metadata != nil {

tool_test.go 🔗

@@ -83,3 +83,29 @@ func TestEnumToolExample(t *testing.T) {
 	require.Contains(t, result.Content, "San Francisco")
 	require.Contains(t, result.Content, "72°F")
 }
+
+func TestNewImageResponse(t *testing.T) {
+	imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
+	mediaType := "image/png"
+
+	resp := NewImageResponse(imageData, mediaType)
+
+	require.Equal(t, "image", resp.Type)
+	require.Equal(t, imageData, resp.Data)
+	require.Equal(t, mediaType, resp.MediaType)
+	require.False(t, resp.IsError)
+	require.Empty(t, resp.Content)
+}
+
+func TestNewMediaResponse(t *testing.T) {
+	audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
+	mediaType := "audio/wav"
+
+	resp := NewMediaResponse(audioData, mediaType)
+
+	require.Equal(t, "media", resp.Type)
+	require.Equal(t, audioData, resp.Data)
+	require.Equal(t, mediaType, resp.MediaType)
+	require.False(t, resp.IsError)
+	require.Empty(t, resp.Content)
+}