1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "net/http"
19 "os"
20 "regexp"
21 "strconv"
22 "strings"
23 "sync"
24 "time"
25
26 "charm.land/catwalk/pkg/catwalk"
27 "charm.land/fantasy"
28 "charm.land/fantasy/providers/anthropic"
29 "charm.land/fantasy/providers/bedrock"
30 "charm.land/fantasy/providers/google"
31 "charm.land/fantasy/providers/openai"
32 "charm.land/fantasy/providers/openrouter"
33 "charm.land/fantasy/providers/vercel"
34 "charm.land/lipgloss/v2"
35 "github.com/charmbracelet/crush/internal/agent/hyper"
36 "github.com/charmbracelet/crush/internal/agent/notify"
37 "github.com/charmbracelet/crush/internal/agent/tools"
38 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
39 "github.com/charmbracelet/crush/internal/config"
40 "github.com/charmbracelet/crush/internal/csync"
41 "github.com/charmbracelet/crush/internal/message"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 Prompt string
75 ProviderOptions fantasy.ProviderOptions
76 Attachments []message.Attachment
77 MaxOutputTokens int64
78 Temperature *float64
79 TopP *float64
80 TopK *int64
81 FrequencyPenalty *float64
82 PresencePenalty *float64
83 NonInteractive bool
84}
85
86type SessionAgent interface {
87 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
88 SetModels(large Model, small Model)
89 SetTools(tools []fantasy.AgentTool)
90 SetSystemPrompt(systemPrompt string)
91 Cancel(sessionID string)
92 CancelAll()
93 IsSessionBusy(sessionID string) bool
94 IsBusy() bool
95 QueuedPrompts(sessionID string) int
96 QueuedPromptsList(sessionID string) []string
97 ClearQueue(sessionID string)
98 Summarize(context.Context, string, fantasy.ProviderOptions) error
99 Model() Model
100}
101
102type Model struct {
103 Model fantasy.LanguageModel
104 CatwalkCfg catwalk.Model
105 ModelCfg config.SelectedModel
106}
107
108type sessionAgent struct {
109 largeModel *csync.Value[Model]
110 smallModel *csync.Value[Model]
111 systemPromptPrefix *csync.Value[string]
112 systemPrompt *csync.Value[string]
113 tools *csync.Slice[fantasy.AgentTool]
114
115 isSubAgent bool
116 sessions session.Service
117 messages message.Service
118 disableAutoSummarize bool
119 isYolo bool
120 notify pubsub.Publisher[notify.Notification]
121
122 messageQueue *csync.Map[string, []SessionAgentCall]
123 activeRequests *csync.Map[string, context.CancelFunc]
124}
125
126type SessionAgentOptions struct {
127 LargeModel Model
128 SmallModel Model
129 SystemPromptPrefix string
130 SystemPrompt string
131 IsSubAgent bool
132 DisableAutoSummarize bool
133 IsYolo bool
134 Sessions session.Service
135 Messages message.Service
136 Tools []fantasy.AgentTool
137 Notify pubsub.Publisher[notify.Notification]
138}
139
140func NewSessionAgent(
141 opts SessionAgentOptions,
142) SessionAgent {
143 return &sessionAgent{
144 largeModel: csync.NewValue(opts.LargeModel),
145 smallModel: csync.NewValue(opts.SmallModel),
146 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
147 systemPrompt: csync.NewValue(opts.SystemPrompt),
148 isSubAgent: opts.IsSubAgent,
149 sessions: opts.Sessions,
150 messages: opts.Messages,
151 disableAutoSummarize: opts.DisableAutoSummarize,
152 tools: csync.NewSliceFrom(opts.Tools),
153 isYolo: opts.IsYolo,
154 notify: opts.Notify,
155 messageQueue: csync.NewMap[string, []SessionAgentCall](),
156 activeRequests: csync.NewMap[string, context.CancelFunc](),
157 }
158}
159
160func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
161 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
162 return nil, ErrEmptyPrompt
163 }
164 if call.SessionID == "" {
165 return nil, ErrSessionMissing
166 }
167
168 // Queue the message if busy
169 if a.IsSessionBusy(call.SessionID) {
170 existing, ok := a.messageQueue.Get(call.SessionID)
171 if !ok {
172 existing = []SessionAgentCall{}
173 }
174 existing = append(existing, call)
175 a.messageQueue.Set(call.SessionID, existing)
176 return nil, nil
177 }
178
179 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
180 agentTools := a.tools.Copy()
181 largeModel := a.largeModel.Get()
182 systemPrompt := a.systemPrompt.Get()
183 promptPrefix := a.systemPromptPrefix.Get()
184 var instructions strings.Builder
185
186 for _, server := range mcp.GetStates() {
187 if server.State != mcp.StateConnected {
188 continue
189 }
190 if s := server.Client.InitializeResult().Instructions; s != "" {
191 instructions.WriteString(s)
192 instructions.WriteString("\n\n")
193 }
194 }
195
196 if s := instructions.String(); s != "" {
197 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
198 }
199
200 if len(agentTools) > 0 {
201 // Add Anthropic caching to the last tool.
202 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
203 }
204
205 agent := fantasy.NewAgent(
206 largeModel.Model,
207 fantasy.WithSystemPrompt(systemPrompt),
208 fantasy.WithTools(agentTools...),
209 fantasy.WithUserAgent(userAgent),
210 )
211
212 sessionLock := sync.Mutex{}
213 currentSession, err := a.sessions.Get(ctx, call.SessionID)
214 if err != nil {
215 return nil, fmt.Errorf("failed to get session: %w", err)
216 }
217
218 msgs, err := a.getSessionMessages(ctx, currentSession)
219 if err != nil {
220 return nil, fmt.Errorf("failed to get session messages: %w", err)
221 }
222
223 var wg sync.WaitGroup
224 // Generate title if first message.
225 if len(msgs) == 0 {
226 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
227 wg.Go(func() {
228 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
229 })
230 }
231 defer wg.Wait()
232
233 // Add the user message to the session.
234 _, err = a.createUserMessage(ctx, call)
235 if err != nil {
236 return nil, err
237 }
238
239 // Add the session to the context.
240 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
241
242 genCtx, cancel := context.WithCancel(ctx)
243 a.activeRequests.Set(call.SessionID, cancel)
244
245 defer cancel()
246 defer a.activeRequests.Del(call.SessionID)
247
248 history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
249
250 startTime := time.Now()
251 a.eventPromptSent(call.SessionID)
252
253 var currentAssistant *message.Message
254 var shouldSummarize bool
255 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
256 var maxOutputTokens *int64
257 if call.MaxOutputTokens > 0 {
258 maxOutputTokens = &call.MaxOutputTokens
259 }
260 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
261 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
262 Files: files,
263 Messages: history,
264 ProviderOptions: call.ProviderOptions,
265 MaxOutputTokens: maxOutputTokens,
266 TopP: call.TopP,
267 Temperature: call.Temperature,
268 PresencePenalty: call.PresencePenalty,
269 TopK: call.TopK,
270 FrequencyPenalty: call.FrequencyPenalty,
271 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
272 prepared.Messages = options.Messages
273 for i := range prepared.Messages {
274 prepared.Messages[i].ProviderOptions = nil
275 }
276
277 // Use latest tools (updated by SetTools when MCP tools change).
278 prepared.Tools = a.tools.Copy()
279
280 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
281 a.messageQueue.Del(call.SessionID)
282 for _, queued := range queuedCalls {
283 userMessage, createErr := a.createUserMessage(callContext, queued)
284 if createErr != nil {
285 return callContext, prepared, createErr
286 }
287 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
288 }
289
290 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
291
292 lastSystemRoleInx := 0
293 systemMessageUpdated := false
294 for i, msg := range prepared.Messages {
295 // Only add cache control to the last message.
296 if msg.Role == fantasy.MessageRoleSystem {
297 lastSystemRoleInx = i
298 } else if !systemMessageUpdated {
299 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
300 systemMessageUpdated = true
301 }
302 // Than add cache control to the last 2 messages.
303 if i > len(prepared.Messages)-3 {
304 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
305 }
306 }
307
308 if promptPrefix != "" {
309 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
310 }
311
312 var assistantMsg message.Message
313 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
314 Role: message.Assistant,
315 Parts: []message.ContentPart{},
316 Model: largeModel.ModelCfg.Model,
317 Provider: largeModel.ModelCfg.Provider,
318 })
319 if err != nil {
320 return callContext, prepared, err
321 }
322 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
323 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
324 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
325 currentAssistant = &assistantMsg
326 return callContext, prepared, err
327 },
328 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
329 currentAssistant.AppendReasoningContent(reasoning.Text)
330 return a.messages.Update(genCtx, *currentAssistant)
331 },
332 OnReasoningDelta: func(id string, text string) error {
333 currentAssistant.AppendReasoningContent(text)
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
337 // handle anthropic signature
338 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
339 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
340 currentAssistant.AppendReasoningSignature(reasoning.Signature)
341 }
342 }
343 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
344 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
345 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
346 }
347 }
348 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
349 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
350 currentAssistant.SetReasoningResponsesData(reasoning)
351 }
352 }
353 currentAssistant.FinishThinking()
354 return a.messages.Update(genCtx, *currentAssistant)
355 },
356 OnTextDelta: func(id string, text string) error {
357 // Strip leading newline from initial text content. This is is
358 // particularly important in non-interactive mode where leading
359 // newlines are very visible.
360 if len(currentAssistant.Parts) == 0 {
361 text = strings.TrimPrefix(text, "\n")
362 }
363
364 currentAssistant.AppendContent(text)
365 return a.messages.Update(genCtx, *currentAssistant)
366 },
367 OnToolInputStart: func(id string, toolName string) error {
368 toolCall := message.ToolCall{
369 ID: id,
370 Name: toolName,
371 ProviderExecuted: false,
372 Finished: false,
373 }
374 currentAssistant.AddToolCall(toolCall)
375 // Use parent ctx instead of genCtx to ensure the update succeeds
376 // even if the request is canceled mid-stream
377 return a.messages.Update(ctx, *currentAssistant)
378 },
379 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
380 slog.Warn("Provider request failed, retrying", providerRetryLogFields(err, delay)...)
381 },
382 OnToolCall: func(tc fantasy.ToolCallContent) error {
383 toolCall := message.ToolCall{
384 ID: tc.ToolCallID,
385 Name: tc.ToolName,
386 Input: tc.Input,
387 ProviderExecuted: false,
388 Finished: true,
389 }
390 currentAssistant.AddToolCall(toolCall)
391 // Use parent ctx instead of genCtx to ensure the update succeeds
392 // even if the request is canceled mid-stream
393 return a.messages.Update(ctx, *currentAssistant)
394 },
395 OnToolResult: func(result fantasy.ToolResultContent) error {
396 toolResult := a.convertToToolResult(result)
397 // Use parent ctx instead of genCtx to ensure the message is created
398 // even if the request is canceled mid-stream
399 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
400 Role: message.Tool,
401 Parts: []message.ContentPart{
402 toolResult,
403 },
404 })
405 return createMsgErr
406 },
407 OnStepFinish: func(stepResult fantasy.StepResult) error {
408 finishReason := message.FinishReasonUnknown
409 switch stepResult.FinishReason {
410 case fantasy.FinishReasonLength:
411 finishReason = message.FinishReasonMaxTokens
412 case fantasy.FinishReasonStop:
413 finishReason = message.FinishReasonEndTurn
414 case fantasy.FinishReasonToolCalls:
415 finishReason = message.FinishReasonToolUse
416 }
417 // If a tool result halted the turn (e.g. a hook halt or a
418 // permission denial), the step ends on FinishReasonToolCalls but
419 // the model will not be called again. Treat it as the end of the
420 // turn so the UI can render the assistant footer.
421 if finishReason == message.FinishReasonToolUse {
422 for _, tr := range stepResult.Content.ToolResults() {
423 if tr.StopTurn {
424 finishReason = message.FinishReasonEndTurn
425 break
426 }
427 }
428 }
429 currentAssistant.AddFinish(finishReason, "", "")
430 sessionLock.Lock()
431 defer sessionLock.Unlock()
432
433 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
434 if getSessionErr != nil {
435 return getSessionErr
436 }
437 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
438 _, sessionErr := a.sessions.Save(ctx, updatedSession)
439 if sessionErr != nil {
440 return sessionErr
441 }
442 currentSession = updatedSession
443 return a.messages.Update(genCtx, *currentAssistant)
444 },
445 StopWhen: []fantasy.StopCondition{
446 func(_ []fantasy.StepResult) bool {
447 cw := int64(largeModel.CatwalkCfg.ContextWindow)
448 // If context window is unknown (0), skip auto-summarize
449 // to avoid immediately truncating custom/local models.
450 if cw == 0 {
451 return false
452 }
453 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
454 remaining := cw - tokens
455 var threshold int64
456 if cw > largeContextWindowThreshold {
457 threshold = largeContextWindowBuffer
458 } else {
459 threshold = int64(float64(cw) * smallContextWindowRatio)
460 }
461 if (remaining <= threshold) && !a.disableAutoSummarize {
462 shouldSummarize = true
463 return true
464 }
465 return false
466 },
467 func(steps []fantasy.StepResult) bool {
468 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
469 },
470 },
471 })
472
473 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
474
475 if err != nil {
476 isHyper := largeModel.ModelCfg.Provider == hyper.Name
477 isCancelErr := errors.Is(err, context.Canceled)
478 if currentAssistant == nil {
479 return result, err
480 }
481 // Ensure we finish thinking on error to close the reasoning state.
482 currentAssistant.FinishThinking()
483 toolCalls := currentAssistant.ToolCalls()
484 // INFO: we use the parent context here because the genCtx has been cancelled.
485 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
486 if createErr != nil {
487 return nil, createErr
488 }
489 for _, tc := range toolCalls {
490 if !tc.Finished {
491 tc.Finished = true
492 tc.Input = "{}"
493 currentAssistant.AddToolCall(tc)
494 updateErr := a.messages.Update(ctx, *currentAssistant)
495 if updateErr != nil {
496 return nil, updateErr
497 }
498 }
499
500 found := false
501 for _, msg := range msgs {
502 if msg.Role == message.Tool {
503 for _, tr := range msg.ToolResults() {
504 if tr.ToolCallID == tc.ID {
505 found = true
506 break
507 }
508 }
509 }
510 if found {
511 break
512 }
513 }
514 if found {
515 continue
516 }
517 content := "There was an error while executing the tool"
518 if isCancelErr {
519 content = "Error: user cancelled assistant tool calling"
520 }
521 toolResult := message.ToolResult{
522 ToolCallID: tc.ID,
523 Name: tc.Name,
524 Content: content,
525 IsError: true,
526 }
527 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
528 Role: message.Tool,
529 Parts: []message.ContentPart{
530 toolResult,
531 },
532 })
533 if createErr != nil {
534 return nil, createErr
535 }
536 }
537 var fantasyErr *fantasy.Error
538 var providerErr *fantasy.ProviderError
539 const defaultTitle = "Provider Error"
540 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
541 if isCancelErr {
542 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
543 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized {
544 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
545 if a.notify != nil {
546 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
547 SessionID: call.SessionID,
548 SessionTitle: currentSession.Title,
549 Type: notify.TypeReAuthenticate,
550 ProviderID: largeModel.ModelCfg.Provider,
551 })
552 }
553 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired {
554 url := hyper.BaseURL()
555 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
556 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
557 } else if errors.As(err, &providerErr) {
558 if providerErr.Message == "The requested model is not supported." {
559 url := "https://github.com/settings/copilot/features"
560 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
561 currentAssistant.AddFinish(
562 message.FinishReasonError,
563 "Copilot model not enabled",
564 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
565 )
566 } else {
567 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
568 }
569 } else if errors.As(err, &fantasyErr) {
570 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
571 } else {
572 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
573 }
574 // Note: we use the parent context here because the genCtx has been
575 // cancelled.
576 updateErr := a.messages.Update(ctx, *currentAssistant)
577 if updateErr != nil {
578 return nil, updateErr
579 }
580 return nil, err
581 }
582
583 // Send notification that agent has finished its turn (skip for
584 // nested/non-interactive sessions).
585 if !call.NonInteractive && a.notify != nil {
586 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
587 SessionID: call.SessionID,
588 SessionTitle: currentSession.Title,
589 Type: notify.TypeAgentFinished,
590 })
591 }
592
593 if shouldSummarize {
594 a.activeRequests.Del(call.SessionID)
595 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
596 return nil, summarizeErr
597 }
598 // If the agent wasn't done...
599 if len(currentAssistant.ToolCalls()) > 0 {
600 existing, ok := a.messageQueue.Get(call.SessionID)
601 if !ok {
602 existing = []SessionAgentCall{}
603 }
604 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
605 existing = append(existing, call)
606 a.messageQueue.Set(call.SessionID, existing)
607 }
608 }
609
610 // Release active request before processing queued messages.
611 a.activeRequests.Del(call.SessionID)
612 cancel()
613
614 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
615 if !ok || len(queuedMessages) == 0 {
616 return result, err
617 }
618 // There are queued messages restart the loop.
619 firstQueuedMessage := queuedMessages[0]
620 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
621 return a.Run(ctx, firstQueuedMessage)
622}
623
624func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
625 if a.IsSessionBusy(sessionID) {
626 return ErrSessionBusy
627 }
628
629 // Copy mutable fields under lock to avoid races with SetModels.
630 largeModel := a.largeModel.Get()
631 systemPromptPrefix := a.systemPromptPrefix.Get()
632
633 currentSession, err := a.sessions.Get(ctx, sessionID)
634 if err != nil {
635 return fmt.Errorf("failed to get session: %w", err)
636 }
637 msgs, err := a.getSessionMessages(ctx, currentSession)
638 if err != nil {
639 return err
640 }
641 if len(msgs) == 0 {
642 // Nothing to summarize.
643 return nil
644 }
645
646 aiMsgs, _ := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages)
647
648 genCtx, cancel := context.WithCancel(ctx)
649 a.activeRequests.Set(sessionID, cancel)
650 defer a.activeRequests.Del(sessionID)
651 defer cancel()
652
653 agent := fantasy.NewAgent(largeModel.Model,
654 fantasy.WithSystemPrompt(string(summaryPrompt)),
655 fantasy.WithUserAgent(userAgent),
656 )
657 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
658 Role: message.Assistant,
659 Model: largeModel.Model.Model(),
660 Provider: largeModel.Model.Provider(),
661 IsSummaryMessage: true,
662 })
663 if err != nil {
664 return err
665 }
666
667 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
668
669 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
670 Prompt: summaryPromptText,
671 Messages: aiMsgs,
672 ProviderOptions: opts,
673 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
674 prepared.Messages = options.Messages
675 if systemPromptPrefix != "" {
676 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
677 }
678 return callContext, prepared, nil
679 },
680 OnReasoningDelta: func(id string, text string) error {
681 summaryMessage.AppendReasoningContent(text)
682 return a.messages.Update(genCtx, summaryMessage)
683 },
684 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
685 // Handle anthropic signature.
686 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
687 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
688 summaryMessage.AppendReasoningSignature(signature.Signature)
689 }
690 }
691 summaryMessage.FinishThinking()
692 return a.messages.Update(genCtx, summaryMessage)
693 },
694 OnTextDelta: func(id, text string) error {
695 summaryMessage.AppendContent(text)
696 return a.messages.Update(genCtx, summaryMessage)
697 },
698 })
699 if err != nil {
700 isCancelErr := errors.Is(err, context.Canceled)
701 if isCancelErr {
702 // User cancelled summarize we need to remove the summary message.
703 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
704 return deleteErr
705 }
706 // Mark the summary message as finished with an error so the UI
707 // stops spinning.
708 summaryMessage.AddFinish(message.FinishReasonError, "Summarization Error", err.Error())
709 if updateErr := a.messages.Update(ctx, summaryMessage); updateErr != nil {
710 return updateErr
711 }
712 return err
713 }
714
715 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
716 err = a.messages.Update(genCtx, summaryMessage)
717 if err != nil {
718 return err
719 }
720
721 var openrouterCost *float64
722 for _, step := range resp.Steps {
723 stepCost := a.openrouterCost(step.ProviderMetadata)
724 if stepCost != nil {
725 newCost := *stepCost
726 if openrouterCost != nil {
727 newCost += *openrouterCost
728 }
729 openrouterCost = &newCost
730 }
731 }
732
733 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
734
735 // Just in case, get just the last usage info.
736 usage := resp.Response.Usage
737 currentSession.SummaryMessageID = summaryMessage.ID
738 currentSession.CompletionTokens = usage.OutputTokens
739 currentSession.PromptTokens = 0
740 _, err = a.sessions.Save(genCtx, currentSession)
741 if err != nil {
742 return err
743 }
744
745 // Release the active request before processing queued messages so that
746 // Run() does not see the session as busy.
747 a.activeRequests.Del(sessionID)
748 cancel()
749
750 // Process any messages that were queued while summarizing.
751 queuedMessages, ok := a.messageQueue.Get(sessionID)
752 if !ok || len(queuedMessages) == 0 {
753 return nil
754 }
755 firstQueuedMessage := queuedMessages[0]
756 a.messageQueue.Set(sessionID, queuedMessages[1:])
757 _, qErr := a.Run(ctx, firstQueuedMessage)
758 return qErr
759}
760
761func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
762 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
763 return fantasy.ProviderOptions{}
764 }
765 return fantasy.ProviderOptions{
766 anthropic.Name: &anthropic.ProviderCacheControlOptions{
767 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
768 },
769 bedrock.Name: &anthropic.ProviderCacheControlOptions{
770 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
771 },
772 vercel.Name: &anthropic.ProviderCacheControlOptions{
773 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
774 },
775 }
776}
777
778func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
779 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
780 var attachmentParts []message.ContentPart
781 for _, attachment := range call.Attachments {
782 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
783 }
784 parts = append(parts, attachmentParts...)
785 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
786 Role: message.User,
787 Parts: parts,
788 })
789 if err != nil {
790 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
791 }
792 return msg, nil
793}
794
795func (a *sessionAgent) preparePrompt(msgs []message.Message, supportsImages bool, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
796 var history []fantasy.Message
797 if !a.isSubAgent {
798 history = append(history, fantasy.NewUserMessage(
799 fmt.Sprintf("<system_reminder>%s</system_reminder>",
800 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
801If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
802If not, please feel free to ignore. Again do not mention this message to the user.`,
803 ),
804 ))
805 }
806 // Collect all tool call IDs present in assistant messages and all tool
807 // result IDs present in tool messages. This lets us detect both orphaned
808 // tool results (result without a call) and orphaned tool calls (call
809 // without a result).
810 knownToolCallIDs := make(map[string]struct{})
811 knownToolResultIDs := make(map[string]struct{})
812 for _, m := range msgs {
813 switch m.Role {
814 case message.Assistant:
815 for _, tc := range m.ToolCalls() {
816 knownToolCallIDs[tc.ID] = struct{}{}
817 }
818 case message.Tool:
819 for _, tr := range m.ToolResults() {
820 knownToolResultIDs[tr.ToolCallID] = struct{}{}
821 }
822 }
823 }
824
825 for _, m := range msgs {
826 if len(m.Parts) == 0 {
827 continue
828 }
829 // Assistant message without content or tool calls (cancelled before it returned anything).
830 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
831 continue
832 }
833 if m.Role == message.Tool {
834 if msg, ok := filterOrphanedToolResults(m, knownToolCallIDs); ok {
835 history = append(history, msg)
836 }
837 continue
838 }
839 aiMsgs := m.ToAIMessage()
840 if !supportsImages {
841 for i := range aiMsgs {
842 if aiMsgs[i].Role == fantasy.MessageRoleUser {
843 aiMsgs[i].Content = filterFileParts(aiMsgs[i].Content)
844 }
845 }
846 }
847 history = append(history, aiMsgs...)
848
849 if m.Role == message.Assistant {
850 if msg, ok := syntheticToolResultsForOrphanedCalls(m, knownToolResultIDs); ok {
851 history = append(history, msg)
852 }
853 }
854 }
855
856 var files []fantasy.FilePart
857 for _, attachment := range attachments {
858 if attachment.IsText() {
859 continue
860 }
861 files = append(files, fantasy.FilePart{
862 Filename: attachment.FileName,
863 Data: attachment.Content,
864 MediaType: attachment.MimeType,
865 })
866 }
867
868 return history, files
869}
870
871// filterFileParts removes fantasy.FilePart entries from a slice of message
872// parts. Used to strip image attachments from historical user messages when
873// the current model does not support them.
874func filterFileParts(parts []fantasy.MessagePart) []fantasy.MessagePart {
875 filtered := make([]fantasy.MessagePart, 0, len(parts))
876 for _, part := range parts {
877 if _, ok := fantasy.AsMessagePart[fantasy.FilePart](part); ok {
878 continue
879 }
880 filtered = append(filtered, part)
881 }
882 return filtered
883}
884
885// filterOrphanedToolResults converts a tool message to a fantasy.Message,
886// dropping any tool result parts whose tool_call_id has no matching tool call
887// in the known set. An orphaned result causes API validation to fail on every
888// subsequent turn, permanently locking the session. Returns the filtered
889// message and true if at least one valid part remains.
890func filterOrphanedToolResults(m message.Message, knownToolCallIDs map[string]struct{}) (fantasy.Message, bool) {
891 aiMsgs := m.ToAIMessage()
892 if len(aiMsgs) == 0 {
893 return fantasy.Message{}, false
894 }
895 var validParts []fantasy.MessagePart
896 for _, part := range aiMsgs[0].Content {
897 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
898 if !ok {
899 validParts = append(validParts, part)
900 continue
901 }
902 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
903 validParts = append(validParts, part)
904 } else {
905 slog.Warn("Dropping orphaned tool result with no matching tool call",
906 "tool_call_id", tr.ToolCallID,
907 )
908 }
909 }
910 if len(validParts) == 0 {
911 return fantasy.Message{}, false
912 }
913 msg := aiMsgs[0]
914 msg.Content = validParts
915 return msg, true
916}
917
918// syntheticToolResultsForOrphanedCalls returns a tool message containing
919// synthetic tool results for any tool calls in the assistant message that
920// have no matching result in knownToolResultIDs. LLM APIs require every
921// tool_use to be immediately followed by a tool_result; an interrupted
922// session can leave orphaned tool_use blocks that permanently lock the
923// conversation. Returns the message and true if any synthetic results were
924// produced.
925func syntheticToolResultsForOrphanedCalls(m message.Message, knownToolResultIDs map[string]struct{}) (fantasy.Message, bool) {
926 var syntheticParts []fantasy.MessagePart
927 for _, tc := range m.ToolCalls() {
928 if _, hasResult := knownToolResultIDs[tc.ID]; hasResult {
929 continue
930 }
931 slog.Warn("Injecting synthetic tool result for orphaned tool call",
932 "tool_call_id", tc.ID,
933 "tool_name", tc.Name,
934 )
935 syntheticParts = append(syntheticParts, fantasy.ToolResultPart{
936 ToolCallID: tc.ID,
937 Output: fantasy.ToolResultOutputContentError{
938 Error: errors.New("tool call was interrupted and did not produce a result, you may retry this call if the result is still needed"),
939 },
940 })
941 }
942 if len(syntheticParts) == 0 {
943 return fantasy.Message{}, false
944 }
945 return fantasy.Message{
946 Role: fantasy.MessageRoleTool,
947 Content: syntheticParts,
948 }, true
949}
950
951func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
952 msgs, err := a.messages.List(ctx, session.ID)
953 if err != nil {
954 return nil, fmt.Errorf("failed to list messages: %w", err)
955 }
956
957 if session.SummaryMessageID != "" {
958 summaryMsgIndex := -1
959 for i, msg := range msgs {
960 if msg.ID == session.SummaryMessageID {
961 summaryMsgIndex = i
962 break
963 }
964 }
965 if summaryMsgIndex != -1 {
966 msgs = msgs[summaryMsgIndex:]
967 msgs[0].Role = message.User
968 }
969 }
970 return msgs, nil
971}
972
973// generateTitle generates a session titled based on the initial prompt.
974func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
975 if userPrompt == "" {
976 return
977 }
978
979 smallModel := a.smallModel.Get()
980 largeModel := a.largeModel.Get()
981 systemPromptPrefix := a.systemPromptPrefix.Get()
982
983 var maxOutputTokens int64 = 40
984 if smallModel.CatwalkCfg.CanReason {
985 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
986 }
987
988 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
989 return fantasy.NewAgent(m,
990 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
991 fantasy.WithMaxOutputTokens(tok),
992 fantasy.WithUserAgent(userAgent),
993 )
994 }
995
996 streamCall := fantasy.AgentStreamCall{
997 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
998 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
999 prepared.Messages = opts.Messages
1000 if systemPromptPrefix != "" {
1001 prepared.Messages = append([]fantasy.Message{
1002 fantasy.NewSystemMessage(systemPromptPrefix),
1003 }, prepared.Messages...)
1004 }
1005 return callCtx, prepared, nil
1006 },
1007 }
1008
1009 // Use the small model to generate the title.
1010 model := smallModel
1011 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
1012 resp, err := agent.Stream(ctx, streamCall)
1013 if err == nil {
1014 // We successfully generated a title with the small model.
1015 slog.Debug("Generated title with small model")
1016 } else {
1017 // It didn't work. Let's try with the big model.
1018 slog.Error("Error generating title with small model; trying big model", "err", err)
1019 model = largeModel
1020 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
1021 resp, err = agent.Stream(ctx, streamCall)
1022 if err == nil {
1023 slog.Debug("Generated title with large model")
1024 } else {
1025 // Welp, the large model didn't work either. Use the default
1026 // session name and return.
1027 slog.Error("Error generating title with large model", "err", err)
1028 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1029 if saveErr != nil {
1030 slog.Error("Failed to save session title", "error", saveErr)
1031 }
1032 return
1033 }
1034 }
1035
1036 if resp == nil {
1037 // Actually, we didn't get a response so we can't. Use the default
1038 // session name and return.
1039 slog.Error("Response is nil; can't generate title")
1040 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1041 if saveErr != nil {
1042 slog.Error("Failed to save session title", "error", saveErr)
1043 }
1044 return
1045 }
1046
1047 // Clean up title.
1048 var title string
1049 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
1050
1051 // Remove thinking tags if present.
1052 title = thinkTagRegex.ReplaceAllString(title, "")
1053 title = orphanThinkTagRegex.ReplaceAllString(title, "")
1054
1055 title = strings.TrimSpace(title)
1056 title = cmp.Or(title, DefaultSessionName)
1057
1058 // Calculate usage and cost.
1059 var openrouterCost *float64
1060 for _, step := range resp.Steps {
1061 stepCost := a.openrouterCost(step.ProviderMetadata)
1062 if stepCost != nil {
1063 newCost := *stepCost
1064 if openrouterCost != nil {
1065 newCost += *openrouterCost
1066 }
1067 openrouterCost = &newCost
1068 }
1069 }
1070
1071 modelConfig := model.CatwalkCfg
1072 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
1073 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
1074 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
1075 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
1076
1077 // Use override cost if available (e.g., from OpenRouter).
1078 if openrouterCost != nil {
1079 cost = *openrouterCost
1080 }
1081
1082 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
1083 completionTokens := resp.TotalUsage.OutputTokens
1084
1085 // Atomically update only title and usage fields to avoid overriding other
1086 // concurrent session updates.
1087 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
1088 if saveErr != nil {
1089 slog.Error("Failed to save session title and usage", "error", saveErr)
1090 return
1091 }
1092}
1093
1094func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
1095 openrouterMetadata, ok := metadata[openrouter.Name]
1096 if !ok {
1097 return nil
1098 }
1099
1100 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
1101 if !ok {
1102 return nil
1103 }
1104 return &opts.Usage.Cost
1105}
1106
1107func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
1108 modelConfig := model.CatwalkCfg
1109 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1110 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1111 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1112 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1113
1114 a.eventTokensUsed(session.ID, model, usage, cost)
1115
1116 if overrideCost != nil {
1117 session.Cost += *overrideCost
1118 } else {
1119 session.Cost += cost
1120 }
1121
1122 session.CompletionTokens = usage.OutputTokens
1123 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
1124}
1125
1126func (a *sessionAgent) Cancel(sessionID string) {
1127 // Cancel regular requests. Don't use Take() here - we need the entry to
1128 // remain in activeRequests so IsBusy() returns true until the goroutine
1129 // fully completes (including error handling that may access the DB).
1130 // The defer in processRequest will clean up the entry.
1131 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1132 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1133 cancel()
1134 }
1135
1136 // Also check for summarize requests.
1137 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1138 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1139 cancel()
1140 }
1141
1142 if a.QueuedPrompts(sessionID) > 0 {
1143 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1144 a.messageQueue.Del(sessionID)
1145 }
1146}
1147
1148func (a *sessionAgent) ClearQueue(sessionID string) {
1149 if a.QueuedPrompts(sessionID) > 0 {
1150 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1151 a.messageQueue.Del(sessionID)
1152 }
1153}
1154
1155func (a *sessionAgent) CancelAll() {
1156 if !a.IsBusy() {
1157 return
1158 }
1159 for key := range a.activeRequests.Seq2() {
1160 a.Cancel(key) // key is sessionID
1161 }
1162
1163 timeout := time.After(5 * time.Second)
1164 for a.IsBusy() {
1165 select {
1166 case <-timeout:
1167 return
1168 default:
1169 time.Sleep(200 * time.Millisecond)
1170 }
1171 }
1172}
1173
1174func (a *sessionAgent) IsBusy() bool {
1175 var busy bool
1176 for cancelFunc := range a.activeRequests.Seq() {
1177 if cancelFunc != nil {
1178 busy = true
1179 break
1180 }
1181 }
1182 return busy
1183}
1184
1185func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1186 _, busy := a.activeRequests.Get(sessionID)
1187 return busy
1188}
1189
1190func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1191 l, ok := a.messageQueue.Get(sessionID)
1192 if !ok {
1193 return 0
1194 }
1195 return len(l)
1196}
1197
1198func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1199 l, ok := a.messageQueue.Get(sessionID)
1200 if !ok {
1201 return nil
1202 }
1203 prompts := make([]string, len(l))
1204 for i, call := range l {
1205 prompts[i] = call.Prompt
1206 }
1207 return prompts
1208}
1209
1210func (a *sessionAgent) SetModels(large Model, small Model) {
1211 a.largeModel.Set(large)
1212 a.smallModel.Set(small)
1213}
1214
1215func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1216 a.tools.SetSlice(tools)
1217}
1218
1219func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1220 a.systemPrompt.Set(systemPrompt)
1221}
1222
1223func (a *sessionAgent) Model() Model {
1224 return a.largeModel.Get()
1225}
1226
1227// convertToToolResult converts a fantasy tool result to a message tool result.
1228func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1229 baseResult := message.ToolResult{
1230 ToolCallID: result.ToolCallID,
1231 Name: result.ToolName,
1232 Metadata: result.ClientMetadata,
1233 }
1234
1235 switch result.Result.GetType() {
1236 case fantasy.ToolResultContentTypeText:
1237 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1238 baseResult.Content = r.Text
1239 }
1240 case fantasy.ToolResultContentTypeError:
1241 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1242 baseResult.Content = r.Error.Error()
1243 baseResult.IsError = true
1244 }
1245 case fantasy.ToolResultContentTypeMedia:
1246 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1247 if !stringext.IsValidBase64(r.Data) {
1248 slog.Warn("Tool returned media with invalid base64 data, discarding image",
1249 "tool", result.ToolName,
1250 "tool_call_id", result.ToolCallID,
1251 )
1252 baseResult.Content = "Tool returned image data with invalid encoding"
1253 baseResult.IsError = true
1254 } else {
1255 content := r.Text
1256 if content == "" {
1257 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1258 }
1259 baseResult.Content = content
1260 baseResult.Data = r.Data
1261 baseResult.MIMEType = r.MediaType
1262 }
1263 }
1264 }
1265
1266 return baseResult
1267}
1268
1269// workaroundProviderMediaLimitations converts media content in tool results to
1270// user messages for providers that don't natively support images in tool results.
1271//
1272// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1273// don't support sending images/media in tool result messages - they only accept
1274// text in tool results. However, they DO support images in user messages.
1275//
1276// If we send media in tool results to these providers, the API returns an error.
1277//
1278// Solution: For these providers, we:
1279// 1. Replace the media in the tool result with a text placeholder
1280// 2. Inject a user message immediately after with the image as a file attachment
1281// 3. This maintains the tool execution flow while working around API limitations
1282//
1283// Anthropic and Bedrock support images natively in tool results, so we skip
1284// this workaround for them.
1285//
1286// Example transformation:
1287//
1288// BEFORE: [tool result: image data]
1289// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1290func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1291 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1292 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1293
1294 if providerSupportsMedia {
1295 return messages
1296 }
1297
1298 convertedMessages := make([]fantasy.Message, 0, len(messages))
1299
1300 for _, msg := range messages {
1301 if msg.Role != fantasy.MessageRoleTool {
1302 convertedMessages = append(convertedMessages, msg)
1303 continue
1304 }
1305
1306 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1307 var mediaFiles []fantasy.FilePart
1308
1309 for _, part := range msg.Content {
1310 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1311 if !ok {
1312 textParts = append(textParts, part)
1313 continue
1314 }
1315
1316 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1317 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1318 if err != nil {
1319 slog.Warn("Failed to decode media data", "error", err)
1320 textParts = append(textParts, part)
1321 continue
1322 }
1323
1324 mediaFiles = append(mediaFiles, fantasy.FilePart{
1325 Data: decoded,
1326 MediaType: media.MediaType,
1327 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1328 })
1329
1330 textParts = append(textParts, fantasy.ToolResultPart{
1331 ToolCallID: toolResult.ToolCallID,
1332 Output: fantasy.ToolResultOutputContentText{
1333 Text: "[Image/media content loaded - see attached file]",
1334 },
1335 ProviderOptions: toolResult.ProviderOptions,
1336 })
1337 } else {
1338 textParts = append(textParts, part)
1339 }
1340 }
1341
1342 convertedMessages = append(convertedMessages, fantasy.Message{
1343 Role: fantasy.MessageRoleTool,
1344 Content: textParts,
1345 })
1346
1347 if len(mediaFiles) > 0 {
1348 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1349 "Here is the media content from the tool result:",
1350 mediaFiles...,
1351 ))
1352 }
1353 }
1354
1355 return convertedMessages
1356}
1357
1358// buildSummaryPrompt constructs the prompt text for session summarization.
1359func buildSummaryPrompt(todos []session.Todo) string {
1360 var sb strings.Builder
1361 sb.WriteString("Provide a detailed summary of our conversation above.")
1362 if len(todos) > 0 {
1363 sb.WriteString("\n\n## Current Todo List\n\n")
1364 for _, t := range todos {
1365 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1366 }
1367 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1368 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1369 }
1370 return sb.String()
1371}
1372
1373func providerRetryLogFields(err *fantasy.ProviderError, delay time.Duration) []any {
1374 fields := []any{
1375 "retry_delay", delay.String(),
1376 }
1377 if err == nil {
1378 return fields
1379 }
1380 fields = append(fields, "status_code", err.StatusCode)
1381 if err.Title != "" {
1382 fields = append(fields, "title", err.Title)
1383 }
1384 if err.Message != "" {
1385 fields = append(fields, "message", err.Message)
1386 }
1387 return fields
1388}