1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43//go:embed templates/title.md
44var titlePrompt []byte
45
46//go:embed templates/summary.md
47var summaryPrompt []byte
48
49// Used to remove <think> tags from generated titles.
50var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
51
52type SessionAgentCall struct {
53 SessionID string
54 Prompt string
55 ProviderOptions fantasy.ProviderOptions
56 Attachments []message.Attachment
57 MaxOutputTokens int64
58 Temperature *float64
59 TopP *float64
60 TopK *int64
61 FrequencyPenalty *float64
62 PresencePenalty *float64
63}
64
65type SessionAgent interface {
66 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
67 SetModels(large Model, small Model)
68 SetTools(tools []fantasy.AgentTool)
69 Cancel(sessionID string)
70 CancelAll()
71 IsSessionBusy(sessionID string) bool
72 IsBusy() bool
73 QueuedPrompts(sessionID string) int
74 QueuedPromptsList(sessionID string) []string
75 ClearQueue(sessionID string)
76 Summarize(context.Context, string, fantasy.ProviderOptions) error
77 Model() Model
78}
79
80type Model struct {
81 Model fantasy.LanguageModel
82 CatwalkCfg catwalk.Model
83 ModelCfg config.SelectedModel
84}
85
86type sessionAgent struct {
87 largeModel Model
88 smallModel Model
89 systemPromptPrefix string
90 systemPrompt string
91 isSubAgent bool
92 tools []fantasy.AgentTool
93 sessions session.Service
94 messages message.Service
95 disableAutoSummarize bool
96 isYolo bool
97
98 messageQueue *csync.Map[string, []SessionAgentCall]
99 activeRequests *csync.Map[string, context.CancelFunc]
100}
101
102type SessionAgentOptions struct {
103 LargeModel Model
104 SmallModel Model
105 SystemPromptPrefix string
106 SystemPrompt string
107 IsSubAgent bool
108 DisableAutoSummarize bool
109 IsYolo bool
110 Sessions session.Service
111 Messages message.Service
112 Tools []fantasy.AgentTool
113}
114
115func NewSessionAgent(
116 opts SessionAgentOptions,
117) SessionAgent {
118 return &sessionAgent{
119 largeModel: opts.LargeModel,
120 smallModel: opts.SmallModel,
121 systemPromptPrefix: opts.SystemPromptPrefix,
122 systemPrompt: opts.SystemPrompt,
123 isSubAgent: opts.IsSubAgent,
124 sessions: opts.Sessions,
125 messages: opts.Messages,
126 disableAutoSummarize: opts.DisableAutoSummarize,
127 tools: opts.Tools,
128 isYolo: opts.IsYolo,
129 messageQueue: csync.NewMap[string, []SessionAgentCall](),
130 activeRequests: csync.NewMap[string, context.CancelFunc](),
131 }
132}
133
134func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
135 if call.Prompt == "" {
136 return nil, ErrEmptyPrompt
137 }
138 if call.SessionID == "" {
139 return nil, ErrSessionMissing
140 }
141
142 // Queue the message if busy
143 if a.IsSessionBusy(call.SessionID) {
144 existing, ok := a.messageQueue.Get(call.SessionID)
145 if !ok {
146 existing = []SessionAgentCall{}
147 }
148 existing = append(existing, call)
149 a.messageQueue.Set(call.SessionID, existing)
150 return nil, nil
151 }
152
153 if len(a.tools) > 0 {
154 // Add Anthropic caching to the last tool.
155 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
156 }
157
158 agent := fantasy.NewAgent(
159 a.largeModel.Model,
160 fantasy.WithSystemPrompt(a.systemPrompt),
161 fantasy.WithTools(a.tools...),
162 )
163
164 sessionLock := sync.Mutex{}
165 currentSession, err := a.sessions.Get(ctx, call.SessionID)
166 if err != nil {
167 return nil, fmt.Errorf("failed to get session: %w", err)
168 }
169
170 msgs, err := a.getSessionMessages(ctx, currentSession)
171 if err != nil {
172 return nil, fmt.Errorf("failed to get session messages: %w", err)
173 }
174
175 var wg sync.WaitGroup
176 // Generate title if first message.
177 if len(msgs) == 0 {
178 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
179 wg.Go(func() {
180 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
181 })
182 }
183
184 // Add the user message to the session.
185 _, err = a.createUserMessage(ctx, call)
186 if err != nil {
187 return nil, err
188 }
189
190 // Add the session to the context.
191 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
192
193 genCtx, cancel := context.WithCancel(ctx)
194 a.activeRequests.Set(call.SessionID, cancel)
195
196 defer cancel()
197 defer a.activeRequests.Del(call.SessionID)
198
199 history, files := a.preparePrompt(msgs, call.Attachments...)
200
201 startTime := time.Now()
202 a.eventPromptSent(call.SessionID)
203
204 var currentAssistant *message.Message
205 var shouldSummarize bool
206 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
207 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
208 Files: files,
209 Messages: history,
210 ProviderOptions: call.ProviderOptions,
211 MaxOutputTokens: &call.MaxOutputTokens,
212 TopP: call.TopP,
213 Temperature: call.Temperature,
214 PresencePenalty: call.PresencePenalty,
215 TopK: call.TopK,
216 FrequencyPenalty: call.FrequencyPenalty,
217 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
218 prepared.Messages = options.Messages
219 for i := range prepared.Messages {
220 prepared.Messages[i].ProviderOptions = nil
221 }
222
223 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
224 a.messageQueue.Del(call.SessionID)
225 for _, queued := range queuedCalls {
226 userMessage, createErr := a.createUserMessage(callContext, queued)
227 if createErr != nil {
228 return callContext, prepared, createErr
229 }
230 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
231 }
232
233 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
234
235 lastSystemRoleInx := 0
236 systemMessageUpdated := false
237 for i, msg := range prepared.Messages {
238 // Only add cache control to the last message.
239 if msg.Role == fantasy.MessageRoleSystem {
240 lastSystemRoleInx = i
241 } else if !systemMessageUpdated {
242 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
243 systemMessageUpdated = true
244 }
245 // Than add cache control to the last 2 messages.
246 if i > len(prepared.Messages)-3 {
247 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
248 }
249 }
250
251 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
252 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
253 }
254
255 var assistantMsg message.Message
256 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
257 Role: message.Assistant,
258 Parts: []message.ContentPart{},
259 Model: a.largeModel.ModelCfg.Model,
260 Provider: a.largeModel.ModelCfg.Provider,
261 })
262 if err != nil {
263 return callContext, prepared, err
264 }
265 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
266 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
267 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
268 currentAssistant = &assistantMsg
269 return callContext, prepared, err
270 },
271 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
272 currentAssistant.AppendReasoningContent(reasoning.Text)
273 return a.messages.Update(genCtx, *currentAssistant)
274 },
275 OnReasoningDelta: func(id string, text string) error {
276 currentAssistant.AppendReasoningContent(text)
277 return a.messages.Update(genCtx, *currentAssistant)
278 },
279 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
280 // handle anthropic signature
281 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
282 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
283 currentAssistant.AppendReasoningSignature(reasoning.Signature)
284 }
285 }
286 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
287 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
288 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
289 }
290 }
291 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
292 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
293 currentAssistant.SetReasoningResponsesData(reasoning)
294 }
295 }
296 currentAssistant.FinishThinking()
297 return a.messages.Update(genCtx, *currentAssistant)
298 },
299 OnTextDelta: func(id string, text string) error {
300 // Strip leading newline from initial text content. This is is
301 // particularly important in non-interactive mode where leading
302 // newlines are very visible.
303 if len(currentAssistant.Parts) == 0 {
304 text = strings.TrimPrefix(text, "\n")
305 }
306
307 currentAssistant.AppendContent(text)
308 return a.messages.Update(genCtx, *currentAssistant)
309 },
310 OnToolInputStart: func(id string, toolName string) error {
311 toolCall := message.ToolCall{
312 ID: id,
313 Name: toolName,
314 ProviderExecuted: false,
315 Finished: false,
316 }
317 currentAssistant.AddToolCall(toolCall)
318 return a.messages.Update(genCtx, *currentAssistant)
319 },
320 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
321 // TODO: implement
322 },
323 OnToolCall: func(tc fantasy.ToolCallContent) error {
324 toolCall := message.ToolCall{
325 ID: tc.ToolCallID,
326 Name: tc.ToolName,
327 Input: tc.Input,
328 ProviderExecuted: false,
329 Finished: true,
330 }
331 currentAssistant.AddToolCall(toolCall)
332 return a.messages.Update(genCtx, *currentAssistant)
333 },
334 OnToolResult: func(result fantasy.ToolResultContent) error {
335 toolResult := a.convertToToolResult(result)
336 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
337 Role: message.Tool,
338 Parts: []message.ContentPart{
339 toolResult,
340 },
341 })
342 return createMsgErr
343 },
344 OnStepFinish: func(stepResult fantasy.StepResult) error {
345 finishReason := message.FinishReasonUnknown
346 switch stepResult.FinishReason {
347 case fantasy.FinishReasonLength:
348 finishReason = message.FinishReasonMaxTokens
349 case fantasy.FinishReasonStop:
350 finishReason = message.FinishReasonEndTurn
351 case fantasy.FinishReasonToolCalls:
352 finishReason = message.FinishReasonToolUse
353 }
354 currentAssistant.AddFinish(finishReason, "", "")
355 sessionLock.Lock()
356 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
357 if getSessionErr != nil {
358 sessionLock.Unlock()
359 return getSessionErr
360 }
361 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
362 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
363 sessionLock.Unlock()
364 if sessionErr != nil {
365 return sessionErr
366 }
367 return a.messages.Update(genCtx, *currentAssistant)
368 },
369 StopWhen: []fantasy.StopCondition{
370 func(_ []fantasy.StepResult) bool {
371 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
372 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
373 remaining := cw - tokens
374 var threshold int64
375 if cw > 200_000 {
376 threshold = 20_000
377 } else {
378 threshold = int64(float64(cw) * 0.2)
379 }
380 if (remaining <= threshold) && !a.disableAutoSummarize {
381 shouldSummarize = true
382 return true
383 }
384 return false
385 },
386 },
387 })
388
389 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
390
391 if err != nil {
392 isCancelErr := errors.Is(err, context.Canceled)
393 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
394 if currentAssistant == nil {
395 return result, err
396 }
397 // Ensure we finish thinking on error to close the reasoning state.
398 currentAssistant.FinishThinking()
399 toolCalls := currentAssistant.ToolCalls()
400 // INFO: we use the parent context here because the genCtx has been cancelled.
401 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
402 if createErr != nil {
403 return nil, createErr
404 }
405 for _, tc := range toolCalls {
406 if !tc.Finished {
407 tc.Finished = true
408 tc.Input = "{}"
409 currentAssistant.AddToolCall(tc)
410 updateErr := a.messages.Update(ctx, *currentAssistant)
411 if updateErr != nil {
412 return nil, updateErr
413 }
414 }
415
416 found := false
417 for _, msg := range msgs {
418 if msg.Role == message.Tool {
419 for _, tr := range msg.ToolResults() {
420 if tr.ToolCallID == tc.ID {
421 found = true
422 break
423 }
424 }
425 }
426 if found {
427 break
428 }
429 }
430 if found {
431 continue
432 }
433 content := "There was an error while executing the tool"
434 if isCancelErr {
435 content = "Tool execution canceled by user"
436 } else if isPermissionErr {
437 content = "User denied permission"
438 }
439 toolResult := message.ToolResult{
440 ToolCallID: tc.ID,
441 Name: tc.Name,
442 Content: content,
443 IsError: true,
444 }
445 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
446 Role: message.Tool,
447 Parts: []message.ContentPart{
448 toolResult,
449 },
450 })
451 if createErr != nil {
452 return nil, createErr
453 }
454 }
455 var fantasyErr *fantasy.Error
456 var providerErr *fantasy.ProviderError
457 const defaultTitle = "Provider Error"
458 if isCancelErr {
459 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
460 } else if isPermissionErr {
461 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
462 } else if errors.Is(err, hyper.ErrNoCredits) {
463 url := hyper.BaseURL()
464 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
465 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
466 } else if errors.As(err, &providerErr) {
467 if providerErr.Message == "The requested model is not supported." {
468 url := "https://github.com/settings/copilot/features"
469 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
470 currentAssistant.AddFinish(
471 message.FinishReasonError,
472 "Copilot model not enabled",
473 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
474 )
475 } else {
476 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
477 }
478 } else if errors.As(err, &fantasyErr) {
479 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
480 } else {
481 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
482 }
483 // Note: we use the parent context here because the genCtx has been
484 // cancelled.
485 updateErr := a.messages.Update(ctx, *currentAssistant)
486 if updateErr != nil {
487 return nil, updateErr
488 }
489 return nil, err
490 }
491 wg.Wait()
492
493 if shouldSummarize {
494 a.activeRequests.Del(call.SessionID)
495 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
496 return nil, summarizeErr
497 }
498 // If the agent wasn't done...
499 if len(currentAssistant.ToolCalls()) > 0 {
500 existing, ok := a.messageQueue.Get(call.SessionID)
501 if !ok {
502 existing = []SessionAgentCall{}
503 }
504 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
505 existing = append(existing, call)
506 a.messageQueue.Set(call.SessionID, existing)
507 }
508 }
509
510 // Release active request before processing queued messages.
511 a.activeRequests.Del(call.SessionID)
512 cancel()
513
514 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
515 if !ok || len(queuedMessages) == 0 {
516 return result, err
517 }
518 // There are queued messages restart the loop.
519 firstQueuedMessage := queuedMessages[0]
520 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
521 return a.Run(ctx, firstQueuedMessage)
522}
523
524func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
525 if a.IsSessionBusy(sessionID) {
526 return ErrSessionBusy
527 }
528
529 currentSession, err := a.sessions.Get(ctx, sessionID)
530 if err != nil {
531 return fmt.Errorf("failed to get session: %w", err)
532 }
533 msgs, err := a.getSessionMessages(ctx, currentSession)
534 if err != nil {
535 return err
536 }
537 if len(msgs) == 0 {
538 // Nothing to summarize.
539 return nil
540 }
541
542 aiMsgs, _ := a.preparePrompt(msgs)
543
544 genCtx, cancel := context.WithCancel(ctx)
545 a.activeRequests.Set(sessionID, cancel)
546 defer a.activeRequests.Del(sessionID)
547 defer cancel()
548
549 agent := fantasy.NewAgent(a.largeModel.Model,
550 fantasy.WithSystemPrompt(string(summaryPrompt)),
551 )
552 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
553 Role: message.Assistant,
554 Model: a.largeModel.Model.Model(),
555 Provider: a.largeModel.Model.Provider(),
556 IsSummaryMessage: true,
557 })
558 if err != nil {
559 return err
560 }
561
562 summaryPromptText := "Provide a detailed summary of our conversation above."
563 if len(currentSession.Todos) > 0 {
564 summaryPromptText += "\n\n## Current Todo List\n\n"
565 for _, t := range currentSession.Todos {
566 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
567 }
568 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
569 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
570 }
571
572 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
573 Prompt: summaryPromptText,
574 Messages: aiMsgs,
575 ProviderOptions: opts,
576 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
577 prepared.Messages = options.Messages
578 if a.systemPromptPrefix != "" {
579 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
580 }
581 return callContext, prepared, nil
582 },
583 OnReasoningDelta: func(id string, text string) error {
584 summaryMessage.AppendReasoningContent(text)
585 return a.messages.Update(genCtx, summaryMessage)
586 },
587 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
588 // Handle anthropic signature.
589 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
590 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
591 summaryMessage.AppendReasoningSignature(signature.Signature)
592 }
593 }
594 summaryMessage.FinishThinking()
595 return a.messages.Update(genCtx, summaryMessage)
596 },
597 OnTextDelta: func(id, text string) error {
598 summaryMessage.AppendContent(text)
599 return a.messages.Update(genCtx, summaryMessage)
600 },
601 })
602 if err != nil {
603 isCancelErr := errors.Is(err, context.Canceled)
604 if isCancelErr {
605 // User cancelled summarize we need to remove the summary message.
606 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
607 return deleteErr
608 }
609 return err
610 }
611
612 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
613 err = a.messages.Update(genCtx, summaryMessage)
614 if err != nil {
615 return err
616 }
617
618 var openrouterCost *float64
619 for _, step := range resp.Steps {
620 stepCost := a.openrouterCost(step.ProviderMetadata)
621 if stepCost != nil {
622 newCost := *stepCost
623 if openrouterCost != nil {
624 newCost += *openrouterCost
625 }
626 openrouterCost = &newCost
627 }
628 }
629
630 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
631
632 // Just in case, get just the last usage info.
633 usage := resp.Response.Usage
634 currentSession.SummaryMessageID = summaryMessage.ID
635 currentSession.CompletionTokens = usage.OutputTokens
636 currentSession.PromptTokens = 0
637 _, err = a.sessions.Save(genCtx, currentSession)
638 return err
639}
640
641func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
642 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
643 return fantasy.ProviderOptions{}
644 }
645 return fantasy.ProviderOptions{
646 anthropic.Name: &anthropic.ProviderCacheControlOptions{
647 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
648 },
649 bedrock.Name: &anthropic.ProviderCacheControlOptions{
650 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
651 },
652 }
653}
654
655func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
656 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
657 var attachmentParts []message.ContentPart
658 for _, attachment := range call.Attachments {
659 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
660 }
661 parts = append(parts, attachmentParts...)
662 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
663 Role: message.User,
664 Parts: parts,
665 })
666 if err != nil {
667 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
668 }
669 return msg, nil
670}
671
672func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
673 var history []fantasy.Message
674 if !a.isSubAgent {
675 history = append(history, fantasy.NewUserMessage(
676 fmt.Sprintf("<system_reminder>%s</system_reminder>",
677 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
678If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
679If not, please feel free to ignore. Again do not mention this message to the user.`,
680 ),
681 ))
682 }
683 for _, m := range msgs {
684 if len(m.Parts) == 0 {
685 continue
686 }
687 // Assistant message without content or tool calls (cancelled before it
688 // returned anything).
689 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
690 continue
691 }
692 history = append(history, m.ToAIMessage()...)
693 }
694
695 var files []fantasy.FilePart
696 for _, attachment := range attachments {
697 if attachment.IsText() {
698 continue
699 }
700 files = append(files, fantasy.FilePart{
701 Filename: attachment.FileName,
702 Data: attachment.Content,
703 MediaType: attachment.MimeType,
704 })
705 }
706
707 return history, files
708}
709
710func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
711 msgs, err := a.messages.List(ctx, session.ID)
712 if err != nil {
713 return nil, fmt.Errorf("failed to list messages: %w", err)
714 }
715
716 if session.SummaryMessageID != "" {
717 summaryMsgInex := -1
718 for i, msg := range msgs {
719 if msg.ID == session.SummaryMessageID {
720 summaryMsgInex = i
721 break
722 }
723 }
724 if summaryMsgInex != -1 {
725 msgs = msgs[summaryMsgInex:]
726 msgs[0].Role = message.User
727 }
728 }
729 return msgs, nil
730}
731
732func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
733 if prompt == "" {
734 return
735 }
736
737 var maxOutput int64 = 40
738 if a.smallModel.CatwalkCfg.CanReason {
739 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
740 }
741
742 agent := fantasy.NewAgent(a.smallModel.Model,
743 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
744 fantasy.WithMaxOutputTokens(maxOutput),
745 )
746
747 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
748 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
749 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
750 prepared.Messages = options.Messages
751 if a.systemPromptPrefix != "" {
752 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
753 }
754 return callContext, prepared, nil
755 },
756 })
757 if err != nil {
758 slog.Error("error generating title", "err", err)
759 return
760 }
761
762 title := resp.Response.Content.Text()
763
764 title = strings.ReplaceAll(title, "\n", " ")
765
766 // Remove thinking tags if present.
767 title = thinkTagRegex.ReplaceAllString(title, "")
768
769 title = strings.TrimSpace(title)
770 if title == "" {
771 slog.Warn("failed to generate title", "warn", "empty title")
772 return
773 }
774
775 // Calculate usage and cost.
776 var openrouterCost *float64
777 for _, step := range resp.Steps {
778 stepCost := a.openrouterCost(step.ProviderMetadata)
779 if stepCost != nil {
780 newCost := *stepCost
781 if openrouterCost != nil {
782 newCost += *openrouterCost
783 }
784 openrouterCost = &newCost
785 }
786 }
787
788 modelConfig := a.smallModel.CatwalkCfg
789 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
790 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
791 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
792 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
793
794 if a.isClaudeCode() {
795 cost = 0
796 }
797
798 // Use override cost if available (e.g., from OpenRouter).
799 if openrouterCost != nil {
800 cost = *openrouterCost
801 }
802
803 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
804 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
805
806 // Atomically update only title and usage fields to avoid overriding other
807 // concurrent session updates.
808 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
809 if saveErr != nil {
810 slog.Error("failed to save session title & usage", "error", saveErr)
811 return
812 }
813}
814
815func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
816 openrouterMetadata, ok := metadata[openrouter.Name]
817 if !ok {
818 return nil
819 }
820
821 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
822 if !ok {
823 return nil
824 }
825 return &opts.Usage.Cost
826}
827
828func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
829 modelConfig := model.CatwalkCfg
830 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
831 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
832 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
833 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
834
835 if a.isClaudeCode() {
836 cost = 0
837 }
838
839 a.eventTokensUsed(session.ID, model, usage, cost)
840
841 if overrideCost != nil {
842 session.Cost += *overrideCost
843 } else {
844 session.Cost += cost
845 }
846
847 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
848 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
849}
850
851func (a *sessionAgent) Cancel(sessionID string) {
852 // Cancel regular requests.
853 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
854 slog.Info("Request cancellation initiated", "session_id", sessionID)
855 cancel()
856 }
857
858 // Also check for summarize requests.
859 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
860 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
861 cancel()
862 }
863
864 if a.QueuedPrompts(sessionID) > 0 {
865 slog.Info("Clearing queued prompts", "session_id", sessionID)
866 a.messageQueue.Del(sessionID)
867 }
868}
869
870func (a *sessionAgent) ClearQueue(sessionID string) {
871 if a.QueuedPrompts(sessionID) > 0 {
872 slog.Info("Clearing queued prompts", "session_id", sessionID)
873 a.messageQueue.Del(sessionID)
874 }
875}
876
877func (a *sessionAgent) CancelAll() {
878 if !a.IsBusy() {
879 return
880 }
881 for key := range a.activeRequests.Seq2() {
882 a.Cancel(key) // key is sessionID
883 }
884
885 timeout := time.After(5 * time.Second)
886 for a.IsBusy() {
887 select {
888 case <-timeout:
889 return
890 default:
891 time.Sleep(200 * time.Millisecond)
892 }
893 }
894}
895
896func (a *sessionAgent) IsBusy() bool {
897 var busy bool
898 for cancelFunc := range a.activeRequests.Seq() {
899 if cancelFunc != nil {
900 busy = true
901 break
902 }
903 }
904 return busy
905}
906
907func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
908 _, busy := a.activeRequests.Get(sessionID)
909 return busy
910}
911
912func (a *sessionAgent) QueuedPrompts(sessionID string) int {
913 l, ok := a.messageQueue.Get(sessionID)
914 if !ok {
915 return 0
916 }
917 return len(l)
918}
919
920func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
921 l, ok := a.messageQueue.Get(sessionID)
922 if !ok {
923 return nil
924 }
925 prompts := make([]string, len(l))
926 for i, call := range l {
927 prompts[i] = call.Prompt
928 }
929 return prompts
930}
931
932func (a *sessionAgent) SetModels(large Model, small Model) {
933 a.largeModel = large
934 a.smallModel = small
935}
936
937func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
938 a.tools = tools
939}
940
941func (a *sessionAgent) Model() Model {
942 return a.largeModel
943}
944
945func (a *sessionAgent) promptPrefix() string {
946 if a.isClaudeCode() {
947 return "You are Claude Code, Anthropic's official CLI for Claude."
948 }
949 return a.systemPromptPrefix
950}
951
952func (a *sessionAgent) isClaudeCode() bool {
953 cfg := config.Get()
954 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
955 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
956}
957
958// convertToToolResult converts a fantasy tool result to a message tool result.
959func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
960 baseResult := message.ToolResult{
961 ToolCallID: result.ToolCallID,
962 Name: result.ToolName,
963 Metadata: result.ClientMetadata,
964 }
965
966 switch result.Result.GetType() {
967 case fantasy.ToolResultContentTypeText:
968 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
969 baseResult.Content = r.Text
970 }
971 case fantasy.ToolResultContentTypeError:
972 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
973 baseResult.Content = r.Error.Error()
974 baseResult.IsError = true
975 }
976 case fantasy.ToolResultContentTypeMedia:
977 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
978 content := r.Text
979 if content == "" {
980 content = fmt.Sprintf("Loaded %s content", r.MediaType)
981 }
982 baseResult.Content = content
983 baseResult.Data = r.Data
984 baseResult.MIMEType = r.MediaType
985 }
986 }
987
988 return baseResult
989}
990
991// workaroundProviderMediaLimitations converts media content in tool results to
992// user messages for providers that don't natively support images in tool results.
993//
994// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
995// don't support sending images/media in tool result messages - they only accept
996// text in tool results. However, they DO support images in user messages.
997//
998// If we send media in tool results to these providers, the API returns an error.
999//
1000// Solution: For these providers, we:
1001// 1. Replace the media in the tool result with a text placeholder
1002// 2. Inject a user message immediately after with the image as a file attachment
1003// 3. This maintains the tool execution flow while working around API limitations
1004//
1005// Anthropic and Bedrock support images natively in tool results, so we skip
1006// this workaround for them.
1007//
1008// Example transformation:
1009//
1010// BEFORE: [tool result: image data]
1011// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1012func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1013 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1014 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1015
1016 if providerSupportsMedia {
1017 return messages
1018 }
1019
1020 convertedMessages := make([]fantasy.Message, 0, len(messages))
1021
1022 for _, msg := range messages {
1023 if msg.Role != fantasy.MessageRoleTool {
1024 convertedMessages = append(convertedMessages, msg)
1025 continue
1026 }
1027
1028 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1029 var mediaFiles []fantasy.FilePart
1030
1031 for _, part := range msg.Content {
1032 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1033 if !ok {
1034 textParts = append(textParts, part)
1035 continue
1036 }
1037
1038 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1039 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1040 if err != nil {
1041 slog.Warn("failed to decode media data", "error", err)
1042 textParts = append(textParts, part)
1043 continue
1044 }
1045
1046 mediaFiles = append(mediaFiles, fantasy.FilePart{
1047 Data: decoded,
1048 MediaType: media.MediaType,
1049 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1050 })
1051
1052 textParts = append(textParts, fantasy.ToolResultPart{
1053 ToolCallID: toolResult.ToolCallID,
1054 Output: fantasy.ToolResultOutputContentText{
1055 Text: "[Image/media content loaded - see attached file]",
1056 },
1057 ProviderOptions: toolResult.ProviderOptions,
1058 })
1059 } else {
1060 textParts = append(textParts, part)
1061 }
1062 }
1063
1064 convertedMessages = append(convertedMessages, fantasy.Message{
1065 Role: fantasy.MessageRoleTool,
1066 Content: textParts,
1067 })
1068
1069 if len(mediaFiles) > 0 {
1070 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1071 "Here is the media content from the tool result:",
1072 mediaFiles...,
1073 ))
1074 }
1075 }
1076
1077 return convertedMessages
1078}