@@ -1535,3 +1535,236 @@ func TestToolCallRepair(t *testing.T) {
require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
})
}
+
+// Test media and image tool responses
+func TestAgent_MediaToolResponses(t *testing.T) {
+ t.Parallel()
+
+ imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
+ audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
+
+ t.Run("Image tool response", func(t *testing.T) {
+ t.Parallel()
+
+ imageTool := &mockTool{
+ name: "generate_image",
+ description: "Generates an image",
+ executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ return NewImageResponse(imageData, "image/png"), nil
+ },
+ }
+
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ if len(call.Prompt) == 1 {
+ // First call - request image tool
+ return &Response{
+ Content: []Content{
+ ToolCallContent{
+ ToolCallID: "img-1",
+ ToolName: "generate_image",
+ Input: `{}`,
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonToolCalls,
+ }, nil
+ }
+ // Second call - after tool execution
+ return &Response{
+ Content: []Content{TextContent{Text: "Image generated"}},
+ Usage: Usage{TotalTokens: 20},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "Generate an image",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 2) // Tool call step + final response
+
+ // Check tool results in first step
+ toolResults := result.Steps[0].Content.ToolResults()
+ require.Len(t, toolResults, 1)
+
+ mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
+ require.True(t, ok, "Expected media result")
+ require.Equal(t, string(imageData), mediaResult.Data)
+ require.Equal(t, "image/png", mediaResult.MediaType)
+ })
+
+ t.Run("Media tool response (audio)", func(t *testing.T) {
+ t.Parallel()
+
+ audioTool := &mockTool{
+ name: "generate_audio",
+ description: "Generates audio",
+ executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ return NewMediaResponse(audioData, "audio/wav"), nil
+ },
+ }
+
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ if len(call.Prompt) == 1 {
+ return &Response{
+ Content: []Content{
+ ToolCallContent{
+ ToolCallID: "audio-1",
+ ToolName: "generate_audio",
+ Input: `{}`,
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonToolCalls,
+ }, nil
+ }
+ return &Response{
+ Content: []Content{TextContent{Text: "Audio generated"}},
+ Usage: Usage{TotalTokens: 20},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "Generate audio",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ toolResults := result.Steps[0].Content.ToolResults()
+ require.Len(t, toolResults, 1)
+
+ mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
+ require.True(t, ok, "Expected media result")
+ require.Equal(t, string(audioData), mediaResult.Data)
+ require.Equal(t, "audio/wav", mediaResult.MediaType)
+ })
+
+ t.Run("Media response with text", func(t *testing.T) {
+ t.Parallel()
+
+ imageTool := &mockTool{
+ name: "screenshot",
+ description: "Takes a screenshot",
+ executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ resp := NewImageResponse(imageData, "image/png")
+ resp.Content = "Screenshot captured successfully"
+ return resp, nil
+ },
+ }
+
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ if len(call.Prompt) == 1 {
+ return &Response{
+ Content: []Content{
+ ToolCallContent{
+ ToolCallID: "screen-1",
+ ToolName: "screenshot",
+ Input: `{}`,
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonToolCalls,
+ }, nil
+ }
+ return &Response{
+ Content: []Content{TextContent{Text: "Done"}},
+ Usage: Usage{TotalTokens: 20},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "Take a screenshot",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ toolResults := result.Steps[0].Content.ToolResults()
+ require.Len(t, toolResults, 1)
+
+ mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
+ require.True(t, ok, "Expected media result")
+ require.Equal(t, string(imageData), mediaResult.Data)
+ require.Equal(t, "image/png", mediaResult.MediaType)
+ require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
+ })
+
+ t.Run("Media response preserves metadata", func(t *testing.T) {
+ t.Parallel()
+
+ type ImageMetadata struct {
+ Width int `json:"width"`
+ Height int `json:"height"`
+ }
+
+ imageTool := &mockTool{
+ name: "generate_image",
+ description: "Generates an image",
+ executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ resp := NewImageResponse(imageData, "image/png")
+ return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
+ },
+ }
+
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ if len(call.Prompt) == 1 {
+ return &Response{
+ Content: []Content{
+ ToolCallContent{
+ ToolCallID: "img-1",
+ ToolName: "generate_image",
+ Input: `{}`,
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonToolCalls,
+ }, nil
+ }
+ return &Response{
+ Content: []Content{TextContent{Text: "Done"}},
+ Usage: Usage{TotalTokens: 20},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "Generate image",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ toolResults := result.Steps[0].Content.ToolResults()
+ require.Len(t, toolResults, 1)
+
+ // Check metadata was preserved
+ require.NotEmpty(t, toolResults[0].ClientMetadata)
+
+ var metadata ImageMetadata
+ err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
+ require.NoError(t, err)
+ require.Equal(t, 800, metadata.Width)
+ require.Equal(t, 600, metadata.Height)
+ })
+}
@@ -83,3 +83,29 @@ func TestEnumToolExample(t *testing.T) {
require.Contains(t, result.Content, "San Francisco")
require.Contains(t, result.Content, "72°F")
}
+
+func TestNewImageResponse(t *testing.T) {
+ imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
+ mediaType := "image/png"
+
+ resp := NewImageResponse(imageData, mediaType)
+
+ require.Equal(t, "image", resp.Type)
+ require.Equal(t, imageData, resp.Data)
+ require.Equal(t, mediaType, resp.MediaType)
+ require.False(t, resp.IsError)
+ require.Empty(t, resp.Content)
+}
+
+func TestNewMediaResponse(t *testing.T) {
+ audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
+ mediaType := "audio/wav"
+
+ resp := NewMediaResponse(audioData, mediaType)
+
+ require.Equal(t, "media", resp.Type)
+ require.Equal(t, audioData, resp.Data)
+ require.Equal(t, mediaType, resp.MediaType)
+ require.False(t, resp.IsError)
+ require.Empty(t, resp.Content)
+}