1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "net/http"
19 "os"
20 "regexp"
21 "strconv"
22 "strings"
23 "sync"
24 "time"
25
26 "charm.land/catwalk/pkg/catwalk"
27 "charm.land/fantasy"
28 "charm.land/fantasy/providers/anthropic"
29 "charm.land/fantasy/providers/bedrock"
30 "charm.land/fantasy/providers/google"
31 "charm.land/fantasy/providers/openai"
32 "charm.land/fantasy/providers/openrouter"
33 "charm.land/fantasy/providers/vercel"
34 "charm.land/lipgloss/v2"
35 "github.com/charmbracelet/crush/internal/agent/hyper"
36 "github.com/charmbracelet/crush/internal/agent/notify"
37 "github.com/charmbracelet/crush/internal/agent/tools"
38 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
39 "github.com/charmbracelet/crush/internal/config"
40 "github.com/charmbracelet/crush/internal/csync"
41 "github.com/charmbracelet/crush/internal/message"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 Prompt string
75 ProviderOptions fantasy.ProviderOptions
76 Attachments []message.Attachment
77 MaxOutputTokens int64
78 Temperature *float64
79 TopP *float64
80 TopK *int64
81 FrequencyPenalty *float64
82 PresencePenalty *float64
83 NonInteractive bool
84}
85
86type SessionAgent interface {
87 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
88 SetModels(large Model, small Model)
89 SetTools(tools []fantasy.AgentTool)
90 SetSystemPrompt(systemPrompt string)
91 Cancel(sessionID string)
92 CancelAll()
93 IsSessionBusy(sessionID string) bool
94 IsBusy() bool
95 QueuedPrompts(sessionID string) int
96 QueuedPromptsList(sessionID string) []string
97 ClearQueue(sessionID string)
98 Summarize(context.Context, string, fantasy.ProviderOptions) error
99 Model() Model
100}
101
102type Model struct {
103 Model fantasy.LanguageModel
104 CatwalkCfg catwalk.Model
105 ModelCfg config.SelectedModel
106 FlatRate bool
107}
108
109type sessionAgent struct {
110 largeModel *csync.Value[Model]
111 smallModel *csync.Value[Model]
112 systemPromptPrefix *csync.Value[string]
113 systemPrompt *csync.Value[string]
114 tools *csync.Slice[fantasy.AgentTool]
115
116 isSubAgent bool
117 sessions session.Service
118 messages message.Service
119 disableAutoSummarize bool
120 isYolo bool
121 notify pubsub.Publisher[notify.Notification]
122
123 messageQueue *csync.Map[string, []SessionAgentCall]
124 activeRequests *csync.Map[string, context.CancelFunc]
125}
126
127type SessionAgentOptions struct {
128 LargeModel Model
129 SmallModel Model
130 SystemPromptPrefix string
131 SystemPrompt string
132 IsSubAgent bool
133 DisableAutoSummarize bool
134 IsYolo bool
135 Sessions session.Service
136 Messages message.Service
137 Tools []fantasy.AgentTool
138 Notify pubsub.Publisher[notify.Notification]
139}
140
141func NewSessionAgent(
142 opts SessionAgentOptions,
143) SessionAgent {
144 return &sessionAgent{
145 largeModel: csync.NewValue(opts.LargeModel),
146 smallModel: csync.NewValue(opts.SmallModel),
147 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
148 systemPrompt: csync.NewValue(opts.SystemPrompt),
149 isSubAgent: opts.IsSubAgent,
150 sessions: opts.Sessions,
151 messages: opts.Messages,
152 disableAutoSummarize: opts.DisableAutoSummarize,
153 tools: csync.NewSliceFrom(opts.Tools),
154 isYolo: opts.IsYolo,
155 notify: opts.Notify,
156 messageQueue: csync.NewMap[string, []SessionAgentCall](),
157 activeRequests: csync.NewMap[string, context.CancelFunc](),
158 }
159}
160
161func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
162 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
163 return nil, ErrEmptyPrompt
164 }
165 if call.SessionID == "" {
166 return nil, ErrSessionMissing
167 }
168
169 // Queue the message if busy
170 if a.IsSessionBusy(call.SessionID) {
171 existing, ok := a.messageQueue.Get(call.SessionID)
172 if !ok {
173 existing = []SessionAgentCall{}
174 }
175 existing = append(existing, call)
176 a.messageQueue.Set(call.SessionID, existing)
177 return nil, nil
178 }
179
180 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
181 agentTools := a.tools.Copy()
182 largeModel := a.largeModel.Get()
183 systemPrompt := a.systemPrompt.Get()
184 promptPrefix := a.systemPromptPrefix.Get()
185 var instructions strings.Builder
186
187 for _, server := range mcp.GetStates() {
188 if server.State != mcp.StateConnected {
189 continue
190 }
191 if s := server.Client.InitializeResult().Instructions; s != "" {
192 instructions.WriteString(s)
193 instructions.WriteString("\n\n")
194 }
195 }
196
197 if s := instructions.String(); s != "" {
198 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
199 }
200
201 if len(agentTools) > 0 {
202 // Add Anthropic caching to the last tool.
203 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
204 }
205
206 agent := fantasy.NewAgent(
207 largeModel.Model,
208 fantasy.WithSystemPrompt(systemPrompt),
209 fantasy.WithTools(agentTools...),
210 fantasy.WithUserAgent(userAgent),
211 )
212
213 sessionLock := sync.Mutex{}
214 currentSession, err := a.sessions.Get(ctx, call.SessionID)
215 if err != nil {
216 return nil, fmt.Errorf("failed to get session: %w", err)
217 }
218
219 msgs, err := a.getSessionMessages(ctx, currentSession)
220 if err != nil {
221 return nil, fmt.Errorf("failed to get session messages: %w", err)
222 }
223
224 var wg sync.WaitGroup
225 // Generate title if first message.
226 if len(msgs) == 0 {
227 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
228 wg.Go(func() {
229 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
230 })
231 }
232 defer wg.Wait()
233
234 // Add the user message to the session.
235 _, err = a.createUserMessage(ctx, call)
236 if err != nil {
237 return nil, err
238 }
239
240 // Add the session to the context.
241 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
242
243 genCtx, cancel := context.WithCancel(ctx)
244 a.activeRequests.Set(call.SessionID, cancel)
245
246 defer cancel()
247 defer a.activeRequests.Del(call.SessionID)
248 // Drain any debounced message updates before returning. message.Service
249 // already flushes synchronously on terminal updates, but a defer here
250 // guarantees the contract at every Run exit (success, error, panic
251 // recovery upstream) without callers needing to know.
252 defer func() {
253 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
254 slog.Error("Failed to flush pending message updates after run", "error", flushErr)
255 }
256 }()
257
258 history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
259
260 startTime := time.Now()
261 a.eventPromptSent(call.SessionID)
262
263 var currentAssistant *message.Message
264 var stepMessages []fantasy.Message
265 var shouldSummarize bool
266 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
267 var maxOutputTokens *int64
268 if call.MaxOutputTokens > 0 {
269 maxOutputTokens = &call.MaxOutputTokens
270 }
271 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
272 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
273 Files: files,
274 Messages: history,
275 ProviderOptions: call.ProviderOptions,
276 MaxOutputTokens: maxOutputTokens,
277 TopP: call.TopP,
278 Temperature: call.Temperature,
279 PresencePenalty: call.PresencePenalty,
280 TopK: call.TopK,
281 FrequencyPenalty: call.FrequencyPenalty,
282 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
283 prepared.Messages = options.Messages
284 for i := range prepared.Messages {
285 prepared.Messages[i].ProviderOptions = nil
286 }
287
288 // Use latest tools (updated by SetTools when MCP tools change).
289 prepared.Tools = a.tools.Copy()
290
291 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
292 a.messageQueue.Del(call.SessionID)
293 for _, queued := range queuedCalls {
294 userMessage, createErr := a.createUserMessage(callContext, queued)
295 if createErr != nil {
296 return callContext, prepared, createErr
297 }
298 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
299 }
300
301 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
302
303 lastSystemRoleInx := 0
304 systemMessageUpdated := false
305 for i, msg := range prepared.Messages {
306 // Only add cache control to the last message.
307 if msg.Role == fantasy.MessageRoleSystem {
308 lastSystemRoleInx = i
309 } else if !systemMessageUpdated {
310 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
311 systemMessageUpdated = true
312 }
313 // Than add cache control to the last 2 messages.
314 if i > len(prepared.Messages)-3 {
315 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
316 }
317 }
318
319 if promptPrefix != "" {
320 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
321 }
322
323 sessionLock.Lock()
324 stepMessages = cloneFantasyMessages(prepared.Messages)
325 sessionLock.Unlock()
326
327 var assistantMsg message.Message
328 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
329 Role: message.Assistant,
330 Parts: []message.ContentPart{},
331 Model: largeModel.ModelCfg.Model,
332 Provider: largeModel.ModelCfg.Provider,
333 })
334 if err != nil {
335 return callContext, prepared, err
336 }
337 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
338 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
339 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
340 currentAssistant = &assistantMsg
341 return callContext, prepared, err
342 },
343 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
344 currentAssistant.AppendReasoningContent(reasoning.Text)
345 return a.messages.Update(genCtx, *currentAssistant)
346 },
347 OnReasoningDelta: func(id string, text string) error {
348 currentAssistant.AppendReasoningContent(text)
349 return a.messages.Update(genCtx, *currentAssistant)
350 },
351 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
352 // handle anthropic signature
353 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
354 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
355 currentAssistant.AppendReasoningSignature(reasoning.Signature)
356 }
357 }
358 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
359 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
360 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
361 }
362 }
363 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
364 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
365 currentAssistant.SetReasoningResponsesData(reasoning)
366 }
367 }
368 currentAssistant.FinishThinking()
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 OnTextDelta: func(id string, text string) error {
372 // Strip leading newline from initial text content. This is is
373 // particularly important in non-interactive mode where leading
374 // newlines are very visible.
375 if len(currentAssistant.Parts) == 0 {
376 text = strings.TrimPrefix(text, "\n")
377 }
378
379 currentAssistant.AppendContent(text)
380 return a.messages.Update(genCtx, *currentAssistant)
381 },
382 OnToolInputStart: func(id string, toolName string) error {
383 toolCall := message.ToolCall{
384 ID: id,
385 Name: toolName,
386 ProviderExecuted: false,
387 Finished: false,
388 }
389 currentAssistant.AddToolCall(toolCall)
390 // Use parent ctx instead of genCtx to ensure the update succeeds
391 // even if the request is canceled mid-stream
392 return a.messages.Update(ctx, *currentAssistant)
393 },
394 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
395 slog.Warn("Provider request failed, retrying", providerRetryLogFields(err, delay)...)
396 },
397 OnToolCall: func(tc fantasy.ToolCallContent) error {
398 toolCall := message.ToolCall{
399 ID: tc.ToolCallID,
400 Name: tc.ToolName,
401 Input: tc.Input,
402 ProviderExecuted: false,
403 Finished: true,
404 }
405 currentAssistant.AddToolCall(toolCall)
406 // Use parent ctx instead of genCtx to ensure the update succeeds
407 // even if the request is canceled mid-stream
408 return a.messages.Update(ctx, *currentAssistant)
409 },
410 OnToolResult: func(result fantasy.ToolResultContent) error {
411 toolResult := a.convertToToolResult(result)
412 // Use parent ctx instead of genCtx to ensure the message is created
413 // even if the request is canceled mid-stream
414 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
415 Role: message.Tool,
416 Parts: []message.ContentPart{
417 toolResult,
418 },
419 })
420 return createMsgErr
421 },
422 OnStepFinish: func(stepResult fantasy.StepResult) error {
423 finishReason := message.FinishReasonUnknown
424 switch stepResult.FinishReason {
425 case fantasy.FinishReasonLength:
426 finishReason = message.FinishReasonMaxTokens
427 case fantasy.FinishReasonStop:
428 finishReason = message.FinishReasonEndTurn
429 case fantasy.FinishReasonToolCalls:
430 finishReason = message.FinishReasonToolUse
431 }
432 // If a tool result halted the turn (e.g. a hook halt or a
433 // permission denial), the step ends on FinishReasonToolCalls but
434 // the model will not be called again. Treat it as the end of the
435 // turn so the UI can render the assistant footer.
436 if finishReason == message.FinishReasonToolUse {
437 for _, tr := range stepResult.Content.ToolResults() {
438 if tr.StopTurn {
439 finishReason = message.FinishReasonEndTurn
440 break
441 }
442 }
443 }
444 currentAssistant.AddFinish(finishReason, "", "")
445 sessionLock.Lock()
446 defer sessionLock.Unlock()
447
448 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
449 if getSessionErr != nil {
450 return getSessionErr
451 }
452 usage, estimated := fallbackStepUsage(stepMessages, stepResult)
453 a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated)
454 _, sessionErr := a.sessions.Save(ctx, updatedSession)
455 if sessionErr != nil {
456 return sessionErr
457 }
458 currentSession = updatedSession
459 return a.messages.Update(genCtx, *currentAssistant)
460 },
461 StopWhen: []fantasy.StopCondition{
462 func(_ []fantasy.StepResult) bool {
463 cw := int64(largeModel.CatwalkCfg.ContextWindow)
464 // If context window is unknown (0), skip auto-summarize
465 // to avoid immediately truncating custom/local models.
466 if cw == 0 {
467 return false
468 }
469 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
470 remaining := cw - tokens
471 var threshold int64
472 if cw > largeContextWindowThreshold {
473 threshold = largeContextWindowBuffer
474 } else {
475 threshold = int64(float64(cw) * smallContextWindowRatio)
476 }
477 if (remaining <= threshold) && !a.disableAutoSummarize {
478 shouldSummarize = true
479 return true
480 }
481 return false
482 },
483 func(steps []fantasy.StepResult) bool {
484 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
485 },
486 },
487 })
488
489 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
490
491 if err != nil {
492 isHyper := largeModel.ModelCfg.Provider == hyper.Name
493 isCancelErr := errors.Is(err, context.Canceled)
494 if currentAssistant == nil {
495 return result, err
496 }
497 // Ensure we finish thinking on error to close the reasoning state.
498 currentAssistant.FinishThinking()
499 toolCalls := currentAssistant.ToolCalls()
500 // INFO: we use the parent context here because the genCtx has been cancelled.
501 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
502 if createErr != nil {
503 return nil, createErr
504 }
505 for _, tc := range toolCalls {
506 if !tc.Finished {
507 tc.Finished = true
508 tc.Input = "{}"
509 currentAssistant.AddToolCall(tc)
510 updateErr := a.messages.Update(ctx, *currentAssistant)
511 if updateErr != nil {
512 return nil, updateErr
513 }
514 }
515
516 found := false
517 for _, msg := range msgs {
518 if msg.Role == message.Tool {
519 for _, tr := range msg.ToolResults() {
520 if tr.ToolCallID == tc.ID {
521 found = true
522 break
523 }
524 }
525 }
526 if found {
527 break
528 }
529 }
530 if found {
531 continue
532 }
533 content := "There was an error while executing the tool"
534 if isCancelErr {
535 content = "Error: user cancelled assistant tool calling"
536 }
537 toolResult := message.ToolResult{
538 ToolCallID: tc.ID,
539 Name: tc.Name,
540 Content: content,
541 IsError: true,
542 }
543 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
544 Role: message.Tool,
545 Parts: []message.ContentPart{
546 toolResult,
547 },
548 })
549 if createErr != nil {
550 return nil, createErr
551 }
552 }
553 var fantasyErr *fantasy.Error
554 var providerErr *fantasy.ProviderError
555 const defaultTitle = "Provider Error"
556 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
557 if isCancelErr {
558 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
559 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized {
560 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
561 if a.notify != nil {
562 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
563 SessionID: call.SessionID,
564 SessionTitle: currentSession.Title,
565 Type: notify.TypeReAuthenticate,
566 ProviderID: largeModel.ModelCfg.Provider,
567 })
568 }
569 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired {
570 url := hyper.BaseURL()
571 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
572 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
573 } else if errors.As(err, &providerErr) {
574 if providerErr.Message == "The requested model is not supported." {
575 url := "https://github.com/settings/copilot/features"
576 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
577 currentAssistant.AddFinish(
578 message.FinishReasonError,
579 "Copilot model not enabled",
580 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
581 )
582 } else {
583 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
584 }
585 } else if errors.As(err, &fantasyErr) {
586 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
587 } else {
588 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
589 }
590 // Note: we use the parent context here because the genCtx has been
591 // cancelled.
592 updateErr := a.messages.Update(ctx, *currentAssistant)
593 if updateErr != nil {
594 return nil, updateErr
595 }
596 return nil, err
597 }
598
599 if shouldSummarize {
600 a.activeRequests.Del(call.SessionID)
601 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
602 return nil, summarizeErr
603 }
604 // If the agent wasn't done...
605 if len(currentAssistant.ToolCalls()) > 0 {
606 existing, ok := a.messageQueue.Get(call.SessionID)
607 if !ok {
608 existing = []SessionAgentCall{}
609 }
610 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
611 existing = append(existing, call)
612 a.messageQueue.Set(call.SessionID, existing)
613 }
614 }
615
616 // Release active request before publishing the notification.
617 // TUI handlers poll IsSessionBusy() and only re-evaluate when a
618 // tea.Msg arrives, so the cleanup must precede the notify or
619 // subscribers see stale busy state at the moment of receipt.
620 a.activeRequests.Del(call.SessionID)
621 cancel()
622
623 // Send notification that agent has finished its turn (skip for
624 // nested/non-interactive sessions).
625 if !call.NonInteractive && a.notify != nil {
626 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
627 SessionID: call.SessionID,
628 SessionTitle: currentSession.Title,
629 Type: notify.TypeAgentFinished,
630 })
631 }
632
633 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
634 if !ok || len(queuedMessages) == 0 {
635 return result, err
636 }
637 // There are queued messages restart the loop.
638 firstQueuedMessage := queuedMessages[0]
639 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
640 return a.Run(ctx, firstQueuedMessage)
641}
642
643func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
644 if a.IsSessionBusy(sessionID) {
645 return ErrSessionBusy
646 }
647
648 // Copy mutable fields under lock to avoid races with SetModels.
649 largeModel := a.largeModel.Get()
650 systemPromptPrefix := a.systemPromptPrefix.Get()
651
652 currentSession, err := a.sessions.Get(ctx, sessionID)
653 if err != nil {
654 return fmt.Errorf("failed to get session: %w", err)
655 }
656 msgs, err := a.getSessionMessages(ctx, currentSession)
657 if err != nil {
658 return err
659 }
660 if len(msgs) == 0 {
661 // Nothing to summarize.
662 return nil
663 }
664
665 aiMsgs, _ := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages)
666
667 genCtx, cancel := context.WithCancel(ctx)
668 a.activeRequests.Set(sessionID, cancel)
669 defer a.activeRequests.Del(sessionID)
670 defer cancel()
671 defer func() {
672 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
673 slog.Error("Failed to flush pending message updates after summarize", "error", flushErr)
674 }
675 }()
676
677 agent := fantasy.NewAgent(
678 largeModel.Model,
679 fantasy.WithSystemPrompt(string(summaryPrompt)),
680 fantasy.WithUserAgent(userAgent),
681 )
682 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
683 Role: message.Assistant,
684 Model: largeModel.Model.Model(),
685 Provider: largeModel.Model.Provider(),
686 IsSummaryMessage: true,
687 })
688 if err != nil {
689 return err
690 }
691
692 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
693
694 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
695 Prompt: summaryPromptText,
696 Messages: aiMsgs,
697 ProviderOptions: opts,
698 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
699 prepared.Messages = options.Messages
700 if systemPromptPrefix != "" {
701 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
702 }
703 return callContext, prepared, nil
704 },
705 OnReasoningDelta: func(id string, text string) error {
706 summaryMessage.AppendReasoningContent(text)
707 return a.messages.Update(genCtx, summaryMessage)
708 },
709 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
710 // Handle anthropic signature.
711 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
712 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
713 summaryMessage.AppendReasoningSignature(signature.Signature)
714 }
715 }
716 summaryMessage.FinishThinking()
717 return a.messages.Update(genCtx, summaryMessage)
718 },
719 OnTextDelta: func(id, text string) error {
720 summaryMessage.AppendContent(text)
721 return a.messages.Update(genCtx, summaryMessage)
722 },
723 })
724 if err != nil {
725 isCancelErr := errors.Is(err, context.Canceled)
726 if isCancelErr {
727 // User cancelled summarize we need to remove the summary message.
728 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
729 return deleteErr
730 }
731 // Mark the summary message as finished with an error so the UI
732 // stops spinning.
733 summaryMessage.AddFinish(message.FinishReasonError, "Summarization Error", err.Error())
734 if updateErr := a.messages.Update(ctx, summaryMessage); updateErr != nil {
735 return updateErr
736 }
737 return err
738 }
739
740 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
741 err = a.messages.Update(genCtx, summaryMessage)
742 if err != nil {
743 return err
744 }
745
746 var openrouterCost *float64
747 for _, step := range resp.Steps {
748 stepCost := a.openrouterCost(step.ProviderMetadata)
749 if stepCost != nil {
750 newCost := *stepCost
751 if openrouterCost != nil {
752 newCost += *openrouterCost
753 }
754 openrouterCost = &newCost
755 }
756 }
757
758 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost, false)
759
760 // Just in case, get just the last usage info.
761 usage := resp.Response.Usage
762 currentSession.SummaryMessageID = summaryMessage.ID
763 currentSession.CompletionTokens = usage.OutputTokens
764 currentSession.PromptTokens = 0
765 _, err = a.sessions.Save(genCtx, currentSession)
766 if err != nil {
767 return err
768 }
769
770 // Release the active request before processing queued messages so that
771 // Run() does not see the session as busy.
772 a.activeRequests.Del(sessionID)
773 cancel()
774
775 // Process any messages that were queued while summarizing.
776 queuedMessages, ok := a.messageQueue.Get(sessionID)
777 if !ok || len(queuedMessages) == 0 {
778 return nil
779 }
780 firstQueuedMessage := queuedMessages[0]
781 a.messageQueue.Set(sessionID, queuedMessages[1:])
782 _, qErr := a.Run(ctx, firstQueuedMessage)
783 return qErr
784}
785
786func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
787 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
788 return fantasy.ProviderOptions{}
789 }
790 return fantasy.ProviderOptions{
791 anthropic.Name: &anthropic.ProviderCacheControlOptions{
792 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
793 },
794 bedrock.Name: &anthropic.ProviderCacheControlOptions{
795 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
796 },
797 vercel.Name: &anthropic.ProviderCacheControlOptions{
798 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
799 },
800 }
801}
802
803func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
804 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
805 var attachmentParts []message.ContentPart
806 for _, attachment := range call.Attachments {
807 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
808 }
809 parts = append(parts, attachmentParts...)
810 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
811 Role: message.User,
812 Parts: parts,
813 })
814 if err != nil {
815 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
816 }
817 return msg, nil
818}
819
820func (a *sessionAgent) preparePrompt(msgs []message.Message, supportsImages bool, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
821 var history []fantasy.Message
822 if !a.isSubAgent {
823 history = append(history, fantasy.NewUserMessage(
824 fmt.Sprintf(
825 "<system_reminder>%s</system_reminder>",
826 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
827If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
828If not, please feel free to ignore. Again do not mention this message to the user.`,
829 ),
830 ))
831 }
832 // Collect all tool call IDs present in assistant messages and all tool
833 // result IDs present in tool messages. This lets us detect both orphaned
834 // tool results (result without a call) and orphaned tool calls (call
835 // without a result).
836 knownToolCallIDs := make(map[string]struct{})
837 knownToolResultIDs := make(map[string]struct{})
838 for _, m := range msgs {
839 switch m.Role {
840 case message.Assistant:
841 for _, tc := range m.ToolCalls() {
842 knownToolCallIDs[tc.ID] = struct{}{}
843 }
844 case message.Tool:
845 for _, tr := range m.ToolResults() {
846 knownToolResultIDs[tr.ToolCallID] = struct{}{}
847 }
848 }
849 }
850
851 for _, m := range msgs {
852 if len(m.Parts) == 0 {
853 continue
854 }
855 // Assistant message without content or tool calls (cancelled before it returned anything).
856 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
857 continue
858 }
859 if m.Role == message.Tool {
860 if msg, ok := filterOrphanedToolResults(m, knownToolCallIDs); ok {
861 history = append(history, msg)
862 }
863 continue
864 }
865 aiMsgs := m.ToAIMessage()
866 if !supportsImages {
867 for i := range aiMsgs {
868 if aiMsgs[i].Role == fantasy.MessageRoleUser {
869 aiMsgs[i].Content = filterFileParts(aiMsgs[i].Content)
870 }
871 }
872 }
873 history = append(history, aiMsgs...)
874
875 if m.Role == message.Assistant {
876 if msg, ok := syntheticToolResultsForOrphanedCalls(m, knownToolResultIDs); ok {
877 history = append(history, msg)
878 }
879 }
880 }
881
882 var files []fantasy.FilePart
883 for _, attachment := range attachments {
884 if attachment.IsText() {
885 continue
886 }
887 files = append(files, fantasy.FilePart{
888 Filename: attachment.FileName,
889 Data: attachment.Content,
890 MediaType: attachment.MimeType,
891 })
892 }
893
894 return history, files
895}
896
897// filterFileParts removes fantasy.FilePart entries from a slice of message
898// parts. Used to strip image attachments from historical user messages when
899// the current model does not support them.
900func filterFileParts(parts []fantasy.MessagePart) []fantasy.MessagePart {
901 filtered := make([]fantasy.MessagePart, 0, len(parts))
902 for _, part := range parts {
903 if _, ok := fantasy.AsMessagePart[fantasy.FilePart](part); ok {
904 continue
905 }
906 filtered = append(filtered, part)
907 }
908 return filtered
909}
910
911// filterOrphanedToolResults converts a tool message to a fantasy.Message,
912// dropping any tool result parts whose tool_call_id has no matching tool call
913// in the known set. An orphaned result causes API validation to fail on every
914// subsequent turn, permanently locking the session. Returns the filtered
915// message and true if at least one valid part remains.
916func filterOrphanedToolResults(m message.Message, knownToolCallIDs map[string]struct{}) (fantasy.Message, bool) {
917 aiMsgs := m.ToAIMessage()
918 if len(aiMsgs) == 0 {
919 return fantasy.Message{}, false
920 }
921 var validParts []fantasy.MessagePart
922 for _, part := range aiMsgs[0].Content {
923 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
924 if !ok {
925 validParts = append(validParts, part)
926 continue
927 }
928 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
929 validParts = append(validParts, part)
930 } else {
931 slog.Warn(
932 "Dropping orphaned tool result with no matching tool call",
933 "tool_call_id", tr.ToolCallID,
934 )
935 }
936 }
937 if len(validParts) == 0 {
938 return fantasy.Message{}, false
939 }
940 msg := aiMsgs[0]
941 msg.Content = validParts
942 return msg, true
943}
944
945// syntheticToolResultsForOrphanedCalls returns a tool message containing
946// synthetic tool results for any tool calls in the assistant message that
947// have no matching result in knownToolResultIDs. LLM APIs require every
948// tool_use to be immediately followed by a tool_result; an interrupted
949// session can leave orphaned tool_use blocks that permanently lock the
950// conversation. Returns the message and true if any synthetic results were
951// produced.
952func syntheticToolResultsForOrphanedCalls(m message.Message, knownToolResultIDs map[string]struct{}) (fantasy.Message, bool) {
953 var syntheticParts []fantasy.MessagePart
954 for _, tc := range m.ToolCalls() {
955 if _, hasResult := knownToolResultIDs[tc.ID]; hasResult {
956 continue
957 }
958 slog.Warn(
959 "Injecting synthetic tool result for orphaned tool call",
960 "tool_call_id", tc.ID,
961 "tool_name", tc.Name,
962 )
963 syntheticParts = append(syntheticParts, fantasy.ToolResultPart{
964 ToolCallID: tc.ID,
965 Output: fantasy.ToolResultOutputContentError{
966 Error: errors.New("tool call was interrupted and did not produce a result, you may retry this call if the result is still needed"),
967 },
968 })
969 }
970 if len(syntheticParts) == 0 {
971 return fantasy.Message{}, false
972 }
973 return fantasy.Message{
974 Role: fantasy.MessageRoleTool,
975 Content: syntheticParts,
976 }, true
977}
978
979func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
980 msgs, err := a.messages.List(ctx, session.ID)
981 if err != nil {
982 return nil, fmt.Errorf("failed to list messages: %w", err)
983 }
984
985 if session.SummaryMessageID != "" {
986 summaryMsgIndex := -1
987 for i, msg := range msgs {
988 if msg.ID == session.SummaryMessageID {
989 summaryMsgIndex = i
990 break
991 }
992 }
993 if summaryMsgIndex != -1 {
994 msgs = msgs[summaryMsgIndex:]
995 msgs[0].Role = message.User
996 }
997 }
998 return msgs, nil
999}
1000
1001// generateTitle generates a session titled based on the initial prompt.
1002func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
1003 if userPrompt == "" {
1004 return
1005 }
1006
1007 smallModel := a.smallModel.Get()
1008 largeModel := a.largeModel.Get()
1009 systemPromptPrefix := a.systemPromptPrefix.Get()
1010
1011 var maxOutputTokens int64 = 40
1012 if smallModel.CatwalkCfg.CanReason {
1013 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
1014 }
1015
1016 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
1017 return fantasy.NewAgent(
1018 m,
1019 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
1020 fantasy.WithMaxOutputTokens(tok),
1021 fantasy.WithUserAgent(userAgent),
1022 )
1023 }
1024
1025 streamCall := fantasy.AgentStreamCall{
1026 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
1027 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
1028 prepared.Messages = opts.Messages
1029 if systemPromptPrefix != "" {
1030 prepared.Messages = append([]fantasy.Message{
1031 fantasy.NewSystemMessage(systemPromptPrefix),
1032 }, prepared.Messages...)
1033 }
1034 return callCtx, prepared, nil
1035 },
1036 }
1037
1038 // Use the small model to generate the title.
1039 model := smallModel
1040 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
1041 resp, err := agent.Stream(ctx, streamCall)
1042 if err == nil {
1043 // We successfully generated a title with the small model.
1044 slog.Debug("Generated title with small model")
1045 } else {
1046 // It didn't work. Let's try with the big model.
1047 slog.Error("Error generating title with small model; trying big model", "err", err)
1048 model = largeModel
1049 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
1050 resp, err = agent.Stream(ctx, streamCall)
1051 if err == nil {
1052 slog.Debug("Generated title with large model")
1053 } else {
1054 // Welp, the large model didn't work either. Use the default
1055 // session name and return.
1056 slog.Error("Error generating title with large model", "err", err)
1057 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1058 if saveErr != nil {
1059 slog.Error("Failed to save session title", "error", saveErr)
1060 }
1061 return
1062 }
1063 }
1064
1065 if resp == nil {
1066 // Actually, we didn't get a response so we can't. Use the default
1067 // session name and return.
1068 slog.Error("Response is nil; can't generate title")
1069 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1070 if saveErr != nil {
1071 slog.Error("Failed to save session title", "error", saveErr)
1072 }
1073 return
1074 }
1075
1076 // Clean up title.
1077 var title string
1078 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
1079
1080 // Remove thinking tags if present.
1081 title = thinkTagRegex.ReplaceAllString(title, "")
1082 title = orphanThinkTagRegex.ReplaceAllString(title, "")
1083
1084 title = strings.TrimSpace(title)
1085 title = cmp.Or(title, DefaultSessionName)
1086
1087 // Calculate usage and cost.
1088 var openrouterCost *float64
1089 for _, step := range resp.Steps {
1090 stepCost := a.openrouterCost(step.ProviderMetadata)
1091 if stepCost != nil {
1092 newCost := *stepCost
1093 if openrouterCost != nil {
1094 newCost += *openrouterCost
1095 }
1096 openrouterCost = &newCost
1097 }
1098 }
1099
1100 modelConfig := model.CatwalkCfg
1101 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
1102 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
1103 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
1104 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
1105
1106 // Use override cost if available (e.g., from OpenRouter).
1107 if openrouterCost != nil {
1108 cost = *openrouterCost
1109 }
1110
1111 // Skip cost accumulation
1112 if model.FlatRate {
1113 cost = 0
1114 }
1115
1116 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
1117 completionTokens := resp.TotalUsage.OutputTokens
1118
1119 // Atomically update only title and usage fields to avoid overriding other
1120 // concurrent session updates.
1121 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
1122 if saveErr != nil {
1123 slog.Error("Failed to save session title and usage", "error", saveErr)
1124 return
1125 }
1126}
1127
1128func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
1129 openrouterMetadata, ok := metadata[openrouter.Name]
1130 if !ok {
1131 return nil
1132 }
1133
1134 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
1135 if !ok {
1136 return nil
1137 }
1138 return &opts.Usage.Cost
1139}
1140
1141func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) {
1142 modelConfig := model.CatwalkCfg
1143 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1144 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1145 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1146 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1147
1148 eventCost := cost
1149 if estimated {
1150 eventCost = 0
1151 }
1152 a.eventTokensUsed(session.ID, model, usage, eventCost)
1153
1154 if estimated {
1155 cost = 0
1156 } else {
1157 // Use override cost if available (e.g., from OpenRouter).
1158 if overrideCost != nil {
1159 cost = *overrideCost
1160 }
1161
1162 // Skip cost accumulation
1163 if model.FlatRate {
1164 cost = 0
1165 }
1166 }
1167
1168 session.Cost += cost
1169 if !usageIsZero(usage) {
1170 session.CompletionTokens = usage.OutputTokens
1171 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
1172 }
1173}
1174
1175func (a *sessionAgent) Cancel(sessionID string) {
1176 // Cancel regular requests. Don't use Take() here - we need the entry to
1177 // remain in activeRequests so IsBusy() returns true until the goroutine
1178 // fully completes (including error handling that may access the DB).
1179 // The defer in processRequest will clean up the entry.
1180 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1181 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1182 cancel()
1183 }
1184
1185 // Also check for summarize requests.
1186 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1187 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1188 cancel()
1189 }
1190
1191 if a.QueuedPrompts(sessionID) > 0 {
1192 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1193 a.messageQueue.Del(sessionID)
1194 }
1195}
1196
1197func (a *sessionAgent) ClearQueue(sessionID string) {
1198 if a.QueuedPrompts(sessionID) > 0 {
1199 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1200 a.messageQueue.Del(sessionID)
1201 }
1202}
1203
1204func (a *sessionAgent) CancelAll() {
1205 if !a.IsBusy() {
1206 return
1207 }
1208 for key := range a.activeRequests.Seq2() {
1209 a.Cancel(key) // key is sessionID
1210 }
1211
1212 timeout := time.After(5 * time.Second)
1213 for a.IsBusy() {
1214 select {
1215 case <-timeout:
1216 return
1217 default:
1218 time.Sleep(200 * time.Millisecond)
1219 }
1220 }
1221}
1222
1223func (a *sessionAgent) IsBusy() bool {
1224 var busy bool
1225 for cancelFunc := range a.activeRequests.Seq() {
1226 if cancelFunc != nil {
1227 busy = true
1228 break
1229 }
1230 }
1231 return busy
1232}
1233
1234func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1235 _, busy := a.activeRequests.Get(sessionID)
1236 return busy
1237}
1238
1239func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1240 l, ok := a.messageQueue.Get(sessionID)
1241 if !ok {
1242 return 0
1243 }
1244 return len(l)
1245}
1246
1247func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1248 l, ok := a.messageQueue.Get(sessionID)
1249 if !ok {
1250 return nil
1251 }
1252 prompts := make([]string, len(l))
1253 for i, call := range l {
1254 prompts[i] = call.Prompt
1255 }
1256 return prompts
1257}
1258
1259func (a *sessionAgent) SetModels(large Model, small Model) {
1260 a.largeModel.Set(large)
1261 a.smallModel.Set(small)
1262}
1263
1264func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1265 a.tools.SetSlice(tools)
1266}
1267
1268func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1269 a.systemPrompt.Set(systemPrompt)
1270}
1271
1272func (a *sessionAgent) Model() Model {
1273 return a.largeModel.Get()
1274}
1275
1276// convertToToolResult converts a fantasy tool result to a message tool result.
1277func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1278 baseResult := message.ToolResult{
1279 ToolCallID: result.ToolCallID,
1280 Name: result.ToolName,
1281 Metadata: result.ClientMetadata,
1282 }
1283
1284 switch result.Result.GetType() {
1285 case fantasy.ToolResultContentTypeText:
1286 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1287 baseResult.Content = r.Text
1288 }
1289 case fantasy.ToolResultContentTypeError:
1290 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1291 baseResult.Content = r.Error.Error()
1292 baseResult.IsError = true
1293 }
1294 case fantasy.ToolResultContentTypeMedia:
1295 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1296 if !stringext.IsValidBase64(r.Data) {
1297 slog.Warn(
1298 "Tool returned media with invalid base64 data, discarding image",
1299 "tool", result.ToolName,
1300 "tool_call_id", result.ToolCallID,
1301 )
1302 baseResult.Content = "Tool returned image data with invalid encoding"
1303 baseResult.IsError = true
1304 } else {
1305 content := r.Text
1306 if content == "" {
1307 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1308 }
1309 baseResult.Content = content
1310 baseResult.Data = r.Data
1311 baseResult.MIMEType = r.MediaType
1312 }
1313 }
1314 }
1315
1316 return baseResult
1317}
1318
1319// workaroundProviderMediaLimitations converts media content in tool results to
1320// user messages for providers that don't natively support images in tool results.
1321//
1322// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1323// don't support sending images/media in tool result messages - they only accept
1324// text in tool results. However, they DO support images in user messages.
1325//
1326// If we send media in tool results to these providers, the API returns an error.
1327//
1328// Solution: For these providers, we:
1329// 1. Replace the media in the tool result with a text placeholder
1330// 2. Inject a user message immediately after with the image as a file attachment
1331// 3. This maintains the tool execution flow while working around API limitations
1332//
1333// Anthropic and Bedrock support images natively in tool results, so we skip
1334// this workaround for them.
1335//
1336// Example transformation:
1337//
1338// BEFORE: [tool result: image data]
1339// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1340func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1341 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1342 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1343
1344 if providerSupportsMedia {
1345 return messages
1346 }
1347
1348 convertedMessages := make([]fantasy.Message, 0, len(messages))
1349
1350 for _, msg := range messages {
1351 if msg.Role != fantasy.MessageRoleTool {
1352 convertedMessages = append(convertedMessages, msg)
1353 continue
1354 }
1355
1356 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1357 var mediaFiles []fantasy.FilePart
1358
1359 for _, part := range msg.Content {
1360 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1361 if !ok {
1362 textParts = append(textParts, part)
1363 continue
1364 }
1365
1366 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1367 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1368 if err != nil {
1369 slog.Warn("Failed to decode media data", "error", err)
1370 textParts = append(textParts, part)
1371 continue
1372 }
1373
1374 mediaFiles = append(mediaFiles, fantasy.FilePart{
1375 Data: decoded,
1376 MediaType: media.MediaType,
1377 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1378 })
1379
1380 textParts = append(textParts, fantasy.ToolResultPart{
1381 ToolCallID: toolResult.ToolCallID,
1382 Output: fantasy.ToolResultOutputContentText{
1383 Text: "[Image/media content loaded - see attached file]",
1384 },
1385 ProviderOptions: toolResult.ProviderOptions,
1386 })
1387 } else {
1388 textParts = append(textParts, part)
1389 }
1390 }
1391
1392 convertedMessages = append(convertedMessages, fantasy.Message{
1393 Role: fantasy.MessageRoleTool,
1394 Content: textParts,
1395 })
1396
1397 if len(mediaFiles) > 0 {
1398 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1399 "Here is the media content from the tool result:",
1400 mediaFiles...,
1401 ))
1402 }
1403 }
1404
1405 return convertedMessages
1406}
1407
1408// buildSummaryPrompt constructs the prompt text for session summarization.
1409func buildSummaryPrompt(todos []session.Todo) string {
1410 var sb strings.Builder
1411 sb.WriteString("Provide a detailed summary of our conversation above.")
1412 if len(todos) > 0 {
1413 sb.WriteString("\n\n## Current Todo List\n\n")
1414 for _, t := range todos {
1415 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1416 }
1417 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1418 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1419 }
1420 return sb.String()
1421}
1422
1423func providerRetryLogFields(err *fantasy.ProviderError, delay time.Duration) []any {
1424 fields := []any{
1425 "retry_delay", delay.String(),
1426 }
1427 if err == nil {
1428 return fields
1429 }
1430 fields = append(fields, "status_code", err.StatusCode)
1431 if err.Title != "" {
1432 fields = append(fields, "title", err.Title)
1433 }
1434 if err.Message != "" {
1435 fields = append(fields, "message", err.Message)
1436 }
1437 return fields
1438}