1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "net/http"
19 "os"
20 "regexp"
21 "strconv"
22 "strings"
23 "sync"
24 "time"
25
26 "charm.land/catwalk/pkg/catwalk"
27 "charm.land/fantasy"
28 "charm.land/fantasy/providers/anthropic"
29 "charm.land/fantasy/providers/bedrock"
30 "charm.land/fantasy/providers/google"
31 "charm.land/fantasy/providers/openai"
32 "charm.land/fantasy/providers/openrouter"
33 "charm.land/fantasy/providers/vercel"
34 "charm.land/lipgloss/v2"
35 "github.com/charmbracelet/crush/internal/agent/hyper"
36 "github.com/charmbracelet/crush/internal/agent/notify"
37 "github.com/charmbracelet/crush/internal/agent/tools"
38 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
39 "github.com/charmbracelet/crush/internal/config"
40 "github.com/charmbracelet/crush/internal/csync"
41 "github.com/charmbracelet/crush/internal/message"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 Prompt string
75 ProviderOptions fantasy.ProviderOptions
76 Attachments []message.Attachment
77 MaxOutputTokens int64
78 Temperature *float64
79 TopP *float64
80 TopK *int64
81 FrequencyPenalty *float64
82 PresencePenalty *float64
83 NonInteractive bool
84}
85
86type SessionAgent interface {
87 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
88 SetModels(large Model, small Model)
89 SetTools(tools []fantasy.AgentTool)
90 SetSystemPrompt(systemPrompt string)
91 Cancel(sessionID string)
92 CancelAll()
93 IsSessionBusy(sessionID string) bool
94 IsBusy() bool
95 QueuedPrompts(sessionID string) int
96 QueuedPromptsList(sessionID string) []string
97 ClearQueue(sessionID string)
98 Summarize(context.Context, string, fantasy.ProviderOptions) error
99 Model() Model
100}
101
102type Model struct {
103 Model fantasy.LanguageModel
104 CatwalkCfg catwalk.Model
105 ModelCfg config.SelectedModel
106 FlatRate bool
107}
108
109type sessionAgent struct {
110 largeModel *csync.Value[Model]
111 smallModel *csync.Value[Model]
112 systemPromptPrefix *csync.Value[string]
113 systemPrompt *csync.Value[string]
114 tools *csync.Slice[fantasy.AgentTool]
115
116 isSubAgent bool
117 sessions session.Service
118 messages message.Service
119 disableAutoSummarize bool
120 isYolo bool
121 notify pubsub.Publisher[notify.Notification]
122
123 messageQueue *csync.Map[string, []SessionAgentCall]
124 activeRequests *csync.Map[string, context.CancelFunc]
125}
126
127type SessionAgentOptions struct {
128 LargeModel Model
129 SmallModel Model
130 SystemPromptPrefix string
131 SystemPrompt string
132 IsSubAgent bool
133 DisableAutoSummarize bool
134 IsYolo bool
135 Sessions session.Service
136 Messages message.Service
137 Tools []fantasy.AgentTool
138 Notify pubsub.Publisher[notify.Notification]
139}
140
141func NewSessionAgent(
142 opts SessionAgentOptions,
143) SessionAgent {
144 return &sessionAgent{
145 largeModel: csync.NewValue(opts.LargeModel),
146 smallModel: csync.NewValue(opts.SmallModel),
147 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
148 systemPrompt: csync.NewValue(opts.SystemPrompt),
149 isSubAgent: opts.IsSubAgent,
150 sessions: opts.Sessions,
151 messages: opts.Messages,
152 disableAutoSummarize: opts.DisableAutoSummarize,
153 tools: csync.NewSliceFrom(opts.Tools),
154 isYolo: opts.IsYolo,
155 notify: opts.Notify,
156 messageQueue: csync.NewMap[string, []SessionAgentCall](),
157 activeRequests: csync.NewMap[string, context.CancelFunc](),
158 }
159}
160
161func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
162 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
163 return nil, ErrEmptyPrompt
164 }
165 if call.SessionID == "" {
166 return nil, ErrSessionMissing
167 }
168
169 // Queue the message if busy
170 if a.IsSessionBusy(call.SessionID) {
171 existing, ok := a.messageQueue.Get(call.SessionID)
172 if !ok {
173 existing = []SessionAgentCall{}
174 }
175 existing = append(existing, call)
176 a.messageQueue.Set(call.SessionID, existing)
177 return nil, nil
178 }
179
180 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
181 agentTools := a.tools.Copy()
182 largeModel := a.largeModel.Get()
183 systemPrompt := a.systemPrompt.Get()
184 promptPrefix := a.systemPromptPrefix.Get()
185 var instructions strings.Builder
186
187 for _, server := range mcp.GetStates() {
188 if server.State != mcp.StateConnected {
189 continue
190 }
191 if s := server.Client.InitializeResult().Instructions; s != "" {
192 instructions.WriteString(s)
193 instructions.WriteString("\n\n")
194 }
195 }
196
197 if s := instructions.String(); s != "" {
198 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
199 }
200
201 if len(agentTools) > 0 {
202 // Add Anthropic caching to the last tool.
203 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
204 }
205
206 agent := fantasy.NewAgent(
207 largeModel.Model,
208 fantasy.WithSystemPrompt(systemPrompt),
209 fantasy.WithTools(agentTools...),
210 fantasy.WithUserAgent(userAgent),
211 )
212
213 sessionLock := sync.Mutex{}
214 currentSession, err := a.sessions.Get(ctx, call.SessionID)
215 if err != nil {
216 return nil, fmt.Errorf("failed to get session: %w", err)
217 }
218
219 msgs, err := a.getSessionMessages(ctx, currentSession)
220 if err != nil {
221 return nil, fmt.Errorf("failed to get session messages: %w", err)
222 }
223
224 var wg sync.WaitGroup
225 // Generate title if first message.
226 if len(msgs) == 0 {
227 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
228 wg.Go(func() {
229 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
230 })
231 }
232 defer wg.Wait()
233
234 // Add the user message to the session.
235 _, err = a.createUserMessage(ctx, call)
236 if err != nil {
237 return nil, err
238 }
239
240 // Add the session to the context.
241 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
242
243 genCtx, cancel := context.WithCancel(ctx)
244 a.activeRequests.Set(call.SessionID, cancel)
245
246 defer cancel()
247 defer a.activeRequests.Del(call.SessionID)
248 // Drain any debounced message updates before returning. message.Service
249 // already flushes synchronously on terminal updates, but a defer here
250 // guarantees the contract at every Run exit (success, error, panic
251 // recovery upstream) without callers needing to know.
252 defer func() {
253 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
254 slog.Error("Failed to flush pending message updates after run", "error", flushErr)
255 }
256 }()
257
258 history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
259
260 startTime := time.Now()
261 a.eventPromptSent(call.SessionID)
262
263 var currentAssistant *message.Message
264 var stepMessages []fantasy.Message
265 var shouldSummarize bool
266 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
267 var maxOutputTokens *int64
268 if call.MaxOutputTokens > 0 {
269 maxOutputTokens = &call.MaxOutputTokens
270 }
271 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
272 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
273 Files: files,
274 Messages: history,
275 ProviderOptions: call.ProviderOptions,
276 MaxOutputTokens: maxOutputTokens,
277 TopP: call.TopP,
278 Temperature: call.Temperature,
279 PresencePenalty: call.PresencePenalty,
280 TopK: call.TopK,
281 FrequencyPenalty: call.FrequencyPenalty,
282 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
283 prepared.Messages = options.Messages
284 for i := range prepared.Messages {
285 prepared.Messages[i].ProviderOptions = nil
286 }
287
288 // Use latest tools (updated by SetTools when MCP tools change).
289 prepared.Tools = a.tools.Copy()
290
291 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
292 a.messageQueue.Del(call.SessionID)
293 for _, queued := range queuedCalls {
294 userMessage, createErr := a.createUserMessage(callContext, queued)
295 if createErr != nil {
296 return callContext, prepared, createErr
297 }
298 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
299 }
300
301 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
302
303 lastSystemRoleInx := 0
304 systemMessageUpdated := false
305 for i, msg := range prepared.Messages {
306 // Only add cache control to the last message.
307 if msg.Role == fantasy.MessageRoleSystem {
308 lastSystemRoleInx = i
309 } else if !systemMessageUpdated {
310 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
311 systemMessageUpdated = true
312 }
313 // Than add cache control to the last 2 messages.
314 if i > len(prepared.Messages)-3 {
315 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
316 }
317 }
318
319 if promptPrefix != "" {
320 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
321 }
322
323 sessionLock.Lock()
324 stepMessages = cloneFantasyMessages(prepared.Messages)
325 sessionLock.Unlock()
326
327 var assistantMsg message.Message
328 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
329 Role: message.Assistant,
330 Parts: []message.ContentPart{},
331 Model: largeModel.ModelCfg.Model,
332 Provider: largeModel.ModelCfg.Provider,
333 })
334 if err != nil {
335 return callContext, prepared, err
336 }
337 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
338 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
339 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
340 currentAssistant = &assistantMsg
341 return callContext, prepared, err
342 },
343 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
344 currentAssistant.AppendReasoningContent(reasoning.Text)
345 return a.messages.Update(genCtx, *currentAssistant)
346 },
347 OnReasoningDelta: func(id string, text string) error {
348 currentAssistant.AppendReasoningContent(text)
349 return a.messages.Update(genCtx, *currentAssistant)
350 },
351 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
352 // handle anthropic signature
353 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
354 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
355 currentAssistant.AppendReasoningSignature(reasoning.Signature)
356 }
357 }
358 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
359 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
360 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
361 }
362 }
363 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
364 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
365 currentAssistant.SetReasoningResponsesData(reasoning)
366 }
367 }
368 currentAssistant.FinishThinking()
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 OnTextDelta: func(id string, text string) error {
372 // Strip leading newline from initial text content. This is is
373 // particularly important in non-interactive mode where leading
374 // newlines are very visible.
375 if len(currentAssistant.Parts) == 0 {
376 text = strings.TrimPrefix(text, "\n")
377 }
378
379 currentAssistant.AppendContent(text)
380 return a.messages.Update(genCtx, *currentAssistant)
381 },
382 OnToolInputStart: func(id string, toolName string) error {
383 toolCall := message.ToolCall{
384 ID: id,
385 Name: toolName,
386 ProviderExecuted: false,
387 Finished: false,
388 }
389 currentAssistant.AddToolCall(toolCall)
390 // Use parent ctx instead of genCtx to ensure the update succeeds
391 // even if the request is canceled mid-stream
392 return a.messages.Update(ctx, *currentAssistant)
393 },
394 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
395 slog.Warn("Provider request failed, retrying", providerRetryLogFields(err, delay)...)
396 },
397 OnToolCall: func(tc fantasy.ToolCallContent) error {
398 toolCall := message.ToolCall{
399 ID: tc.ToolCallID,
400 Name: tc.ToolName,
401 Input: tc.Input,
402 ProviderExecuted: false,
403 Finished: true,
404 }
405 currentAssistant.AddToolCall(toolCall)
406 // Use parent ctx instead of genCtx to ensure the update succeeds
407 // even if the request is canceled mid-stream
408 return a.messages.Update(ctx, *currentAssistant)
409 },
410 OnToolResult: func(result fantasy.ToolResultContent) error {
411 toolResult := a.convertToToolResult(result)
412 // Use parent ctx instead of genCtx to ensure the message is created
413 // even if the request is canceled mid-stream
414 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
415 Role: message.Tool,
416 Parts: []message.ContentPart{
417 toolResult,
418 },
419 })
420 return createMsgErr
421 },
422 OnStepFinish: func(stepResult fantasy.StepResult) error {
423 finishReason := message.FinishReasonUnknown
424 switch stepResult.FinishReason {
425 case fantasy.FinishReasonLength:
426 finishReason = message.FinishReasonMaxTokens
427 case fantasy.FinishReasonStop:
428 finishReason = message.FinishReasonEndTurn
429 case fantasy.FinishReasonToolCalls:
430 finishReason = message.FinishReasonToolUse
431 }
432 // If a tool result halted the turn (e.g. a hook halt or a
433 // permission denial), the step ends on FinishReasonToolCalls but
434 // the model will not be called again. Treat it as the end of the
435 // turn so the UI can render the assistant footer.
436 if finishReason == message.FinishReasonToolUse {
437 for _, tr := range stepResult.Content.ToolResults() {
438 if tr.StopTurn {
439 finishReason = message.FinishReasonEndTurn
440 break
441 }
442 }
443 }
444 currentAssistant.AddFinish(finishReason, "", "")
445 sessionLock.Lock()
446 defer sessionLock.Unlock()
447
448 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
449 if getSessionErr != nil {
450 return getSessionErr
451 }
452 usage, estimated := fallbackStepUsage(stepMessages, stepResult)
453 a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated)
454 _, sessionErr := a.sessions.Save(ctx, updatedSession)
455 if sessionErr != nil {
456 return sessionErr
457 }
458 currentSession = updatedSession
459 return a.messages.Update(genCtx, *currentAssistant)
460 },
461 StopWhen: []fantasy.StopCondition{
462 func(_ []fantasy.StepResult) bool {
463 cw := int64(largeModel.CatwalkCfg.ContextWindow)
464 // If context window is unknown (0), skip auto-summarize
465 // to avoid immediately truncating custom/local models.
466 if cw == 0 {
467 return false
468 }
469 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
470 remaining := cw - tokens
471 var threshold int64
472 if cw > largeContextWindowThreshold {
473 threshold = largeContextWindowBuffer
474 } else {
475 threshold = int64(float64(cw) * smallContextWindowRatio)
476 }
477 if (remaining <= threshold) && !a.disableAutoSummarize {
478 shouldSummarize = true
479 return true
480 }
481 return false
482 },
483 func(steps []fantasy.StepResult) bool {
484 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
485 },
486 },
487 })
488
489 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
490
491 if err != nil {
492 isHyper := largeModel.ModelCfg.Provider == hyper.Name
493 isCancelErr := errors.Is(err, context.Canceled)
494 if currentAssistant == nil {
495 return result, err
496 }
497 // Ensure we finish thinking on error to close the reasoning state.
498 currentAssistant.FinishThinking()
499 toolCalls := currentAssistant.ToolCalls()
500 // INFO: we use the parent context here because the genCtx has been cancelled.
501 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
502 if createErr != nil {
503 return nil, createErr
504 }
505 for _, tc := range toolCalls {
506 if !tc.Finished {
507 tc.Finished = true
508 tc.Input = "{}"
509 currentAssistant.AddToolCall(tc)
510 updateErr := a.messages.Update(ctx, *currentAssistant)
511 if updateErr != nil {
512 return nil, updateErr
513 }
514 }
515
516 found := false
517 for _, msg := range msgs {
518 if msg.Role == message.Tool {
519 for _, tr := range msg.ToolResults() {
520 if tr.ToolCallID == tc.ID {
521 found = true
522 break
523 }
524 }
525 }
526 if found {
527 break
528 }
529 }
530 if found {
531 continue
532 }
533 content := "There was an error while executing the tool"
534 if isCancelErr {
535 content = "Error: user cancelled assistant tool calling"
536 }
537 toolResult := message.ToolResult{
538 ToolCallID: tc.ID,
539 Name: tc.Name,
540 Content: content,
541 IsError: true,
542 }
543 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
544 Role: message.Tool,
545 Parts: []message.ContentPart{
546 toolResult,
547 },
548 })
549 if createErr != nil {
550 return nil, createErr
551 }
552 }
553 var fantasyErr *fantasy.Error
554 var providerErr *fantasy.ProviderError
555 const defaultTitle = "Provider Error"
556 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
557 if isCancelErr {
558 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
559 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized {
560 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
561 if a.notify != nil {
562 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
563 SessionID: call.SessionID,
564 SessionTitle: currentSession.Title,
565 Type: notify.TypeReAuthenticate,
566 ProviderID: largeModel.ModelCfg.Provider,
567 })
568 }
569 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired {
570 url := hyper.BaseURL()
571 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
572 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
573 } else if errors.As(err, &providerErr) {
574 if providerErr.Message == "The requested model is not supported." {
575 url := "https://github.com/settings/copilot/features"
576 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
577 currentAssistant.AddFinish(
578 message.FinishReasonError,
579 "Copilot model not enabled",
580 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
581 )
582 } else {
583 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
584 }
585 } else if errors.As(err, &fantasyErr) {
586 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
587 } else {
588 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
589 }
590 // Note: we use the parent context here because the genCtx has been
591 // cancelled.
592 updateErr := a.messages.Update(ctx, *currentAssistant)
593 if updateErr != nil {
594 return nil, updateErr
595 }
596 return nil, err
597 }
598
599 if shouldSummarize {
600 a.activeRequests.Del(call.SessionID)
601 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
602 return nil, summarizeErr
603 }
604 // If the agent wasn't done...
605 if len(currentAssistant.ToolCalls()) > 0 {
606 existing, ok := a.messageQueue.Get(call.SessionID)
607 if !ok {
608 existing = []SessionAgentCall{}
609 }
610 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
611 existing = append(existing, call)
612 a.messageQueue.Set(call.SessionID, existing)
613 }
614 }
615
616 // Release active request before publishing the notification.
617 // TUI handlers poll IsSessionBusy() and only re-evaluate when a
618 // tea.Msg arrives, so the cleanup must precede the notify or
619 // subscribers see stale busy state at the moment of receipt.
620 a.activeRequests.Del(call.SessionID)
621 cancel()
622
623 // Send notification that agent has finished its turn (skip for
624 // nested/non-interactive sessions).
625 if !call.NonInteractive && a.notify != nil {
626 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
627 SessionID: call.SessionID,
628 SessionTitle: currentSession.Title,
629 Type: notify.TypeAgentFinished,
630 })
631 }
632
633 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
634 if !ok || len(queuedMessages) == 0 {
635 return result, err
636 }
637 // There are queued messages restart the loop.
638 firstQueuedMessage := queuedMessages[0]
639 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
640 return a.Run(ctx, firstQueuedMessage)
641}
642
643func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
644 if a.IsSessionBusy(sessionID) {
645 return ErrSessionBusy
646 }
647
648 // Copy mutable fields under lock to avoid races with SetModels.
649 largeModel := a.largeModel.Get()
650 systemPromptPrefix := a.systemPromptPrefix.Get()
651
652 currentSession, err := a.sessions.Get(ctx, sessionID)
653 if err != nil {
654 return fmt.Errorf("failed to get session: %w", err)
655 }
656 msgs, err := a.getSessionMessages(ctx, currentSession)
657 if err != nil {
658 return err
659 }
660 if len(msgs) == 0 {
661 // Nothing to summarize.
662 return nil
663 }
664
665 aiMsgs, _ := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages)
666
667 genCtx, cancel := context.WithCancel(ctx)
668 a.activeRequests.Set(sessionID, cancel)
669 defer a.activeRequests.Del(sessionID)
670 defer cancel()
671 defer func() {
672 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
673 slog.Error("Failed to flush pending message updates after summarize", "error", flushErr)
674 }
675 }()
676
677 agent := fantasy.NewAgent(
678 largeModel.Model,
679 fantasy.WithSystemPrompt(string(summaryPrompt)),
680 fantasy.WithUserAgent(userAgent),
681 )
682 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
683 Role: message.Assistant,
684 Model: largeModel.ModelCfg.Model,
685 Provider: largeModel.ModelCfg.Provider,
686 IsSummaryMessage: true,
687 })
688 if err != nil {
689 return err
690 }
691
692 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
693
694 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
695 Prompt: summaryPromptText,
696 Messages: aiMsgs,
697 ProviderOptions: opts,
698 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
699 prepared.Messages = options.Messages
700 if systemPromptPrefix != "" {
701 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
702 }
703 return callContext, prepared, nil
704 },
705 OnReasoningDelta: func(id string, text string) error {
706 summaryMessage.AppendReasoningContent(text)
707 return a.messages.Update(genCtx, summaryMessage)
708 },
709 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
710 // Handle anthropic signature.
711 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
712 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
713 summaryMessage.AppendReasoningSignature(signature.Signature)
714 }
715 }
716 summaryMessage.FinishThinking()
717 return a.messages.Update(genCtx, summaryMessage)
718 },
719 OnTextDelta: func(id, text string) error {
720 summaryMessage.AppendContent(text)
721 return a.messages.Update(genCtx, summaryMessage)
722 },
723 })
724 if err != nil {
725 isCancelErr := errors.Is(err, context.Canceled)
726 if isCancelErr {
727 // User cancelled summarize we need to remove the summary message.
728 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
729 return deleteErr
730 }
731 // Mark the summary message as finished with an error so the UI
732 // stops spinning.
733 summaryMessage.AddFinish(message.FinishReasonError, "Summarization Error", err.Error())
734 if updateErr := a.messages.Update(ctx, summaryMessage); updateErr != nil {
735 return updateErr
736 }
737 return err
738 }
739
740 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
741 err = a.messages.Update(genCtx, summaryMessage)
742 if err != nil {
743 return err
744 }
745
746 var openrouterCost *float64
747 for _, step := range resp.Steps {
748 stepCost := a.openrouterCost(step.ProviderMetadata)
749 if stepCost != nil {
750 newCost := *stepCost
751 if openrouterCost != nil {
752 newCost += *openrouterCost
753 }
754 openrouterCost = &newCost
755 }
756 }
757
758 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost, false)
759
760 // Just in case, get just the last usage info.
761 usage := resp.Response.Usage
762 currentSession.SummaryMessageID = summaryMessage.ID
763 currentSession.CompletionTokens = summaryCompletionTokens(usage, summaryMessage)
764 currentSession.PromptTokens = 0
765 currentSession.EstimatedUsage = usageIsZero(usage)
766 _, err = a.sessions.Save(genCtx, currentSession)
767 if err != nil {
768 return err
769 }
770
771 // Release the active request before processing queued messages so that
772 // Run() does not see the session as busy.
773 a.activeRequests.Del(sessionID)
774 cancel()
775
776 // Process any messages that were queued while summarizing.
777 queuedMessages, ok := a.messageQueue.Get(sessionID)
778 if !ok || len(queuedMessages) == 0 {
779 return nil
780 }
781 firstQueuedMessage := queuedMessages[0]
782 a.messageQueue.Set(sessionID, queuedMessages[1:])
783 _, qErr := a.Run(ctx, firstQueuedMessage)
784 return qErr
785}
786
787func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
788 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
789 return fantasy.ProviderOptions{}
790 }
791 return fantasy.ProviderOptions{
792 anthropic.Name: &anthropic.ProviderCacheControlOptions{
793 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
794 },
795 bedrock.Name: &anthropic.ProviderCacheControlOptions{
796 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
797 },
798 vercel.Name: &anthropic.ProviderCacheControlOptions{
799 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
800 },
801 }
802}
803
804func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
805 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
806 var attachmentParts []message.ContentPart
807 for _, attachment := range call.Attachments {
808 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
809 }
810 parts = append(parts, attachmentParts...)
811 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
812 Role: message.User,
813 Parts: parts,
814 })
815 if err != nil {
816 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
817 }
818 return msg, nil
819}
820
821func (a *sessionAgent) preparePrompt(msgs []message.Message, supportsImages bool, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
822 var history []fantasy.Message
823 if !a.isSubAgent {
824 history = append(history, fantasy.NewUserMessage(
825 fmt.Sprintf(
826 "<system_reminder>%s</system_reminder>",
827 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
828If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
829If not, please feel free to ignore. Again do not mention this message to the user.`,
830 ),
831 ))
832 }
833 // Collect all tool call IDs present in assistant messages and all tool
834 // result IDs present in tool messages. This lets us detect both orphaned
835 // tool results (result without a call) and orphaned tool calls (call
836 // without a result).
837 knownToolCallIDs := make(map[string]struct{})
838 knownToolResultIDs := make(map[string]struct{})
839 for _, m := range msgs {
840 switch m.Role {
841 case message.Assistant:
842 for _, tc := range m.ToolCalls() {
843 knownToolCallIDs[tc.ID] = struct{}{}
844 }
845 case message.Tool:
846 for _, tr := range m.ToolResults() {
847 knownToolResultIDs[tr.ToolCallID] = struct{}{}
848 }
849 }
850 }
851
852 for _, m := range msgs {
853 if len(m.Parts) == 0 {
854 continue
855 }
856 // Assistant message without content or tool calls (cancelled before it returned anything).
857 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
858 continue
859 }
860 if m.Role == message.Tool {
861 if msg, ok := filterOrphanedToolResults(m, knownToolCallIDs); ok {
862 history = append(history, msg)
863 }
864 continue
865 }
866 aiMsgs := m.ToAIMessage()
867 if !supportsImages {
868 for i := range aiMsgs {
869 if aiMsgs[i].Role == fantasy.MessageRoleUser {
870 aiMsgs[i].Content = filterFileParts(aiMsgs[i].Content)
871 }
872 }
873 }
874 history = append(history, aiMsgs...)
875
876 if m.Role == message.Assistant {
877 if msg, ok := syntheticToolResultsForOrphanedCalls(m, knownToolResultIDs); ok {
878 history = append(history, msg)
879 }
880 }
881 }
882
883 var files []fantasy.FilePart
884 for _, attachment := range attachments {
885 if attachment.IsText() {
886 continue
887 }
888 files = append(files, fantasy.FilePart{
889 Filename: attachment.FileName,
890 Data: attachment.Content,
891 MediaType: attachment.MimeType,
892 })
893 }
894
895 return history, files
896}
897
898// filterFileParts removes fantasy.FilePart entries from a slice of message
899// parts. Used to strip image attachments from historical user messages when
900// the current model does not support them.
901func filterFileParts(parts []fantasy.MessagePart) []fantasy.MessagePart {
902 filtered := make([]fantasy.MessagePart, 0, len(parts))
903 for _, part := range parts {
904 if _, ok := fantasy.AsMessagePart[fantasy.FilePart](part); ok {
905 continue
906 }
907 filtered = append(filtered, part)
908 }
909 return filtered
910}
911
912// filterOrphanedToolResults converts a tool message to a fantasy.Message,
913// dropping any tool result parts whose tool_call_id has no matching tool call
914// in the known set. An orphaned result causes API validation to fail on every
915// subsequent turn, permanently locking the session. Returns the filtered
916// message and true if at least one valid part remains.
917func filterOrphanedToolResults(m message.Message, knownToolCallIDs map[string]struct{}) (fantasy.Message, bool) {
918 aiMsgs := m.ToAIMessage()
919 if len(aiMsgs) == 0 {
920 return fantasy.Message{}, false
921 }
922 var validParts []fantasy.MessagePart
923 for _, part := range aiMsgs[0].Content {
924 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
925 if !ok {
926 validParts = append(validParts, part)
927 continue
928 }
929 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
930 validParts = append(validParts, part)
931 } else {
932 slog.Warn(
933 "Dropping orphaned tool result with no matching tool call",
934 "tool_call_id", tr.ToolCallID,
935 )
936 }
937 }
938 if len(validParts) == 0 {
939 return fantasy.Message{}, false
940 }
941 msg := aiMsgs[0]
942 msg.Content = validParts
943 return msg, true
944}
945
946// syntheticToolResultsForOrphanedCalls returns a tool message containing
947// synthetic tool results for any tool calls in the assistant message that
948// have no matching result in knownToolResultIDs. LLM APIs require every
949// tool_use to be immediately followed by a tool_result; an interrupted
950// session can leave orphaned tool_use blocks that permanently lock the
951// conversation. Returns the message and true if any synthetic results were
952// produced.
953func syntheticToolResultsForOrphanedCalls(m message.Message, knownToolResultIDs map[string]struct{}) (fantasy.Message, bool) {
954 var syntheticParts []fantasy.MessagePart
955 for _, tc := range m.ToolCalls() {
956 if _, hasResult := knownToolResultIDs[tc.ID]; hasResult {
957 continue
958 }
959 slog.Warn(
960 "Injecting synthetic tool result for orphaned tool call",
961 "tool_call_id", tc.ID,
962 "tool_name", tc.Name,
963 )
964 syntheticParts = append(syntheticParts, fantasy.ToolResultPart{
965 ToolCallID: tc.ID,
966 Output: fantasy.ToolResultOutputContentError{
967 Error: errors.New("tool call was interrupted and did not produce a result, you may retry this call if the result is still needed"),
968 },
969 })
970 }
971 if len(syntheticParts) == 0 {
972 return fantasy.Message{}, false
973 }
974 return fantasy.Message{
975 Role: fantasy.MessageRoleTool,
976 Content: syntheticParts,
977 }, true
978}
979
980func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
981 msgs, err := a.messages.List(ctx, session.ID)
982 if err != nil {
983 return nil, fmt.Errorf("failed to list messages: %w", err)
984 }
985
986 if session.SummaryMessageID != "" {
987 summaryMsgIndex := -1
988 for i, msg := range msgs {
989 if msg.ID == session.SummaryMessageID {
990 summaryMsgIndex = i
991 break
992 }
993 }
994 if summaryMsgIndex != -1 {
995 msgs = msgs[summaryMsgIndex:]
996 msgs[0].Role = message.User
997 }
998 }
999 return msgs, nil
1000}
1001
1002// generateTitle generates a session titled based on the initial prompt.
1003func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
1004 if userPrompt == "" {
1005 return
1006 }
1007
1008 smallModel := a.smallModel.Get()
1009 largeModel := a.largeModel.Get()
1010 systemPromptPrefix := a.systemPromptPrefix.Get()
1011
1012 var maxOutputTokens int64 = 40
1013 if smallModel.CatwalkCfg.CanReason {
1014 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
1015 }
1016
1017 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
1018 return fantasy.NewAgent(
1019 m,
1020 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
1021 fantasy.WithMaxOutputTokens(tok),
1022 fantasy.WithUserAgent(userAgent),
1023 )
1024 }
1025
1026 streamCall := fantasy.AgentStreamCall{
1027 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
1028 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
1029 prepared.Messages = opts.Messages
1030 if systemPromptPrefix != "" {
1031 prepared.Messages = append([]fantasy.Message{
1032 fantasy.NewSystemMessage(systemPromptPrefix),
1033 }, prepared.Messages...)
1034 }
1035 return callCtx, prepared, nil
1036 },
1037 }
1038
1039 // Use the small model to generate the title.
1040 model := smallModel
1041 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
1042 resp, err := agent.Stream(ctx, streamCall)
1043 if err == nil {
1044 // We successfully generated a title with the small model.
1045 slog.Debug("Generated title with small model")
1046 } else {
1047 // It didn't work. Let's try with the big model.
1048 slog.Error("Error generating title with small model; trying big model", "err", err)
1049 model = largeModel
1050 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
1051 resp, err = agent.Stream(ctx, streamCall)
1052 if err == nil {
1053 slog.Debug("Generated title with large model")
1054 } else {
1055 // Welp, the large model didn't work either. Use the default
1056 // session name and return.
1057 slog.Error("Error generating title with large model", "err", err)
1058 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1059 if saveErr != nil {
1060 slog.Error("Failed to save session title", "error", saveErr)
1061 }
1062 return
1063 }
1064 }
1065
1066 if resp == nil {
1067 // Actually, we didn't get a response so we can't. Use the default
1068 // session name and return.
1069 slog.Error("Response is nil; can't generate title")
1070 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1071 if saveErr != nil {
1072 slog.Error("Failed to save session title", "error", saveErr)
1073 }
1074 return
1075 }
1076
1077 // Clean up title.
1078 var title string
1079 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
1080
1081 // Remove thinking tags if present.
1082 title = thinkTagRegex.ReplaceAllString(title, "")
1083 title = orphanThinkTagRegex.ReplaceAllString(title, "")
1084
1085 title = strings.TrimSpace(title)
1086 title = cmp.Or(title, DefaultSessionName)
1087
1088 // Calculate usage and cost.
1089 var openrouterCost *float64
1090 for _, step := range resp.Steps {
1091 stepCost := a.openrouterCost(step.ProviderMetadata)
1092 if stepCost != nil {
1093 newCost := *stepCost
1094 if openrouterCost != nil {
1095 newCost += *openrouterCost
1096 }
1097 openrouterCost = &newCost
1098 }
1099 }
1100
1101 modelConfig := model.CatwalkCfg
1102 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
1103 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
1104 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
1105 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
1106
1107 // Use override cost if available (e.g., from OpenRouter).
1108 if openrouterCost != nil {
1109 cost = *openrouterCost
1110 }
1111
1112 // Skip cost accumulation
1113 if model.FlatRate {
1114 cost = 0
1115 }
1116
1117 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
1118 completionTokens := resp.TotalUsage.OutputTokens
1119
1120 // Atomically update only title and usage fields to avoid overriding other
1121 // concurrent session updates.
1122 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
1123 if saveErr != nil {
1124 slog.Error("Failed to save session title and usage", "error", saveErr)
1125 return
1126 }
1127}
1128
1129func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
1130 openrouterMetadata, ok := metadata[openrouter.Name]
1131 if !ok {
1132 return nil
1133 }
1134
1135 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
1136 if !ok {
1137 return nil
1138 }
1139 return &opts.Usage.Cost
1140}
1141
1142func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) {
1143 if !usageIsZero(usage) {
1144 session.EstimatedUsage = estimated
1145 }
1146
1147 modelConfig := model.CatwalkCfg
1148 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1149 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1150 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1151 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1152
1153 if !estimated {
1154 a.eventTokensUsed(session.ID, model, usage, cost)
1155 }
1156
1157 if estimated {
1158 cost = 0
1159 } else {
1160 // Use override cost if available (e.g., from OpenRouter).
1161 if overrideCost != nil {
1162 cost = *overrideCost
1163 }
1164
1165 // Skip cost accumulation
1166 if model.FlatRate {
1167 cost = 0
1168 }
1169 }
1170
1171 session.Cost += cost
1172 updateSessionTokenCounters(session, usage)
1173}
1174
1175func updateSessionTokenCounters(session *session.Session, usage fantasy.Usage) {
1176 if usage.OutputTokens != 0 {
1177 session.CompletionTokens = usage.OutputTokens
1178 }
1179 if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
1180 session.PromptTokens = promptTokens
1181 }
1182}
1183
1184func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message) int64 {
1185 if usage.OutputTokens != 0 {
1186 return usage.OutputTokens
1187 }
1188 return approxTokenCount(summaryMessage.Content().Text) + approxTokenCount(summaryMessage.ReasoningContent().String())
1189}
1190
1191func (a *sessionAgent) Cancel(sessionID string) {
1192 // Cancel regular requests. Don't use Take() here - we need the entry to
1193 // remain in activeRequests so IsBusy() returns true until the goroutine
1194 // fully completes (including error handling that may access the DB).
1195 // The defer in processRequest will clean up the entry.
1196 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1197 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1198 cancel()
1199 }
1200
1201 // Also check for summarize requests.
1202 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1203 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1204 cancel()
1205 }
1206
1207 if a.QueuedPrompts(sessionID) > 0 {
1208 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1209 a.messageQueue.Del(sessionID)
1210 }
1211}
1212
1213func (a *sessionAgent) ClearQueue(sessionID string) {
1214 if a.QueuedPrompts(sessionID) > 0 {
1215 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1216 a.messageQueue.Del(sessionID)
1217 }
1218}
1219
1220func (a *sessionAgent) CancelAll() {
1221 if !a.IsBusy() {
1222 return
1223 }
1224 for key := range a.activeRequests.Seq2() {
1225 a.Cancel(key) // key is sessionID
1226 }
1227
1228 timeout := time.After(5 * time.Second)
1229 for a.IsBusy() {
1230 select {
1231 case <-timeout:
1232 return
1233 default:
1234 time.Sleep(200 * time.Millisecond)
1235 }
1236 }
1237}
1238
1239func (a *sessionAgent) IsBusy() bool {
1240 var busy bool
1241 for cancelFunc := range a.activeRequests.Seq() {
1242 if cancelFunc != nil {
1243 busy = true
1244 break
1245 }
1246 }
1247 return busy
1248}
1249
1250func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1251 _, busy := a.activeRequests.Get(sessionID)
1252 return busy
1253}
1254
1255func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1256 l, ok := a.messageQueue.Get(sessionID)
1257 if !ok {
1258 return 0
1259 }
1260 return len(l)
1261}
1262
1263func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1264 l, ok := a.messageQueue.Get(sessionID)
1265 if !ok {
1266 return nil
1267 }
1268 prompts := make([]string, len(l))
1269 for i, call := range l {
1270 prompts[i] = call.Prompt
1271 }
1272 return prompts
1273}
1274
1275func (a *sessionAgent) SetModels(large Model, small Model) {
1276 a.largeModel.Set(large)
1277 a.smallModel.Set(small)
1278}
1279
1280func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1281 a.tools.SetSlice(tools)
1282}
1283
1284func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1285 a.systemPrompt.Set(systemPrompt)
1286}
1287
1288func (a *sessionAgent) Model() Model {
1289 return a.largeModel.Get()
1290}
1291
1292// convertToToolResult converts a fantasy tool result to a message tool result.
1293func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1294 baseResult := message.ToolResult{
1295 ToolCallID: result.ToolCallID,
1296 Name: result.ToolName,
1297 Metadata: result.ClientMetadata,
1298 }
1299
1300 switch result.Result.GetType() {
1301 case fantasy.ToolResultContentTypeText:
1302 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1303 baseResult.Content = r.Text
1304 }
1305 case fantasy.ToolResultContentTypeError:
1306 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1307 baseResult.Content = r.Error.Error()
1308 baseResult.IsError = true
1309 }
1310 case fantasy.ToolResultContentTypeMedia:
1311 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1312 if !stringext.IsValidBase64(r.Data) {
1313 slog.Warn(
1314 "Tool returned media with invalid base64 data, discarding image",
1315 "tool", result.ToolName,
1316 "tool_call_id", result.ToolCallID,
1317 )
1318 baseResult.Content = "Tool returned image data with invalid encoding"
1319 baseResult.IsError = true
1320 } else {
1321 content := r.Text
1322 if content == "" {
1323 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1324 }
1325 baseResult.Content = content
1326 baseResult.Data = r.Data
1327 baseResult.MIMEType = r.MediaType
1328 }
1329 }
1330 }
1331
1332 return baseResult
1333}
1334
1335// workaroundProviderMediaLimitations converts media content in tool results to
1336// user messages for providers that don't natively support images in tool results.
1337//
1338// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1339// don't support sending images/media in tool result messages - they only accept
1340// text in tool results. However, they DO support images in user messages.
1341//
1342// If we send media in tool results to these providers, the API returns an error.
1343//
1344// Solution: For these providers, we:
1345// 1. Replace the media in the tool result with a text placeholder
1346// 2. Inject a user message immediately after with the image as a file attachment
1347// 3. This maintains the tool execution flow while working around API limitations
1348//
1349// Anthropic and Bedrock support images natively in tool results, so we skip
1350// this workaround for them.
1351//
1352// Example transformation:
1353//
1354// BEFORE: [tool result: image data]
1355// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1356func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1357 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1358 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1359
1360 if providerSupportsMedia {
1361 return messages
1362 }
1363
1364 convertedMessages := make([]fantasy.Message, 0, len(messages))
1365
1366 for _, msg := range messages {
1367 if msg.Role != fantasy.MessageRoleTool {
1368 convertedMessages = append(convertedMessages, msg)
1369 continue
1370 }
1371
1372 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1373 var mediaFiles []fantasy.FilePart
1374
1375 for _, part := range msg.Content {
1376 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1377 if !ok {
1378 textParts = append(textParts, part)
1379 continue
1380 }
1381
1382 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1383 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1384 if err != nil {
1385 slog.Warn("Failed to decode media data", "error", err)
1386 textParts = append(textParts, part)
1387 continue
1388 }
1389
1390 mediaFiles = append(mediaFiles, fantasy.FilePart{
1391 Data: decoded,
1392 MediaType: media.MediaType,
1393 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1394 })
1395
1396 textParts = append(textParts, fantasy.ToolResultPart{
1397 ToolCallID: toolResult.ToolCallID,
1398 Output: fantasy.ToolResultOutputContentText{
1399 Text: "[Image/media content loaded - see attached file]",
1400 },
1401 ProviderOptions: toolResult.ProviderOptions,
1402 })
1403 } else {
1404 textParts = append(textParts, part)
1405 }
1406 }
1407
1408 convertedMessages = append(convertedMessages, fantasy.Message{
1409 Role: fantasy.MessageRoleTool,
1410 Content: textParts,
1411 })
1412
1413 if len(mediaFiles) > 0 {
1414 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1415 "Here is the media content from the tool result:",
1416 mediaFiles...,
1417 ))
1418 }
1419 }
1420
1421 return convertedMessages
1422}
1423
1424// buildSummaryPrompt constructs the prompt text for session summarization.
1425func buildSummaryPrompt(todos []session.Todo) string {
1426 var sb strings.Builder
1427 sb.WriteString("Provide a detailed summary of our conversation above.")
1428 if len(todos) > 0 {
1429 sb.WriteString("\n\n## Current Todo List\n\n")
1430 for _, t := range todos {
1431 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1432 }
1433 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1434 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1435 }
1436 return sb.String()
1437}
1438
1439func providerRetryLogFields(err *fantasy.ProviderError, delay time.Duration) []any {
1440 fields := []any{
1441 "retry_delay", delay.String(),
1442 }
1443 if err == nil {
1444 return fields
1445 }
1446 fields = append(fields, "status_code", err.StatusCode)
1447 if err.Title != "" {
1448 fields = append(fields, "title", err.Title)
1449 }
1450 if err.Message != "" {
1451 fields = append(fields, "message", err.Message)
1452 }
1453 return fields
1454}