1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43const defaultSessionName = "Untitled Session"
44
45//go:embed templates/title.md
46var titlePrompt []byte
47
48//go:embed templates/summary.md
49var summaryPrompt []byte
50
51// Used to remove <think> tags from generated titles.
52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
53
54type SessionAgentCall struct {
55 SessionID string
56 Prompt string
57 ProviderOptions fantasy.ProviderOptions
58 Attachments []message.Attachment
59 MaxOutputTokens int64
60 Temperature *float64
61 TopP *float64
62 TopK *int64
63 FrequencyPenalty *float64
64 PresencePenalty *float64
65}
66
67type SessionAgent interface {
68 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
69 SetModels(large Model, small Model)
70 SetTools(tools []fantasy.AgentTool)
71 Cancel(sessionID string)
72 CancelAll()
73 IsSessionBusy(sessionID string) bool
74 IsBusy() bool
75 QueuedPrompts(sessionID string) int
76 QueuedPromptsList(sessionID string) []string
77 ClearQueue(sessionID string)
78 Summarize(context.Context, string, fantasy.ProviderOptions) error
79 Model() Model
80}
81
82type Model struct {
83 Model fantasy.LanguageModel
84 CatwalkCfg catwalk.Model
85 ModelCfg config.SelectedModel
86}
87
88type sessionAgent struct {
89 largeModel Model
90 smallModel Model
91 systemPromptPrefix string
92 systemPrompt string
93 isSubAgent bool
94 tools []fantasy.AgentTool
95 sessions session.Service
96 messages message.Service
97 disableAutoSummarize bool
98 isYolo bool
99
100 messageQueue *csync.Map[string, []SessionAgentCall]
101 activeRequests *csync.Map[string, context.CancelFunc]
102}
103
104type SessionAgentOptions struct {
105 LargeModel Model
106 SmallModel Model
107 SystemPromptPrefix string
108 SystemPrompt string
109 IsSubAgent bool
110 DisableAutoSummarize bool
111 IsYolo bool
112 Sessions session.Service
113 Messages message.Service
114 Tools []fantasy.AgentTool
115}
116
117func NewSessionAgent(
118 opts SessionAgentOptions,
119) SessionAgent {
120 return &sessionAgent{
121 largeModel: opts.LargeModel,
122 smallModel: opts.SmallModel,
123 systemPromptPrefix: opts.SystemPromptPrefix,
124 systemPrompt: opts.SystemPrompt,
125 isSubAgent: opts.IsSubAgent,
126 sessions: opts.Sessions,
127 messages: opts.Messages,
128 disableAutoSummarize: opts.DisableAutoSummarize,
129 tools: opts.Tools,
130 isYolo: opts.IsYolo,
131 messageQueue: csync.NewMap[string, []SessionAgentCall](),
132 activeRequests: csync.NewMap[string, context.CancelFunc](),
133 }
134}
135
136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
137 if call.Prompt == "" {
138 return nil, ErrEmptyPrompt
139 }
140 if call.SessionID == "" {
141 return nil, ErrSessionMissing
142 }
143
144 // Queue the message if busy
145 if a.IsSessionBusy(call.SessionID) {
146 existing, ok := a.messageQueue.Get(call.SessionID)
147 if !ok {
148 existing = []SessionAgentCall{}
149 }
150 existing = append(existing, call)
151 a.messageQueue.Set(call.SessionID, existing)
152 return nil, nil
153 }
154
155 if len(a.tools) > 0 {
156 // Add Anthropic caching to the last tool.
157 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
158 }
159
160 agent := fantasy.NewAgent(
161 a.largeModel.Model,
162 fantasy.WithSystemPrompt(a.systemPrompt),
163 fantasy.WithTools(a.tools...),
164 )
165
166 sessionLock := sync.Mutex{}
167 currentSession, err := a.sessions.Get(ctx, call.SessionID)
168 if err != nil {
169 return nil, fmt.Errorf("failed to get session: %w", err)
170 }
171
172 msgs, err := a.getSessionMessages(ctx, currentSession)
173 if err != nil {
174 return nil, fmt.Errorf("failed to get session messages: %w", err)
175 }
176
177 var wg sync.WaitGroup
178 // Generate title if first message.
179 if len(msgs) == 0 {
180 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
181 wg.Go(func() {
182 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
183 })
184 }
185
186 // Add the user message to the session.
187 _, err = a.createUserMessage(ctx, call)
188 if err != nil {
189 return nil, err
190 }
191
192 // Add the session to the context.
193 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
194
195 genCtx, cancel := context.WithCancel(ctx)
196 a.activeRequests.Set(call.SessionID, cancel)
197
198 defer cancel()
199 defer a.activeRequests.Del(call.SessionID)
200
201 history, files := a.preparePrompt(msgs, call.Attachments...)
202
203 startTime := time.Now()
204 a.eventPromptSent(call.SessionID)
205
206 var currentAssistant *message.Message
207 var shouldSummarize bool
208 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
209 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
210 Files: files,
211 Messages: history,
212 ProviderOptions: call.ProviderOptions,
213 MaxOutputTokens: &call.MaxOutputTokens,
214 TopP: call.TopP,
215 Temperature: call.Temperature,
216 PresencePenalty: call.PresencePenalty,
217 TopK: call.TopK,
218 FrequencyPenalty: call.FrequencyPenalty,
219 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
220 prepared.Messages = options.Messages
221 for i := range prepared.Messages {
222 prepared.Messages[i].ProviderOptions = nil
223 }
224
225 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
226 a.messageQueue.Del(call.SessionID)
227 for _, queued := range queuedCalls {
228 userMessage, createErr := a.createUserMessage(callContext, queued)
229 if createErr != nil {
230 return callContext, prepared, createErr
231 }
232 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
233 }
234
235 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
236
237 lastSystemRoleInx := 0
238 systemMessageUpdated := false
239 for i, msg := range prepared.Messages {
240 // Only add cache control to the last message.
241 if msg.Role == fantasy.MessageRoleSystem {
242 lastSystemRoleInx = i
243 } else if !systemMessageUpdated {
244 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
245 systemMessageUpdated = true
246 }
247 // Than add cache control to the last 2 messages.
248 if i > len(prepared.Messages)-3 {
249 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
250 }
251 }
252
253 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
254 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
255 }
256
257 var assistantMsg message.Message
258 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
259 Role: message.Assistant,
260 Parts: []message.ContentPart{},
261 Model: a.largeModel.ModelCfg.Model,
262 Provider: a.largeModel.ModelCfg.Provider,
263 })
264 if err != nil {
265 return callContext, prepared, err
266 }
267 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
268 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
269 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
270 currentAssistant = &assistantMsg
271 return callContext, prepared, err
272 },
273 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
274 currentAssistant.AppendReasoningContent(reasoning.Text)
275 return a.messages.Update(genCtx, *currentAssistant)
276 },
277 OnReasoningDelta: func(id string, text string) error {
278 currentAssistant.AppendReasoningContent(text)
279 return a.messages.Update(genCtx, *currentAssistant)
280 },
281 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
282 // handle anthropic signature
283 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
284 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
285 currentAssistant.AppendReasoningSignature(reasoning.Signature)
286 }
287 }
288 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
289 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
290 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
291 }
292 }
293 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
294 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
295 currentAssistant.SetReasoningResponsesData(reasoning)
296 }
297 }
298 currentAssistant.FinishThinking()
299 return a.messages.Update(genCtx, *currentAssistant)
300 },
301 OnTextDelta: func(id string, text string) error {
302 // Strip leading newline from initial text content. This is is
303 // particularly important in non-interactive mode where leading
304 // newlines are very visible.
305 if len(currentAssistant.Parts) == 0 {
306 text = strings.TrimPrefix(text, "\n")
307 }
308
309 currentAssistant.AppendContent(text)
310 return a.messages.Update(genCtx, *currentAssistant)
311 },
312 OnToolInputStart: func(id string, toolName string) error {
313 toolCall := message.ToolCall{
314 ID: id,
315 Name: toolName,
316 ProviderExecuted: false,
317 Finished: false,
318 }
319 currentAssistant.AddToolCall(toolCall)
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
323 // TODO: implement
324 },
325 OnToolCall: func(tc fantasy.ToolCallContent) error {
326 toolCall := message.ToolCall{
327 ID: tc.ToolCallID,
328 Name: tc.ToolName,
329 Input: tc.Input,
330 ProviderExecuted: false,
331 Finished: true,
332 }
333 currentAssistant.AddToolCall(toolCall)
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnToolResult: func(result fantasy.ToolResultContent) error {
337 toolResult := a.convertToToolResult(result)
338 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
339 Role: message.Tool,
340 Parts: []message.ContentPart{
341 toolResult,
342 },
343 })
344 return createMsgErr
345 },
346 OnStepFinish: func(stepResult fantasy.StepResult) error {
347 finishReason := message.FinishReasonUnknown
348 switch stepResult.FinishReason {
349 case fantasy.FinishReasonLength:
350 finishReason = message.FinishReasonMaxTokens
351 case fantasy.FinishReasonStop:
352 finishReason = message.FinishReasonEndTurn
353 case fantasy.FinishReasonToolCalls:
354 finishReason = message.FinishReasonToolUse
355 }
356 currentAssistant.AddFinish(finishReason, "", "")
357 sessionLock.Lock()
358 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
359 if getSessionErr != nil {
360 sessionLock.Unlock()
361 return getSessionErr
362 }
363 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
364 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
365 sessionLock.Unlock()
366 if sessionErr != nil {
367 return sessionErr
368 }
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 StopWhen: []fantasy.StopCondition{
372 func(_ []fantasy.StepResult) bool {
373 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
374 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
375 remaining := cw - tokens
376 var threshold int64
377 if cw > 200_000 {
378 threshold = 20_000
379 } else {
380 threshold = int64(float64(cw) * 0.2)
381 }
382 if (remaining <= threshold) && !a.disableAutoSummarize {
383 shouldSummarize = true
384 return true
385 }
386 return false
387 },
388 },
389 })
390
391 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
392
393 if err != nil {
394 isCancelErr := errors.Is(err, context.Canceled)
395 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
396 if currentAssistant == nil {
397 return result, err
398 }
399 // Ensure we finish thinking on error to close the reasoning state.
400 currentAssistant.FinishThinking()
401 toolCalls := currentAssistant.ToolCalls()
402 // INFO: we use the parent context here because the genCtx has been cancelled.
403 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
404 if createErr != nil {
405 return nil, createErr
406 }
407 for _, tc := range toolCalls {
408 if !tc.Finished {
409 tc.Finished = true
410 tc.Input = "{}"
411 currentAssistant.AddToolCall(tc)
412 updateErr := a.messages.Update(ctx, *currentAssistant)
413 if updateErr != nil {
414 return nil, updateErr
415 }
416 }
417
418 found := false
419 for _, msg := range msgs {
420 if msg.Role == message.Tool {
421 for _, tr := range msg.ToolResults() {
422 if tr.ToolCallID == tc.ID {
423 found = true
424 break
425 }
426 }
427 }
428 if found {
429 break
430 }
431 }
432 if found {
433 continue
434 }
435 content := "There was an error while executing the tool"
436 if isCancelErr {
437 content = "Tool execution canceled by user"
438 } else if isPermissionErr {
439 content = "User denied permission"
440 }
441 toolResult := message.ToolResult{
442 ToolCallID: tc.ID,
443 Name: tc.Name,
444 Content: content,
445 IsError: true,
446 }
447 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
448 Role: message.Tool,
449 Parts: []message.ContentPart{
450 toolResult,
451 },
452 })
453 if createErr != nil {
454 return nil, createErr
455 }
456 }
457 var fantasyErr *fantasy.Error
458 var providerErr *fantasy.ProviderError
459 const defaultTitle = "Provider Error"
460 if isCancelErr {
461 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
462 } else if isPermissionErr {
463 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
464 } else if errors.Is(err, hyper.ErrNoCredits) {
465 url := hyper.BaseURL()
466 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
467 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
468 } else if errors.As(err, &providerErr) {
469 if providerErr.Message == "The requested model is not supported." {
470 url := "https://github.com/settings/copilot/features"
471 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
472 currentAssistant.AddFinish(
473 message.FinishReasonError,
474 "Copilot model not enabled",
475 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
476 )
477 } else {
478 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
479 }
480 } else if errors.As(err, &fantasyErr) {
481 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
482 } else {
483 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
484 }
485 // Note: we use the parent context here because the genCtx has been
486 // cancelled.
487 updateErr := a.messages.Update(ctx, *currentAssistant)
488 if updateErr != nil {
489 return nil, updateErr
490 }
491 return nil, err
492 }
493 wg.Wait()
494
495 if shouldSummarize {
496 a.activeRequests.Del(call.SessionID)
497 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
498 return nil, summarizeErr
499 }
500 // If the agent wasn't done...
501 if len(currentAssistant.ToolCalls()) > 0 {
502 existing, ok := a.messageQueue.Get(call.SessionID)
503 if !ok {
504 existing = []SessionAgentCall{}
505 }
506 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
507 existing = append(existing, call)
508 a.messageQueue.Set(call.SessionID, existing)
509 }
510 }
511
512 // Release active request before processing queued messages.
513 a.activeRequests.Del(call.SessionID)
514 cancel()
515
516 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
517 if !ok || len(queuedMessages) == 0 {
518 return result, err
519 }
520 // There are queued messages restart the loop.
521 firstQueuedMessage := queuedMessages[0]
522 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
523 return a.Run(ctx, firstQueuedMessage)
524}
525
526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
527 if a.IsSessionBusy(sessionID) {
528 return ErrSessionBusy
529 }
530
531 currentSession, err := a.sessions.Get(ctx, sessionID)
532 if err != nil {
533 return fmt.Errorf("failed to get session: %w", err)
534 }
535 msgs, err := a.getSessionMessages(ctx, currentSession)
536 if err != nil {
537 return err
538 }
539 if len(msgs) == 0 {
540 // Nothing to summarize.
541 return nil
542 }
543
544 aiMsgs, _ := a.preparePrompt(msgs)
545
546 genCtx, cancel := context.WithCancel(ctx)
547 a.activeRequests.Set(sessionID, cancel)
548 defer a.activeRequests.Del(sessionID)
549 defer cancel()
550
551 agent := fantasy.NewAgent(a.largeModel.Model,
552 fantasy.WithSystemPrompt(string(summaryPrompt)),
553 )
554 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
555 Role: message.Assistant,
556 Model: a.largeModel.Model.Model(),
557 Provider: a.largeModel.Model.Provider(),
558 IsSummaryMessage: true,
559 })
560 if err != nil {
561 return err
562 }
563
564 summaryPromptText := "Provide a detailed summary of our conversation above."
565 if len(currentSession.Todos) > 0 {
566 summaryPromptText += "\n\n## Current Todo List\n\n"
567 for _, t := range currentSession.Todos {
568 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
569 }
570 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
571 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
572 }
573
574 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
575 Prompt: summaryPromptText,
576 Messages: aiMsgs,
577 ProviderOptions: opts,
578 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
579 prepared.Messages = options.Messages
580 if a.systemPromptPrefix != "" {
581 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
582 }
583 return callContext, prepared, nil
584 },
585 OnReasoningDelta: func(id string, text string) error {
586 summaryMessage.AppendReasoningContent(text)
587 return a.messages.Update(genCtx, summaryMessage)
588 },
589 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
590 // Handle anthropic signature.
591 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
592 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
593 summaryMessage.AppendReasoningSignature(signature.Signature)
594 }
595 }
596 summaryMessage.FinishThinking()
597 return a.messages.Update(genCtx, summaryMessage)
598 },
599 OnTextDelta: func(id, text string) error {
600 summaryMessage.AppendContent(text)
601 return a.messages.Update(genCtx, summaryMessage)
602 },
603 })
604 if err != nil {
605 isCancelErr := errors.Is(err, context.Canceled)
606 if isCancelErr {
607 // User cancelled summarize we need to remove the summary message.
608 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
609 return deleteErr
610 }
611 return err
612 }
613
614 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
615 err = a.messages.Update(genCtx, summaryMessage)
616 if err != nil {
617 return err
618 }
619
620 var openrouterCost *float64
621 for _, step := range resp.Steps {
622 stepCost := a.openrouterCost(step.ProviderMetadata)
623 if stepCost != nil {
624 newCost := *stepCost
625 if openrouterCost != nil {
626 newCost += *openrouterCost
627 }
628 openrouterCost = &newCost
629 }
630 }
631
632 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
633
634 // Just in case, get just the last usage info.
635 usage := resp.Response.Usage
636 currentSession.SummaryMessageID = summaryMessage.ID
637 currentSession.CompletionTokens = usage.OutputTokens
638 currentSession.PromptTokens = 0
639 _, err = a.sessions.Save(genCtx, currentSession)
640 return err
641}
642
643func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
644 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
645 return fantasy.ProviderOptions{}
646 }
647 return fantasy.ProviderOptions{
648 anthropic.Name: &anthropic.ProviderCacheControlOptions{
649 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
650 },
651 bedrock.Name: &anthropic.ProviderCacheControlOptions{
652 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
653 },
654 }
655}
656
657func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
658 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
659 var attachmentParts []message.ContentPart
660 for _, attachment := range call.Attachments {
661 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
662 }
663 parts = append(parts, attachmentParts...)
664 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
665 Role: message.User,
666 Parts: parts,
667 })
668 if err != nil {
669 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
670 }
671 return msg, nil
672}
673
674func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
675 var history []fantasy.Message
676 if !a.isSubAgent {
677 history = append(history, fantasy.NewUserMessage(
678 fmt.Sprintf("<system_reminder>%s</system_reminder>",
679 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
680If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
681If not, please feel free to ignore. Again do not mention this message to the user.`,
682 ),
683 ))
684 }
685 for _, m := range msgs {
686 if len(m.Parts) == 0 {
687 continue
688 }
689 // Assistant message without content or tool calls (cancelled before it
690 // returned anything).
691 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
692 continue
693 }
694 history = append(history, m.ToAIMessage()...)
695 }
696
697 var files []fantasy.FilePart
698 for _, attachment := range attachments {
699 if attachment.IsText() {
700 continue
701 }
702 files = append(files, fantasy.FilePart{
703 Filename: attachment.FileName,
704 Data: attachment.Content,
705 MediaType: attachment.MimeType,
706 })
707 }
708
709 return history, files
710}
711
712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
713 msgs, err := a.messages.List(ctx, session.ID)
714 if err != nil {
715 return nil, fmt.Errorf("failed to list messages: %w", err)
716 }
717
718 if session.SummaryMessageID != "" {
719 summaryMsgInex := -1
720 for i, msg := range msgs {
721 if msg.ID == session.SummaryMessageID {
722 summaryMsgInex = i
723 break
724 }
725 }
726 if summaryMsgInex != -1 {
727 msgs = msgs[summaryMsgInex:]
728 msgs[0].Role = message.User
729 }
730 }
731 return msgs, nil
732}
733
734// generateTitle generates a session titled based on the initial prompt.
735func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
736 if userPrompt == "" {
737 return
738 }
739
740 var maxOutputTokens int64 = 40
741 if a.smallModel.CatwalkCfg.CanReason {
742 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
743 }
744
745 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
746 return fantasy.NewAgent(m,
747 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
748 fantasy.WithMaxOutputTokens(tok),
749 )
750 }
751
752 streamCall := fantasy.AgentStreamCall{
753 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
754 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
755 prepared.Messages = opts.Messages
756 if a.systemPromptPrefix != "" {
757 prepared.Messages = append([]fantasy.Message{
758 fantasy.NewSystemMessage(a.systemPromptPrefix),
759 }, prepared.Messages...)
760 }
761 return callCtx, prepared, nil
762 },
763 }
764
765 // Use the small model to generate the title.
766 model := &a.smallModel
767 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
768 resp, err := agent.Stream(ctx, streamCall)
769 if err == nil {
770 // We successfully generated a title with the small model.
771 slog.Info("generated title with small model")
772 } else {
773 // It didn't work. Let's try with the big model.
774 slog.Error("error generating title with small model; trying big model", "err", err)
775 model = &a.largeModel
776 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
777 resp, err = agent.Stream(ctx, streamCall)
778 if err == nil {
779 slog.Info("generated title with large model")
780 } else {
781 // Welp, the large model didn't work either. Use the default
782 // session name and return.
783 slog.Error("error generating title with large model", "err", err)
784 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
785 if saveErr != nil {
786 slog.Error("failed to save session title and usage", "error", saveErr)
787 }
788 return
789 }
790 }
791
792 if resp == nil {
793 // Actually, we didn't get a response so we can't. Use the default
794 // session name and return.
795 slog.Error("response is nil; can't generate title")
796 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
797 if saveErr != nil {
798 slog.Error("failed to save session title and usage", "error", saveErr)
799 }
800 return
801 }
802
803 // Clean up title.
804 var title string
805 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
806 slog.Info("generated title", "title", title)
807
808 // Remove thinking tags if present.
809 title = thinkTagRegex.ReplaceAllString(title, "")
810
811 title = strings.TrimSpace(title)
812 if title == "" {
813 slog.Warn("empty title; using fallback")
814 title = defaultSessionName
815 }
816
817 // Calculate usage and cost.
818 var openrouterCost *float64
819 for _, step := range resp.Steps {
820 stepCost := a.openrouterCost(step.ProviderMetadata)
821 if stepCost != nil {
822 newCost := *stepCost
823 if openrouterCost != nil {
824 newCost += *openrouterCost
825 }
826 openrouterCost = &newCost
827 }
828 }
829
830 modelConfig := model.CatwalkCfg
831 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
832 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
833 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
834 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
835
836 if a.isClaudeCode() {
837 cost = 0
838 }
839
840 // Use override cost if available (e.g., from OpenRouter).
841 if openrouterCost != nil {
842 cost = *openrouterCost
843 }
844
845 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
846 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
847
848 // Atomically update only title and usage fields to avoid overriding other
849 // concurrent session updates.
850 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
851 if saveErr != nil {
852 slog.Error("failed to save session title and usage", "error", saveErr)
853 return
854 }
855}
856
857func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
858 openrouterMetadata, ok := metadata[openrouter.Name]
859 if !ok {
860 return nil
861 }
862
863 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
864 if !ok {
865 return nil
866 }
867 return &opts.Usage.Cost
868}
869
870func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
871 modelConfig := model.CatwalkCfg
872 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
873 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
874 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
875 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
876
877 if a.isClaudeCode() {
878 cost = 0
879 }
880
881 a.eventTokensUsed(session.ID, model, usage, cost)
882
883 if overrideCost != nil {
884 session.Cost += *overrideCost
885 } else {
886 session.Cost += cost
887 }
888
889 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
890 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
891}
892
893func (a *sessionAgent) Cancel(sessionID string) {
894 // Cancel regular requests.
895 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
896 slog.Info("Request cancellation initiated", "session_id", sessionID)
897 cancel()
898 }
899
900 // Also check for summarize requests.
901 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
902 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
903 cancel()
904 }
905
906 if a.QueuedPrompts(sessionID) > 0 {
907 slog.Info("Clearing queued prompts", "session_id", sessionID)
908 a.messageQueue.Del(sessionID)
909 }
910}
911
912func (a *sessionAgent) ClearQueue(sessionID string) {
913 if a.QueuedPrompts(sessionID) > 0 {
914 slog.Info("Clearing queued prompts", "session_id", sessionID)
915 a.messageQueue.Del(sessionID)
916 }
917}
918
919func (a *sessionAgent) CancelAll() {
920 if !a.IsBusy() {
921 return
922 }
923 for key := range a.activeRequests.Seq2() {
924 a.Cancel(key) // key is sessionID
925 }
926
927 timeout := time.After(5 * time.Second)
928 for a.IsBusy() {
929 select {
930 case <-timeout:
931 return
932 default:
933 time.Sleep(200 * time.Millisecond)
934 }
935 }
936}
937
938func (a *sessionAgent) IsBusy() bool {
939 var busy bool
940 for cancelFunc := range a.activeRequests.Seq() {
941 if cancelFunc != nil {
942 busy = true
943 break
944 }
945 }
946 return busy
947}
948
949func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
950 _, busy := a.activeRequests.Get(sessionID)
951 return busy
952}
953
954func (a *sessionAgent) QueuedPrompts(sessionID string) int {
955 l, ok := a.messageQueue.Get(sessionID)
956 if !ok {
957 return 0
958 }
959 return len(l)
960}
961
962func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
963 l, ok := a.messageQueue.Get(sessionID)
964 if !ok {
965 return nil
966 }
967 prompts := make([]string, len(l))
968 for i, call := range l {
969 prompts[i] = call.Prompt
970 }
971 return prompts
972}
973
974func (a *sessionAgent) SetModels(large Model, small Model) {
975 a.largeModel = large
976 a.smallModel = small
977}
978
979func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
980 a.tools = tools
981}
982
983func (a *sessionAgent) Model() Model {
984 return a.largeModel
985}
986
987func (a *sessionAgent) promptPrefix() string {
988 if a.isClaudeCode() {
989 return "You are Claude Code, Anthropic's official CLI for Claude."
990 }
991 return a.systemPromptPrefix
992}
993
994// XXX: this should be generalized to cover other subscription plans, like Copilot.
995func (a *sessionAgent) isClaudeCode() bool {
996 cfg := config.Get()
997 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
998 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
999}
1000
1001// convertToToolResult converts a fantasy tool result to a message tool result.
1002func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1003 baseResult := message.ToolResult{
1004 ToolCallID: result.ToolCallID,
1005 Name: result.ToolName,
1006 Metadata: result.ClientMetadata,
1007 }
1008
1009 switch result.Result.GetType() {
1010 case fantasy.ToolResultContentTypeText:
1011 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1012 baseResult.Content = r.Text
1013 }
1014 case fantasy.ToolResultContentTypeError:
1015 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1016 baseResult.Content = r.Error.Error()
1017 baseResult.IsError = true
1018 }
1019 case fantasy.ToolResultContentTypeMedia:
1020 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1021 content := r.Text
1022 if content == "" {
1023 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1024 }
1025 baseResult.Content = content
1026 baseResult.Data = r.Data
1027 baseResult.MIMEType = r.MediaType
1028 }
1029 }
1030
1031 return baseResult
1032}
1033
1034// workaroundProviderMediaLimitations converts media content in tool results to
1035// user messages for providers that don't natively support images in tool results.
1036//
1037// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1038// don't support sending images/media in tool result messages - they only accept
1039// text in tool results. However, they DO support images in user messages.
1040//
1041// If we send media in tool results to these providers, the API returns an error.
1042//
1043// Solution: For these providers, we:
1044// 1. Replace the media in the tool result with a text placeholder
1045// 2. Inject a user message immediately after with the image as a file attachment
1046// 3. This maintains the tool execution flow while working around API limitations
1047//
1048// Anthropic and Bedrock support images natively in tool results, so we skip
1049// this workaround for them.
1050//
1051// Example transformation:
1052//
1053// BEFORE: [tool result: image data]
1054// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1055func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1056 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1057 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1058
1059 if providerSupportsMedia {
1060 return messages
1061 }
1062
1063 convertedMessages := make([]fantasy.Message, 0, len(messages))
1064
1065 for _, msg := range messages {
1066 if msg.Role != fantasy.MessageRoleTool {
1067 convertedMessages = append(convertedMessages, msg)
1068 continue
1069 }
1070
1071 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1072 var mediaFiles []fantasy.FilePart
1073
1074 for _, part := range msg.Content {
1075 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1076 if !ok {
1077 textParts = append(textParts, part)
1078 continue
1079 }
1080
1081 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1082 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1083 if err != nil {
1084 slog.Warn("failed to decode media data", "error", err)
1085 textParts = append(textParts, part)
1086 continue
1087 }
1088
1089 mediaFiles = append(mediaFiles, fantasy.FilePart{
1090 Data: decoded,
1091 MediaType: media.MediaType,
1092 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1093 })
1094
1095 textParts = append(textParts, fantasy.ToolResultPart{
1096 ToolCallID: toolResult.ToolCallID,
1097 Output: fantasy.ToolResultOutputContentText{
1098 Text: "[Image/media content loaded - see attached file]",
1099 },
1100 ProviderOptions: toolResult.ProviderOptions,
1101 })
1102 } else {
1103 textParts = append(textParts, part)
1104 }
1105 }
1106
1107 convertedMessages = append(convertedMessages, fantasy.Message{
1108 Role: fantasy.MessageRoleTool,
1109 Content: textParts,
1110 })
1111
1112 if len(mediaFiles) > 0 {
1113 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1114 "Here is the media content from the tool result:",
1115 mediaFiles...,
1116 ))
1117 }
1118 }
1119
1120 return convertedMessages
1121}