1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43const defaultSessionName = "Untitled Session"
44
45//go:embed templates/title.md
46var titlePrompt []byte
47
48//go:embed templates/summary.md
49var summaryPrompt []byte
50
51// Used to remove <think> tags from generated titles.
52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
53
54type SessionAgentCall struct {
55 SessionID string
56 Prompt string
57 ProviderOptions fantasy.ProviderOptions
58 Attachments []message.Attachment
59 MaxOutputTokens int64
60 Temperature *float64
61 TopP *float64
62 TopK *int64
63 FrequencyPenalty *float64
64 PresencePenalty *float64
65}
66
67type SessionAgent interface {
68 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
69 SetModels(large Model, small Model)
70 SetTools(tools []fantasy.AgentTool)
71 Cancel(sessionID string)
72 CancelAll()
73 IsSessionBusy(sessionID string) bool
74 IsBusy() bool
75 QueuedPrompts(sessionID string) int
76 QueuedPromptsList(sessionID string) []string
77 ClearQueue(sessionID string)
78 Summarize(context.Context, string, fantasy.ProviderOptions) error
79 Model() Model
80}
81
82type Model struct {
83 Model fantasy.LanguageModel
84 CatwalkCfg catwalk.Model
85 ModelCfg config.SelectedModel
86}
87
88type sessionAgent struct {
89 largeModel Model
90 smallModel Model
91 systemPromptPrefix string
92 systemPrompt string
93 isSubAgent bool
94 tools []fantasy.AgentTool
95 sessions session.Service
96 messages message.Service
97 disableAutoSummarize bool
98 isYolo bool
99
100 messageQueue *csync.Map[string, []SessionAgentCall]
101 activeRequests *csync.Map[string, context.CancelFunc]
102}
103
104type SessionAgentOptions struct {
105 LargeModel Model
106 SmallModel Model
107 SystemPromptPrefix string
108 SystemPrompt string
109 IsSubAgent bool
110 DisableAutoSummarize bool
111 IsYolo bool
112 Sessions session.Service
113 Messages message.Service
114 Tools []fantasy.AgentTool
115}
116
117func NewSessionAgent(
118 opts SessionAgentOptions,
119) SessionAgent {
120 return &sessionAgent{
121 largeModel: opts.LargeModel,
122 smallModel: opts.SmallModel,
123 systemPromptPrefix: opts.SystemPromptPrefix,
124 systemPrompt: opts.SystemPrompt,
125 isSubAgent: opts.IsSubAgent,
126 sessions: opts.Sessions,
127 messages: opts.Messages,
128 disableAutoSummarize: opts.DisableAutoSummarize,
129 tools: opts.Tools,
130 isYolo: opts.IsYolo,
131 messageQueue: csync.NewMap[string, []SessionAgentCall](),
132 activeRequests: csync.NewMap[string, context.CancelFunc](),
133 }
134}
135
136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
137 if call.Prompt == "" {
138 return nil, ErrEmptyPrompt
139 }
140 if call.SessionID == "" {
141 return nil, ErrSessionMissing
142 }
143
144 // Queue the message if busy
145 if a.IsSessionBusy(call.SessionID) {
146 existing, ok := a.messageQueue.Get(call.SessionID)
147 if !ok {
148 existing = []SessionAgentCall{}
149 }
150 existing = append(existing, call)
151 a.messageQueue.Set(call.SessionID, existing)
152 return nil, nil
153 }
154
155 if len(a.tools) > 0 {
156 // Add Anthropic caching to the last tool.
157 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
158 }
159
160 agent := fantasy.NewAgent(
161 a.largeModel.Model,
162 fantasy.WithSystemPrompt(a.systemPrompt),
163 fantasy.WithTools(a.tools...),
164 )
165
166 sessionLock := sync.Mutex{}
167 currentSession, err := a.sessions.Get(ctx, call.SessionID)
168 if err != nil {
169 return nil, fmt.Errorf("failed to get session: %w", err)
170 }
171
172 msgs, err := a.getSessionMessages(ctx, currentSession)
173 if err != nil {
174 return nil, fmt.Errorf("failed to get session messages: %w", err)
175 }
176
177 var wg sync.WaitGroup
178 // Generate title if first message.
179 if len(msgs) == 0 {
180 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
181 wg.Go(func() {
182 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
183 })
184 }
185
186 // Add the user message to the session.
187 _, err = a.createUserMessage(ctx, call)
188 if err != nil {
189 return nil, err
190 }
191
192 // Add the session to the context.
193 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
194
195 genCtx, cancel := context.WithCancel(ctx)
196 a.activeRequests.Set(call.SessionID, cancel)
197
198 defer cancel()
199 defer a.activeRequests.Del(call.SessionID)
200
201 history, files := a.preparePrompt(msgs, call.Attachments...)
202
203 startTime := time.Now()
204 a.eventPromptSent(call.SessionID)
205
206 var currentAssistant *message.Message
207 var shouldSummarize bool
208 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
209 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
210 Files: files,
211 Messages: history,
212 ProviderOptions: call.ProviderOptions,
213 MaxOutputTokens: &call.MaxOutputTokens,
214 TopP: call.TopP,
215 Temperature: call.Temperature,
216 PresencePenalty: call.PresencePenalty,
217 TopK: call.TopK,
218 FrequencyPenalty: call.FrequencyPenalty,
219 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
220 prepared.Messages = options.Messages
221 for i := range prepared.Messages {
222 prepared.Messages[i].ProviderOptions = nil
223 }
224
225 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
226 a.messageQueue.Del(call.SessionID)
227 for _, queued := range queuedCalls {
228 userMessage, createErr := a.createUserMessage(callContext, queued)
229 if createErr != nil {
230 return callContext, prepared, createErr
231 }
232 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
233 }
234
235 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
236
237 lastSystemRoleInx := 0
238 systemMessageUpdated := false
239 for i, msg := range prepared.Messages {
240 // Only add cache control to the last message.
241 if msg.Role == fantasy.MessageRoleSystem {
242 lastSystemRoleInx = i
243 } else if !systemMessageUpdated {
244 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
245 systemMessageUpdated = true
246 }
247 // Than add cache control to the last 2 messages.
248 if i > len(prepared.Messages)-3 {
249 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
250 }
251 }
252
253 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
254 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
255 }
256
257 var assistantMsg message.Message
258 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
259 Role: message.Assistant,
260 Parts: []message.ContentPart{},
261 Model: a.largeModel.ModelCfg.Model,
262 Provider: a.largeModel.ModelCfg.Provider,
263 })
264 if err != nil {
265 return callContext, prepared, err
266 }
267 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
268 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
269 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
270 currentAssistant = &assistantMsg
271 return callContext, prepared, err
272 },
273 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
274 currentAssistant.AppendReasoningContent(reasoning.Text)
275 return a.messages.Update(genCtx, *currentAssistant)
276 },
277 OnReasoningDelta: func(id string, text string) error {
278 currentAssistant.AppendReasoningContent(text)
279 return a.messages.Update(genCtx, *currentAssistant)
280 },
281 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
282 // handle anthropic signature
283 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
284 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
285 currentAssistant.AppendReasoningSignature(reasoning.Signature)
286 }
287 }
288 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
289 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
290 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
291 }
292 }
293 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
294 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
295 currentAssistant.SetReasoningResponsesData(reasoning)
296 }
297 }
298 currentAssistant.FinishThinking()
299 return a.messages.Update(genCtx, *currentAssistant)
300 },
301 OnTextDelta: func(id string, text string) error {
302 // Strip leading newline from initial text content. This is is
303 // particularly important in non-interactive mode where leading
304 // newlines are very visible.
305 if len(currentAssistant.Parts) == 0 {
306 text = strings.TrimPrefix(text, "\n")
307 }
308
309 currentAssistant.AppendContent(text)
310 return a.messages.Update(genCtx, *currentAssistant)
311 },
312 OnToolInputStart: func(id string, toolName string) error {
313 toolCall := message.ToolCall{
314 ID: id,
315 Name: toolName,
316 ProviderExecuted: false,
317 Finished: false,
318 }
319 currentAssistant.AddToolCall(toolCall)
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
323 // TODO: implement
324 },
325 OnToolCall: func(tc fantasy.ToolCallContent) error {
326 toolCall := message.ToolCall{
327 ID: tc.ToolCallID,
328 Name: tc.ToolName,
329 Input: tc.Input,
330 ProviderExecuted: false,
331 Finished: true,
332 }
333 currentAssistant.AddToolCall(toolCall)
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnToolResult: func(result fantasy.ToolResultContent) error {
337 toolResult := a.convertToToolResult(result)
338 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
339 Role: message.Tool,
340 Parts: []message.ContentPart{
341 toolResult,
342 },
343 })
344 return createMsgErr
345 },
346 OnStepFinish: func(stepResult fantasy.StepResult) error {
347 finishReason := message.FinishReasonUnknown
348 switch stepResult.FinishReason {
349 case fantasy.FinishReasonLength:
350 finishReason = message.FinishReasonMaxTokens
351 case fantasy.FinishReasonStop:
352 finishReason = message.FinishReasonEndTurn
353 case fantasy.FinishReasonToolCalls:
354 finishReason = message.FinishReasonToolUse
355 }
356 currentAssistant.AddFinish(finishReason, "", "")
357 sessionLock.Lock()
358 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
359 if getSessionErr != nil {
360 sessionLock.Unlock()
361 return getSessionErr
362 }
363 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
364 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
365 sessionLock.Unlock()
366 if sessionErr != nil {
367 return sessionErr
368 }
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 StopWhen: []fantasy.StopCondition{
372 func(_ []fantasy.StepResult) bool {
373 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
374 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
375 remaining := cw - tokens
376 var threshold int64
377 if cw > 200_000 {
378 threshold = 20_000
379 } else {
380 threshold = int64(float64(cw) * 0.2)
381 }
382 if (remaining <= threshold) && !a.disableAutoSummarize {
383 shouldSummarize = true
384 return true
385 }
386 return false
387 },
388 },
389 })
390
391 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
392
393 if err != nil {
394 isCancelErr := errors.Is(err, context.Canceled)
395 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
396 if currentAssistant == nil {
397 return result, err
398 }
399 // Ensure we finish thinking on error to close the reasoning state.
400 currentAssistant.FinishThinking()
401 toolCalls := currentAssistant.ToolCalls()
402 // INFO: we use the parent context here because the genCtx has been cancelled.
403 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
404 if createErr != nil {
405 return nil, createErr
406 }
407 for _, tc := range toolCalls {
408 if !tc.Finished {
409 tc.Finished = true
410 tc.Input = "{}"
411 currentAssistant.AddToolCall(tc)
412 updateErr := a.messages.Update(ctx, *currentAssistant)
413 if updateErr != nil {
414 return nil, updateErr
415 }
416 }
417
418 found := false
419 for _, msg := range msgs {
420 if msg.Role == message.Tool {
421 for _, tr := range msg.ToolResults() {
422 if tr.ToolCallID == tc.ID {
423 found = true
424 break
425 }
426 }
427 }
428 if found {
429 break
430 }
431 }
432 if found {
433 continue
434 }
435 content := "There was an error while executing the tool"
436 if isCancelErr {
437 content = "Tool execution canceled by user"
438 } else if isPermissionErr {
439 content = "User denied permission"
440 }
441 toolResult := message.ToolResult{
442 ToolCallID: tc.ID,
443 Name: tc.Name,
444 Content: content,
445 IsError: true,
446 }
447 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
448 Role: message.Tool,
449 Parts: []message.ContentPart{
450 toolResult,
451 },
452 })
453 if createErr != nil {
454 return nil, createErr
455 }
456 }
457 var fantasyErr *fantasy.Error
458 var providerErr *fantasy.ProviderError
459 const defaultTitle = "Provider Error"
460 if isCancelErr {
461 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
462 } else if isPermissionErr {
463 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
464 } else if errors.Is(err, hyper.ErrNoCredits) {
465 url := hyper.BaseURL()
466 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
467 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
468 } else if errors.As(err, &providerErr) {
469 if providerErr.Message == "The requested model is not supported." {
470 url := "https://github.com/settings/copilot/features"
471 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
472 currentAssistant.AddFinish(
473 message.FinishReasonError,
474 "Copilot model not enabled",
475 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
476 )
477 } else {
478 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
479 }
480 } else if errors.As(err, &fantasyErr) {
481 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
482 } else {
483 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
484 }
485 // Note: we use the parent context here because the genCtx has been
486 // cancelled.
487 updateErr := a.messages.Update(ctx, *currentAssistant)
488 if updateErr != nil {
489 return nil, updateErr
490 }
491 return nil, err
492 }
493 wg.Wait()
494
495 if shouldSummarize {
496 a.activeRequests.Del(call.SessionID)
497 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
498 return nil, summarizeErr
499 }
500 // If the agent wasn't done...
501 if len(currentAssistant.ToolCalls()) > 0 {
502 existing, ok := a.messageQueue.Get(call.SessionID)
503 if !ok {
504 existing = []SessionAgentCall{}
505 }
506 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
507 existing = append(existing, call)
508 a.messageQueue.Set(call.SessionID, existing)
509 }
510 }
511
512 // Release active request before processing queued messages.
513 a.activeRequests.Del(call.SessionID)
514 cancel()
515
516 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
517 if !ok || len(queuedMessages) == 0 {
518 return result, err
519 }
520 // There are queued messages restart the loop.
521 firstQueuedMessage := queuedMessages[0]
522 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
523 return a.Run(ctx, firstQueuedMessage)
524}
525
526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
527 if a.IsSessionBusy(sessionID) {
528 return ErrSessionBusy
529 }
530
531 currentSession, err := a.sessions.Get(ctx, sessionID)
532 if err != nil {
533 return fmt.Errorf("failed to get session: %w", err)
534 }
535 msgs, err := a.getSessionMessages(ctx, currentSession)
536 if err != nil {
537 return err
538 }
539 if len(msgs) == 0 {
540 // Nothing to summarize.
541 return nil
542 }
543
544 aiMsgs, _ := a.preparePrompt(msgs)
545
546 genCtx, cancel := context.WithCancel(ctx)
547 a.activeRequests.Set(sessionID, cancel)
548 defer a.activeRequests.Del(sessionID)
549 defer cancel()
550
551 agent := fantasy.NewAgent(a.largeModel.Model,
552 fantasy.WithSystemPrompt(string(summaryPrompt)),
553 )
554 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
555 Role: message.Assistant,
556 Model: a.largeModel.Model.Model(),
557 Provider: a.largeModel.Model.Provider(),
558 IsSummaryMessage: true,
559 })
560 if err != nil {
561 return err
562 }
563
564 summaryPromptText := "Provide a detailed summary of our conversation above."
565 if len(currentSession.Todos) > 0 {
566 summaryPromptText += "\n\n## Current Todo List\n\n"
567 for _, t := range currentSession.Todos {
568 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
569 }
570 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
571 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
572 }
573
574 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
575 Prompt: summaryPromptText,
576 Messages: aiMsgs,
577 ProviderOptions: opts,
578 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
579 prepared.Messages = options.Messages
580 if a.systemPromptPrefix != "" {
581 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
582 }
583 return callContext, prepared, nil
584 },
585 OnReasoningDelta: func(id string, text string) error {
586 summaryMessage.AppendReasoningContent(text)
587 return a.messages.Update(genCtx, summaryMessage)
588 },
589 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
590 // Handle anthropic signature.
591 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
592 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
593 summaryMessage.AppendReasoningSignature(signature.Signature)
594 }
595 }
596 summaryMessage.FinishThinking()
597 return a.messages.Update(genCtx, summaryMessage)
598 },
599 OnTextDelta: func(id, text string) error {
600 summaryMessage.AppendContent(text)
601 return a.messages.Update(genCtx, summaryMessage)
602 },
603 })
604 if err != nil {
605 isCancelErr := errors.Is(err, context.Canceled)
606 if isCancelErr {
607 // User cancelled summarize we need to remove the summary message.
608 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
609 return deleteErr
610 }
611 return err
612 }
613
614 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
615 err = a.messages.Update(genCtx, summaryMessage)
616 if err != nil {
617 return err
618 }
619
620 var openrouterCost *float64
621 for _, step := range resp.Steps {
622 stepCost := a.openrouterCost(step.ProviderMetadata)
623 if stepCost != nil {
624 newCost := *stepCost
625 if openrouterCost != nil {
626 newCost += *openrouterCost
627 }
628 openrouterCost = &newCost
629 }
630 }
631
632 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
633
634 // Just in case, get just the last usage info.
635 usage := resp.Response.Usage
636 currentSession.SummaryMessageID = summaryMessage.ID
637 currentSession.CompletionTokens = usage.OutputTokens
638 currentSession.PromptTokens = 0
639 _, err = a.sessions.Save(genCtx, currentSession)
640 return err
641}
642
643func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
644 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
645 return fantasy.ProviderOptions{}
646 }
647 return fantasy.ProviderOptions{
648 anthropic.Name: &anthropic.ProviderCacheControlOptions{
649 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
650 },
651 bedrock.Name: &anthropic.ProviderCacheControlOptions{
652 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
653 },
654 }
655}
656
657func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
658 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
659 var attachmentParts []message.ContentPart
660 for _, attachment := range call.Attachments {
661 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
662 }
663 parts = append(parts, attachmentParts...)
664 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
665 Role: message.User,
666 Parts: parts,
667 })
668 if err != nil {
669 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
670 }
671 return msg, nil
672}
673
674func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
675 var history []fantasy.Message
676 if !a.isSubAgent {
677 history = append(history, fantasy.NewUserMessage(
678 fmt.Sprintf("<system_reminder>%s</system_reminder>",
679 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
680If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
681If not, please feel free to ignore. Again do not mention this message to the user.`,
682 ),
683 ))
684 }
685 for _, m := range msgs {
686 if len(m.Parts) == 0 {
687 continue
688 }
689 // Assistant message without content or tool calls (cancelled before it
690 // returned anything).
691 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
692 continue
693 }
694 history = append(history, m.ToAIMessage()...)
695 }
696
697 var files []fantasy.FilePart
698 for _, attachment := range attachments {
699 if attachment.IsText() {
700 continue
701 }
702 files = append(files, fantasy.FilePart{
703 Filename: attachment.FileName,
704 Data: attachment.Content,
705 MediaType: attachment.MimeType,
706 })
707 }
708
709 return history, files
710}
711
712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
713 msgs, err := a.messages.List(ctx, session.ID)
714 if err != nil {
715 return nil, fmt.Errorf("failed to list messages: %w", err)
716 }
717
718 if session.SummaryMessageID != "" {
719 summaryMsgInex := -1
720 for i, msg := range msgs {
721 if msg.ID == session.SummaryMessageID {
722 summaryMsgInex = i
723 break
724 }
725 }
726 if summaryMsgInex != -1 {
727 msgs = msgs[summaryMsgInex:]
728 msgs[0].Role = message.User
729 }
730 }
731 return msgs, nil
732}
733
734// generateTitle generates a session titled based on the initial prompt.
735func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
736 if userPrompt == "" {
737 return
738 }
739
740 var maxOutputTokens int64 = 40
741 if a.smallModel.CatwalkCfg.CanReason {
742 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
743 }
744
745 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
746 return fantasy.NewAgent(m,
747 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
748 fantasy.WithMaxOutputTokens(tok),
749 )
750 }
751
752 streamCall := fantasy.AgentStreamCall{
753 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
754 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
755 prepared.Messages = opts.Messages
756 if a.systemPromptPrefix != "" {
757 prepared.Messages = append([]fantasy.Message{
758 fantasy.NewSystemMessage(a.systemPromptPrefix),
759 }, prepared.Messages...)
760 }
761 return callCtx, prepared, nil
762 },
763 }
764
765 // Use the small model to generate the title.
766 model := &a.smallModel
767 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
768 resp, err := agent.Stream(ctx, streamCall)
769 if err == nil {
770 // We successfully generated a title with the small model.
771 slog.Info("generated title with small model")
772 } else {
773 // It didn't work. Let's try with the big model.
774 slog.Error("error generating title with small model; trying big model", "err", err)
775 model = &a.largeModel
776 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
777 resp, err = agent.Stream(ctx, streamCall)
778 if err == nil {
779 slog.Info("generated title with large model")
780 } else {
781 // Welp, the large model didn't work either.
782 slog.Error("error generating title with large model", "err", err)
783 }
784 }
785
786 title := strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
787 slog.Info("generated title", "title", title)
788
789 // Remove thinking tags if present.
790 title = thinkTagRegex.ReplaceAllString(title, "")
791
792 title = strings.TrimSpace(title)
793 if title == "" {
794 slog.Warn("empty title; using fallback")
795 title = defaultSessionName
796 }
797
798 // Calculate usage and cost.
799 var openrouterCost *float64
800 for _, step := range resp.Steps {
801 stepCost := a.openrouterCost(step.ProviderMetadata)
802 if stepCost != nil {
803 newCost := *stepCost
804 if openrouterCost != nil {
805 newCost += *openrouterCost
806 }
807 openrouterCost = &newCost
808 }
809 }
810
811 if model == nil {
812 slog.Error("no model available for cost calculation")
813 return
814 }
815 modelConfig := model.CatwalkCfg
816 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
817 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
818 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
819 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
820
821 if a.isClaudeCode() {
822 cost = 0
823 }
824
825 // Use override cost if available (e.g., from OpenRouter).
826 if openrouterCost != nil {
827 cost = *openrouterCost
828 }
829
830 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
831 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
832
833 // Atomically update only title and usage fields to avoid overriding other
834 // concurrent session updates.
835 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
836 if saveErr != nil {
837 slog.Error("failed to save session title and usage", "error", saveErr)
838 return
839 }
840}
841
842func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
843 openrouterMetadata, ok := metadata[openrouter.Name]
844 if !ok {
845 return nil
846 }
847
848 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
849 if !ok {
850 return nil
851 }
852 return &opts.Usage.Cost
853}
854
855func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
856 modelConfig := model.CatwalkCfg
857 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
858 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
859 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
860 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
861
862 if a.isClaudeCode() {
863 cost = 0
864 }
865
866 a.eventTokensUsed(session.ID, model, usage, cost)
867
868 if overrideCost != nil {
869 session.Cost += *overrideCost
870 } else {
871 session.Cost += cost
872 }
873
874 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
875 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
876}
877
878func (a *sessionAgent) Cancel(sessionID string) {
879 // Cancel regular requests.
880 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
881 slog.Info("Request cancellation initiated", "session_id", sessionID)
882 cancel()
883 }
884
885 // Also check for summarize requests.
886 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
887 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
888 cancel()
889 }
890
891 if a.QueuedPrompts(sessionID) > 0 {
892 slog.Info("Clearing queued prompts", "session_id", sessionID)
893 a.messageQueue.Del(sessionID)
894 }
895}
896
897func (a *sessionAgent) ClearQueue(sessionID string) {
898 if a.QueuedPrompts(sessionID) > 0 {
899 slog.Info("Clearing queued prompts", "session_id", sessionID)
900 a.messageQueue.Del(sessionID)
901 }
902}
903
904func (a *sessionAgent) CancelAll() {
905 if !a.IsBusy() {
906 return
907 }
908 for key := range a.activeRequests.Seq2() {
909 a.Cancel(key) // key is sessionID
910 }
911
912 timeout := time.After(5 * time.Second)
913 for a.IsBusy() {
914 select {
915 case <-timeout:
916 return
917 default:
918 time.Sleep(200 * time.Millisecond)
919 }
920 }
921}
922
923func (a *sessionAgent) IsBusy() bool {
924 var busy bool
925 for cancelFunc := range a.activeRequests.Seq() {
926 if cancelFunc != nil {
927 busy = true
928 break
929 }
930 }
931 return busy
932}
933
934func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
935 _, busy := a.activeRequests.Get(sessionID)
936 return busy
937}
938
939func (a *sessionAgent) QueuedPrompts(sessionID string) int {
940 l, ok := a.messageQueue.Get(sessionID)
941 if !ok {
942 return 0
943 }
944 return len(l)
945}
946
947func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
948 l, ok := a.messageQueue.Get(sessionID)
949 if !ok {
950 return nil
951 }
952 prompts := make([]string, len(l))
953 for i, call := range l {
954 prompts[i] = call.Prompt
955 }
956 return prompts
957}
958
959func (a *sessionAgent) SetModels(large Model, small Model) {
960 a.largeModel = large
961 a.smallModel = small
962}
963
964func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
965 a.tools = tools
966}
967
968func (a *sessionAgent) Model() Model {
969 return a.largeModel
970}
971
972func (a *sessionAgent) promptPrefix() string {
973 if a.isClaudeCode() {
974 return "You are Claude Code, Anthropic's official CLI for Claude."
975 }
976 return a.systemPromptPrefix
977}
978
979// XXX: this should be generalized to cover other subscription plans, like Copilot.
980func (a *sessionAgent) isClaudeCode() bool {
981 cfg := config.Get()
982 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
983 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
984}
985
986// convertToToolResult converts a fantasy tool result to a message tool result.
987func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
988 baseResult := message.ToolResult{
989 ToolCallID: result.ToolCallID,
990 Name: result.ToolName,
991 Metadata: result.ClientMetadata,
992 }
993
994 switch result.Result.GetType() {
995 case fantasy.ToolResultContentTypeText:
996 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
997 baseResult.Content = r.Text
998 }
999 case fantasy.ToolResultContentTypeError:
1000 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1001 baseResult.Content = r.Error.Error()
1002 baseResult.IsError = true
1003 }
1004 case fantasy.ToolResultContentTypeMedia:
1005 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1006 content := r.Text
1007 if content == "" {
1008 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1009 }
1010 baseResult.Content = content
1011 baseResult.Data = r.Data
1012 baseResult.MIMEType = r.MediaType
1013 }
1014 }
1015
1016 return baseResult
1017}
1018
1019// workaroundProviderMediaLimitations converts media content in tool results to
1020// user messages for providers that don't natively support images in tool results.
1021//
1022// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1023// don't support sending images/media in tool result messages - they only accept
1024// text in tool results. However, they DO support images in user messages.
1025//
1026// If we send media in tool results to these providers, the API returns an error.
1027//
1028// Solution: For these providers, we:
1029// 1. Replace the media in the tool result with a text placeholder
1030// 2. Inject a user message immediately after with the image as a file attachment
1031// 3. This maintains the tool execution flow while working around API limitations
1032//
1033// Anthropic and Bedrock support images natively in tool results, so we skip
1034// this workaround for them.
1035//
1036// Example transformation:
1037//
1038// BEFORE: [tool result: image data]
1039// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1040func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1041 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1042 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1043
1044 if providerSupportsMedia {
1045 return messages
1046 }
1047
1048 convertedMessages := make([]fantasy.Message, 0, len(messages))
1049
1050 for _, msg := range messages {
1051 if msg.Role != fantasy.MessageRoleTool {
1052 convertedMessages = append(convertedMessages, msg)
1053 continue
1054 }
1055
1056 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1057 var mediaFiles []fantasy.FilePart
1058
1059 for _, part := range msg.Content {
1060 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1061 if !ok {
1062 textParts = append(textParts, part)
1063 continue
1064 }
1065
1066 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1067 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1068 if err != nil {
1069 slog.Warn("failed to decode media data", "error", err)
1070 textParts = append(textParts, part)
1071 continue
1072 }
1073
1074 mediaFiles = append(mediaFiles, fantasy.FilePart{
1075 Data: decoded,
1076 MediaType: media.MediaType,
1077 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1078 })
1079
1080 textParts = append(textParts, fantasy.ToolResultPart{
1081 ToolCallID: toolResult.ToolCallID,
1082 Output: fantasy.ToolResultOutputContentText{
1083 Text: "[Image/media content loaded - see attached file]",
1084 },
1085 ProviderOptions: toolResult.ProviderOptions,
1086 })
1087 } else {
1088 textParts = append(textParts, part)
1089 }
1090 }
1091
1092 convertedMessages = append(convertedMessages, fantasy.Message{
1093 Role: fantasy.MessageRoleTool,
1094 Content: textParts,
1095 })
1096
1097 if len(mediaFiles) > 0 {
1098 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1099 "Here is the media content from the tool result:",
1100 mediaFiles...,
1101 ))
1102 }
1103 }
1104
1105 return convertedMessages
1106}