diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 90eabe638f5fe167a5cdedc9e9cf1e5b02a9125a..6e659bbfc4cfc9f697012f4a289f3b345905a696 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -245,7 +245,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy defer cancel() defer a.activeRequests.Del(call.SessionID) - history, files := a.preparePrompt(msgs, call.Attachments...) + history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...) startTime := time.Now() a.eventPromptSent(call.SessionID) @@ -643,7 +643,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan return nil } - aiMsgs, _ := a.preparePrompt(msgs) + aiMsgs, _ := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages) genCtx, cancel := context.WithCancel(ctx) a.activeRequests.Set(sessionID, cancel) @@ -786,7 +786,7 @@ func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentC return msg, nil } -func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) { +func (a *sessionAgent) preparePrompt(msgs []message.Message, supportsImages bool, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) { var history []fantasy.Message if !a.isSubAgent { history = append(history, fantasy.NewUserMessage( @@ -830,7 +830,15 @@ If not, please feel free to ignore. Again do not mention this message to the use } continue } - history = append(history, m.ToAIMessage()...) + aiMsgs := m.ToAIMessage() + if !supportsImages { + for i := range aiMsgs { + if aiMsgs[i].Role == fantasy.MessageRoleUser { + aiMsgs[i].Content = filterFileParts(aiMsgs[i].Content) + } + } + } + history = append(history, aiMsgs...) if m.Role == message.Assistant { if msg, ok := syntheticToolResultsForOrphanedCalls(m, knownToolResultIDs); ok { @@ -854,6 +862,20 @@ If not, please feel free to ignore. Again do not mention this message to the use return history, files } +// filterFileParts removes fantasy.FilePart entries from a slice of message +// parts. Used to strip image attachments from historical user messages when +// the current model does not support them. +func filterFileParts(parts []fantasy.MessagePart) []fantasy.MessagePart { + filtered := make([]fantasy.MessagePart, 0, len(parts)) + for _, part := range parts { + if _, ok := fantasy.AsMessagePart[fantasy.FilePart](part); ok { + continue + } + filtered = append(filtered, part) + } + return filtered +} + // filterOrphanedToolResults converts a tool message to a fantasy.Message, // dropping any tool result parts whose tool_call_id has no matching tool call // in the known set. An orphaned result causes API validation to fail on every diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index 263a0691ada91d5c44fa419eaee8f6ad02891cf9..50c3ffcca67c51310094b218d74b2a6e2f820b70 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -656,6 +656,51 @@ func BenchmarkBuildSummaryPrompt(b *testing.B) { } } +func TestPreparePrompt_FiltersImageAttachments(t *testing.T) { + env := testEnv(t) + sa := testSessionAgent(env, nil, nil, "test prompt") + agent := sa.(*sessionAgent) + + ctx := t.Context() + sess, err := env.sessions.Create(ctx, "test") + require.NoError(t, err) + + // User message with text, a text attachment, and an image attachment. + _, err = env.messages.Create(ctx, sess.ID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "hello world"}, + message.BinaryContent{Path: "notes.txt", MIMEType: "text/plain", Data: []byte("important notes")}, + message.BinaryContent{Path: "image.png", MIMEType: "image/png", Data: []byte("fake-image-data")}, + }, + }) + require.NoError(t, err) + + msgs, err := env.messages.List(ctx, sess.ID) + require.NoError(t, err) + + // When supportsImages is false, image attachments should be stripped. + history, _ := agent.preparePrompt(msgs, false) + // First message is the system reminder, second is the user message. + require.Len(t, history, 2) + require.Len(t, history[1].Content, 1) + text, ok := fantasy.AsMessagePart[fantasy.TextPart](history[1].Content[0]) + require.True(t, ok) + require.Contains(t, text.Text, "hello world") + require.Contains(t, text.Text, "important notes") + + // When supportsImages is true, image attachments should remain. + history, _ = agent.preparePrompt(msgs, true) + require.Len(t, history, 2) + require.Len(t, history[1].Content, 2) + text, ok = fantasy.AsMessagePart[fantasy.TextPart](history[1].Content[0]) + require.True(t, ok) + require.Contains(t, text.Text, "hello world") + file, ok := fantasy.AsMessagePart[fantasy.FilePart](history[1].Content[1]) + require.True(t, ok) + require.Equal(t, "image.png", file.Filename) +} + func TestPreparePrompt_OrphanedToolUse(t *testing.T) { env := testEnv(t) sa := testSessionAgent(env, nil, nil, "test prompt") @@ -702,7 +747,7 @@ func TestPreparePrompt_OrphanedToolUse(t *testing.T) { msgs, err := env.messages.List(ctx, sess.ID) require.NoError(t, err) - history, _ := agent.preparePrompt(msgs) + history, _ := agent.preparePrompt(msgs, true) // The history must contain a synthetic tool result for the orphaned call. found := false @@ -776,7 +821,7 @@ func TestPreparePrompt_OrphanedToolUseMixed(t *testing.T) { msgs, err := env.messages.List(ctx, sess.ID) require.NoError(t, err) - history, _ := agent.preparePrompt(msgs) + history, _ := agent.preparePrompt(msgs, true) // Should have a synthetic result only for the orphaned call. var syntheticCount int