1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43const defaultSessionName = "Untitled Session"
44
45//go:embed templates/title.md
46var titlePrompt []byte
47
48//go:embed templates/summary.md
49var summaryPrompt []byte
50
51// Used to remove <think> tags from generated titles.
52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
53
54type SessionAgentCall struct {
55 SessionID string
56 Prompt string
57 ProviderOptions fantasy.ProviderOptions
58 Attachments []message.Attachment
59 MaxOutputTokens int64
60 Temperature *float64
61 TopP *float64
62 TopK *int64
63 FrequencyPenalty *float64
64 PresencePenalty *float64
65}
66
67type SessionAgent interface {
68 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
69 SetModels(large Model, small Model)
70 SetTools(tools []fantasy.AgentTool)
71 SetSystemPrompt(systemPrompt string)
72 Cancel(sessionID string)
73 CancelAll()
74 IsSessionBusy(sessionID string) bool
75 IsBusy() bool
76 QueuedPrompts(sessionID string) int
77 QueuedPromptsList(sessionID string) []string
78 ClearQueue(sessionID string)
79 Summarize(context.Context, string, fantasy.ProviderOptions) error
80 Model() Model
81}
82
83type Model struct {
84 Model fantasy.LanguageModel
85 CatwalkCfg catwalk.Model
86 ModelCfg config.SelectedModel
87}
88
89type sessionAgent struct {
90 largeModel *csync.Value[Model]
91 smallModel *csync.Value[Model]
92 systemPromptPrefix *csync.Value[string]
93 systemPrompt *csync.Value[string]
94 tools *csync.Slice[fantasy.AgentTool]
95
96 isSubAgent bool
97 sessions session.Service
98 messages message.Service
99 disableAutoSummarize bool
100 isYolo bool
101
102 messageQueue *csync.Map[string, []SessionAgentCall]
103 activeRequests *csync.Map[string, context.CancelFunc]
104}
105
106type SessionAgentOptions struct {
107 LargeModel Model
108 SmallModel Model
109 SystemPromptPrefix string
110 SystemPrompt string
111 IsSubAgent bool
112 DisableAutoSummarize bool
113 IsYolo bool
114 Sessions session.Service
115 Messages message.Service
116 Tools []fantasy.AgentTool
117}
118
119func NewSessionAgent(
120 opts SessionAgentOptions,
121) SessionAgent {
122 return &sessionAgent{
123 largeModel: csync.NewValue(opts.LargeModel),
124 smallModel: csync.NewValue(opts.SmallModel),
125 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
126 systemPrompt: csync.NewValue(opts.SystemPrompt),
127 isSubAgent: opts.IsSubAgent,
128 sessions: opts.Sessions,
129 messages: opts.Messages,
130 disableAutoSummarize: opts.DisableAutoSummarize,
131 tools: csync.NewSliceFrom(opts.Tools),
132 isYolo: opts.IsYolo,
133 messageQueue: csync.NewMap[string, []SessionAgentCall](),
134 activeRequests: csync.NewMap[string, context.CancelFunc](),
135 }
136}
137
138func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
139 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
140 return nil, ErrEmptyPrompt
141 }
142 if call.SessionID == "" {
143 return nil, ErrSessionMissing
144 }
145
146 // Queue the message if busy
147 if a.IsSessionBusy(call.SessionID) {
148 existing, ok := a.messageQueue.Get(call.SessionID)
149 if !ok {
150 existing = []SessionAgentCall{}
151 }
152 existing = append(existing, call)
153 a.messageQueue.Set(call.SessionID, existing)
154 return nil, nil
155 }
156
157 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
158 agentTools := a.tools.Copy()
159 largeModel := a.largeModel.Get()
160 systemPrompt := a.systemPrompt.Get()
161 promptPrefix := a.systemPromptPrefix.Get()
162
163 if len(agentTools) > 0 {
164 // Add Anthropic caching to the last tool.
165 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
166 }
167
168 agent := fantasy.NewAgent(
169 largeModel.Model,
170 fantasy.WithSystemPrompt(systemPrompt),
171 fantasy.WithTools(agentTools...),
172 )
173
174 sessionLock := sync.Mutex{}
175 currentSession, err := a.sessions.Get(ctx, call.SessionID)
176 if err != nil {
177 return nil, fmt.Errorf("failed to get session: %w", err)
178 }
179
180 msgs, err := a.getSessionMessages(ctx, currentSession)
181 if err != nil {
182 return nil, fmt.Errorf("failed to get session messages: %w", err)
183 }
184
185 var wg sync.WaitGroup
186 // Generate title if first message.
187 if len(msgs) == 0 {
188 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
189 wg.Go(func() {
190 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
191 })
192 }
193 defer wg.Wait()
194
195 // Add the user message to the session.
196 _, err = a.createUserMessage(ctx, call)
197 if err != nil {
198 return nil, err
199 }
200
201 // Add the session to the context.
202 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
203
204 genCtx, cancel := context.WithCancel(ctx)
205 a.activeRequests.Set(call.SessionID, cancel)
206
207 defer cancel()
208 defer a.activeRequests.Del(call.SessionID)
209
210 history, files := a.preparePrompt(msgs, call.Attachments...)
211
212 startTime := time.Now()
213 a.eventPromptSent(call.SessionID)
214
215 var currentAssistant *message.Message
216 var shouldSummarize bool
217 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
218 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
219 Files: files,
220 Messages: history,
221 ProviderOptions: call.ProviderOptions,
222 MaxOutputTokens: &call.MaxOutputTokens,
223 TopP: call.TopP,
224 Temperature: call.Temperature,
225 PresencePenalty: call.PresencePenalty,
226 TopK: call.TopK,
227 FrequencyPenalty: call.FrequencyPenalty,
228 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
229 prepared.Messages = options.Messages
230 for i := range prepared.Messages {
231 prepared.Messages[i].ProviderOptions = nil
232 }
233
234 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
235 a.messageQueue.Del(call.SessionID)
236 for _, queued := range queuedCalls {
237 userMessage, createErr := a.createUserMessage(callContext, queued)
238 if createErr != nil {
239 return callContext, prepared, createErr
240 }
241 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
242 }
243
244 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
245
246 lastSystemRoleInx := 0
247 systemMessageUpdated := false
248 for i, msg := range prepared.Messages {
249 // Only add cache control to the last message.
250 if msg.Role == fantasy.MessageRoleSystem {
251 lastSystemRoleInx = i
252 } else if !systemMessageUpdated {
253 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
254 systemMessageUpdated = true
255 }
256 // Than add cache control to the last 2 messages.
257 if i > len(prepared.Messages)-3 {
258 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
259 }
260 }
261
262 if promptPrefix != "" {
263 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
264 }
265
266 var assistantMsg message.Message
267 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
268 Role: message.Assistant,
269 Parts: []message.ContentPart{},
270 Model: largeModel.ModelCfg.Model,
271 Provider: largeModel.ModelCfg.Provider,
272 })
273 if err != nil {
274 return callContext, prepared, err
275 }
276 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
277 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
278 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
279 currentAssistant = &assistantMsg
280 return callContext, prepared, err
281 },
282 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
283 currentAssistant.AppendReasoningContent(reasoning.Text)
284 return a.messages.Update(genCtx, *currentAssistant)
285 },
286 OnReasoningDelta: func(id string, text string) error {
287 currentAssistant.AppendReasoningContent(text)
288 return a.messages.Update(genCtx, *currentAssistant)
289 },
290 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
291 // handle anthropic signature
292 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
293 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
294 currentAssistant.AppendReasoningSignature(reasoning.Signature)
295 }
296 }
297 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
298 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
299 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
300 }
301 }
302 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
303 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
304 currentAssistant.SetReasoningResponsesData(reasoning)
305 }
306 }
307 currentAssistant.FinishThinking()
308 return a.messages.Update(genCtx, *currentAssistant)
309 },
310 OnTextDelta: func(id string, text string) error {
311 // Strip leading newline from initial text content. This is is
312 // particularly important in non-interactive mode where leading
313 // newlines are very visible.
314 if len(currentAssistant.Parts) == 0 {
315 text = strings.TrimPrefix(text, "\n")
316 }
317
318 currentAssistant.AppendContent(text)
319 return a.messages.Update(genCtx, *currentAssistant)
320 },
321 OnToolInputStart: func(id string, toolName string) error {
322 toolCall := message.ToolCall{
323 ID: id,
324 Name: toolName,
325 ProviderExecuted: false,
326 Finished: false,
327 }
328 currentAssistant.AddToolCall(toolCall)
329 return a.messages.Update(genCtx, *currentAssistant)
330 },
331 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
332 // TODO: implement
333 },
334 OnToolCall: func(tc fantasy.ToolCallContent) error {
335 toolCall := message.ToolCall{
336 ID: tc.ToolCallID,
337 Name: tc.ToolName,
338 Input: tc.Input,
339 ProviderExecuted: false,
340 Finished: true,
341 }
342 currentAssistant.AddToolCall(toolCall)
343 return a.messages.Update(genCtx, *currentAssistant)
344 },
345 OnToolResult: func(result fantasy.ToolResultContent) error {
346 toolResult := a.convertToToolResult(result)
347 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
348 Role: message.Tool,
349 Parts: []message.ContentPart{
350 toolResult,
351 },
352 })
353 return createMsgErr
354 },
355 OnStepFinish: func(stepResult fantasy.StepResult) error {
356 finishReason := message.FinishReasonUnknown
357 switch stepResult.FinishReason {
358 case fantasy.FinishReasonLength:
359 finishReason = message.FinishReasonMaxTokens
360 case fantasy.FinishReasonStop:
361 finishReason = message.FinishReasonEndTurn
362 case fantasy.FinishReasonToolCalls:
363 finishReason = message.FinishReasonToolUse
364 }
365 currentAssistant.AddFinish(finishReason, "", "")
366 sessionLock.Lock()
367 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
368 if getSessionErr != nil {
369 sessionLock.Unlock()
370 return getSessionErr
371 }
372 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
373 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
374 sessionLock.Unlock()
375 if sessionErr != nil {
376 return sessionErr
377 }
378 return a.messages.Update(genCtx, *currentAssistant)
379 },
380 StopWhen: []fantasy.StopCondition{
381 func(_ []fantasy.StepResult) bool {
382 cw := int64(largeModel.CatwalkCfg.ContextWindow)
383 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
384 remaining := cw - tokens
385 var threshold int64
386 if cw > 200_000 {
387 threshold = 20_000
388 } else {
389 threshold = int64(float64(cw) * 0.2)
390 }
391 if (remaining <= threshold) && !a.disableAutoSummarize {
392 shouldSummarize = true
393 return true
394 }
395 return false
396 },
397 },
398 })
399
400 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
401
402 if err != nil {
403 isCancelErr := errors.Is(err, context.Canceled)
404 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
405 if currentAssistant == nil {
406 return result, err
407 }
408 // Ensure we finish thinking on error to close the reasoning state.
409 currentAssistant.FinishThinking()
410 toolCalls := currentAssistant.ToolCalls()
411 // INFO: we use the parent context here because the genCtx has been cancelled.
412 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
413 if createErr != nil {
414 return nil, createErr
415 }
416 for _, tc := range toolCalls {
417 if !tc.Finished {
418 tc.Finished = true
419 tc.Input = "{}"
420 currentAssistant.AddToolCall(tc)
421 updateErr := a.messages.Update(ctx, *currentAssistant)
422 if updateErr != nil {
423 return nil, updateErr
424 }
425 }
426
427 found := false
428 for _, msg := range msgs {
429 if msg.Role == message.Tool {
430 for _, tr := range msg.ToolResults() {
431 if tr.ToolCallID == tc.ID {
432 found = true
433 break
434 }
435 }
436 }
437 if found {
438 break
439 }
440 }
441 if found {
442 continue
443 }
444 content := "There was an error while executing the tool"
445 if isCancelErr {
446 content = "Tool execution canceled by user"
447 } else if isPermissionErr {
448 content = "User denied permission"
449 }
450 toolResult := message.ToolResult{
451 ToolCallID: tc.ID,
452 Name: tc.Name,
453 Content: content,
454 IsError: true,
455 }
456 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
457 Role: message.Tool,
458 Parts: []message.ContentPart{
459 toolResult,
460 },
461 })
462 if createErr != nil {
463 return nil, createErr
464 }
465 }
466 var fantasyErr *fantasy.Error
467 var providerErr *fantasy.ProviderError
468 const defaultTitle = "Provider Error"
469 if isCancelErr {
470 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
471 } else if isPermissionErr {
472 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
473 } else if errors.Is(err, hyper.ErrNoCredits) {
474 url := hyper.BaseURL()
475 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
476 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
477 } else if errors.As(err, &providerErr) {
478 if providerErr.Message == "The requested model is not supported." {
479 url := "https://github.com/settings/copilot/features"
480 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
481 currentAssistant.AddFinish(
482 message.FinishReasonError,
483 "Copilot model not enabled",
484 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
485 )
486 } else {
487 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
488 }
489 } else if errors.As(err, &fantasyErr) {
490 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
491 } else {
492 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
493 }
494 // Note: we use the parent context here because the genCtx has been
495 // cancelled.
496 updateErr := a.messages.Update(ctx, *currentAssistant)
497 if updateErr != nil {
498 return nil, updateErr
499 }
500 return nil, err
501 }
502
503 if shouldSummarize {
504 a.activeRequests.Del(call.SessionID)
505 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
506 return nil, summarizeErr
507 }
508 // If the agent wasn't done...
509 if len(currentAssistant.ToolCalls()) > 0 {
510 existing, ok := a.messageQueue.Get(call.SessionID)
511 if !ok {
512 existing = []SessionAgentCall{}
513 }
514 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
515 existing = append(existing, call)
516 a.messageQueue.Set(call.SessionID, existing)
517 }
518 }
519
520 // Release active request before processing queued messages.
521 a.activeRequests.Del(call.SessionID)
522 cancel()
523
524 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
525 if !ok || len(queuedMessages) == 0 {
526 return result, err
527 }
528 // There are queued messages restart the loop.
529 firstQueuedMessage := queuedMessages[0]
530 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
531 return a.Run(ctx, firstQueuedMessage)
532}
533
534func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
535 if a.IsSessionBusy(sessionID) {
536 return ErrSessionBusy
537 }
538
539 // Copy mutable fields under lock to avoid races with SetModels.
540 largeModel := a.largeModel.Get()
541 systemPromptPrefix := a.systemPromptPrefix.Get()
542
543 currentSession, err := a.sessions.Get(ctx, sessionID)
544 if err != nil {
545 return fmt.Errorf("failed to get session: %w", err)
546 }
547 msgs, err := a.getSessionMessages(ctx, currentSession)
548 if err != nil {
549 return err
550 }
551 if len(msgs) == 0 {
552 // Nothing to summarize.
553 return nil
554 }
555
556 aiMsgs, _ := a.preparePrompt(msgs)
557
558 genCtx, cancel := context.WithCancel(ctx)
559 a.activeRequests.Set(sessionID, cancel)
560 defer a.activeRequests.Del(sessionID)
561 defer cancel()
562
563 agent := fantasy.NewAgent(largeModel.Model,
564 fantasy.WithSystemPrompt(string(summaryPrompt)),
565 )
566 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
567 Role: message.Assistant,
568 Model: largeModel.Model.Model(),
569 Provider: largeModel.Model.Provider(),
570 IsSummaryMessage: true,
571 })
572 if err != nil {
573 return err
574 }
575
576 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
577
578 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
579 Prompt: summaryPromptText,
580 Messages: aiMsgs,
581 ProviderOptions: opts,
582 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
583 prepared.Messages = options.Messages
584 if systemPromptPrefix != "" {
585 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
586 }
587 return callContext, prepared, nil
588 },
589 OnReasoningDelta: func(id string, text string) error {
590 summaryMessage.AppendReasoningContent(text)
591 return a.messages.Update(genCtx, summaryMessage)
592 },
593 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
594 // Handle anthropic signature.
595 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
596 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
597 summaryMessage.AppendReasoningSignature(signature.Signature)
598 }
599 }
600 summaryMessage.FinishThinking()
601 return a.messages.Update(genCtx, summaryMessage)
602 },
603 OnTextDelta: func(id, text string) error {
604 summaryMessage.AppendContent(text)
605 return a.messages.Update(genCtx, summaryMessage)
606 },
607 })
608 if err != nil {
609 isCancelErr := errors.Is(err, context.Canceled)
610 if isCancelErr {
611 // User cancelled summarize we need to remove the summary message.
612 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
613 return deleteErr
614 }
615 return err
616 }
617
618 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
619 err = a.messages.Update(genCtx, summaryMessage)
620 if err != nil {
621 return err
622 }
623
624 var openrouterCost *float64
625 for _, step := range resp.Steps {
626 stepCost := a.openrouterCost(step.ProviderMetadata)
627 if stepCost != nil {
628 newCost := *stepCost
629 if openrouterCost != nil {
630 newCost += *openrouterCost
631 }
632 openrouterCost = &newCost
633 }
634 }
635
636 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
637
638 // Just in case, get just the last usage info.
639 usage := resp.Response.Usage
640 currentSession.SummaryMessageID = summaryMessage.ID
641 currentSession.CompletionTokens = usage.OutputTokens
642 currentSession.PromptTokens = 0
643 _, err = a.sessions.Save(genCtx, currentSession)
644 return err
645}
646
647func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
648 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
649 return fantasy.ProviderOptions{}
650 }
651 return fantasy.ProviderOptions{
652 anthropic.Name: &anthropic.ProviderCacheControlOptions{
653 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
654 },
655 bedrock.Name: &anthropic.ProviderCacheControlOptions{
656 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
657 },
658 }
659}
660
661func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
662 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
663 var attachmentParts []message.ContentPart
664 for _, attachment := range call.Attachments {
665 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
666 }
667 parts = append(parts, attachmentParts...)
668 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
669 Role: message.User,
670 Parts: parts,
671 })
672 if err != nil {
673 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
674 }
675 return msg, nil
676}
677
678func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
679 var history []fantasy.Message
680 if !a.isSubAgent {
681 history = append(history, fantasy.NewUserMessage(
682 fmt.Sprintf("<system_reminder>%s</system_reminder>",
683 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
684If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
685If not, please feel free to ignore. Again do not mention this message to the user.`,
686 ),
687 ))
688 }
689 for _, m := range msgs {
690 if len(m.Parts) == 0 {
691 continue
692 }
693 // Assistant message without content or tool calls (cancelled before it
694 // returned anything).
695 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
696 continue
697 }
698 history = append(history, m.ToAIMessage()...)
699 }
700
701 var files []fantasy.FilePart
702 for _, attachment := range attachments {
703 if attachment.IsText() {
704 continue
705 }
706 files = append(files, fantasy.FilePart{
707 Filename: attachment.FileName,
708 Data: attachment.Content,
709 MediaType: attachment.MimeType,
710 })
711 }
712
713 return history, files
714}
715
716func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
717 msgs, err := a.messages.List(ctx, session.ID)
718 if err != nil {
719 return nil, fmt.Errorf("failed to list messages: %w", err)
720 }
721
722 if session.SummaryMessageID != "" {
723 summaryMsgInex := -1
724 for i, msg := range msgs {
725 if msg.ID == session.SummaryMessageID {
726 summaryMsgInex = i
727 break
728 }
729 }
730 if summaryMsgInex != -1 {
731 msgs = msgs[summaryMsgInex:]
732 msgs[0].Role = message.User
733 }
734 }
735 return msgs, nil
736}
737
738// generateTitle generates a session titled based on the initial prompt.
739func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
740 if userPrompt == "" {
741 return
742 }
743
744 smallModel := a.smallModel.Get()
745 largeModel := a.largeModel.Get()
746 systemPromptPrefix := a.systemPromptPrefix.Get()
747
748 var maxOutputTokens int64 = 40
749 if smallModel.CatwalkCfg.CanReason {
750 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
751 }
752
753 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
754 return fantasy.NewAgent(m,
755 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
756 fantasy.WithMaxOutputTokens(tok),
757 )
758 }
759
760 streamCall := fantasy.AgentStreamCall{
761 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
762 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
763 prepared.Messages = opts.Messages
764 if systemPromptPrefix != "" {
765 prepared.Messages = append([]fantasy.Message{
766 fantasy.NewSystemMessage(systemPromptPrefix),
767 }, prepared.Messages...)
768 }
769 return callCtx, prepared, nil
770 },
771 }
772
773 // Use the small model to generate the title.
774 model := smallModel
775 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
776 resp, err := agent.Stream(ctx, streamCall)
777 if err == nil {
778 // We successfully generated a title with the small model.
779 slog.Info("generated title with small model")
780 } else {
781 // It didn't work. Let's try with the big model.
782 slog.Error("error generating title with small model; trying big model", "err", err)
783 model = largeModel
784 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
785 resp, err = agent.Stream(ctx, streamCall)
786 if err == nil {
787 slog.Info("generated title with large model")
788 } else {
789 // Welp, the large model didn't work either. Use the default
790 // session name and return.
791 slog.Error("error generating title with large model", "err", err)
792 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
793 if saveErr != nil {
794 slog.Error("failed to save session title and usage", "error", saveErr)
795 }
796 return
797 }
798 }
799
800 if resp == nil {
801 // Actually, we didn't get a response so we can't. Use the default
802 // session name and return.
803 slog.Error("response is nil; can't generate title")
804 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
805 if saveErr != nil {
806 slog.Error("failed to save session title and usage", "error", saveErr)
807 }
808 return
809 }
810
811 // Clean up title.
812 var title string
813 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
814 slog.Info("generated title", "title", title)
815
816 // Remove thinking tags if present.
817 title = thinkTagRegex.ReplaceAllString(title, "")
818
819 title = strings.TrimSpace(title)
820 if title == "" {
821 slog.Warn("empty title; using fallback")
822 title = defaultSessionName
823 }
824
825 // Calculate usage and cost.
826 var openrouterCost *float64
827 for _, step := range resp.Steps {
828 stepCost := a.openrouterCost(step.ProviderMetadata)
829 if stepCost != nil {
830 newCost := *stepCost
831 if openrouterCost != nil {
832 newCost += *openrouterCost
833 }
834 openrouterCost = &newCost
835 }
836 }
837
838 modelConfig := model.CatwalkCfg
839 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
840 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
841 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
842 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
843
844 // Use override cost if available (e.g., from OpenRouter).
845 if openrouterCost != nil {
846 cost = *openrouterCost
847 }
848
849 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
850 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
851
852 // Atomically update only title and usage fields to avoid overriding other
853 // concurrent session updates.
854 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
855 if saveErr != nil {
856 slog.Error("failed to save session title and usage", "error", saveErr)
857 return
858 }
859}
860
861func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
862 openrouterMetadata, ok := metadata[openrouter.Name]
863 if !ok {
864 return nil
865 }
866
867 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
868 if !ok {
869 return nil
870 }
871 return &opts.Usage.Cost
872}
873
874func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
875 modelConfig := model.CatwalkCfg
876 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
877 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
878 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
879 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
880
881 a.eventTokensUsed(session.ID, model, usage, cost)
882
883 if overrideCost != nil {
884 session.Cost += *overrideCost
885 } else {
886 session.Cost += cost
887 }
888
889 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
890 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
891}
892
893func (a *sessionAgent) Cancel(sessionID string) {
894 // Cancel regular requests. Don't use Take() here - we need the entry to
895 // remain in activeRequests so IsBusy() returns true until the goroutine
896 // fully completes (including error handling that may access the DB).
897 // The defer in processRequest will clean up the entry.
898 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
899 slog.Info("Request cancellation initiated", "session_id", sessionID)
900 cancel()
901 }
902
903 // Also check for summarize requests.
904 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
905 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
906 cancel()
907 }
908
909 if a.QueuedPrompts(sessionID) > 0 {
910 slog.Info("Clearing queued prompts", "session_id", sessionID)
911 a.messageQueue.Del(sessionID)
912 }
913}
914
915func (a *sessionAgent) ClearQueue(sessionID string) {
916 if a.QueuedPrompts(sessionID) > 0 {
917 slog.Info("Clearing queued prompts", "session_id", sessionID)
918 a.messageQueue.Del(sessionID)
919 }
920}
921
922func (a *sessionAgent) CancelAll() {
923 if !a.IsBusy() {
924 return
925 }
926 for key := range a.activeRequests.Seq2() {
927 a.Cancel(key) // key is sessionID
928 }
929
930 timeout := time.After(5 * time.Second)
931 for a.IsBusy() {
932 select {
933 case <-timeout:
934 return
935 default:
936 time.Sleep(200 * time.Millisecond)
937 }
938 }
939}
940
941func (a *sessionAgent) IsBusy() bool {
942 var busy bool
943 for cancelFunc := range a.activeRequests.Seq() {
944 if cancelFunc != nil {
945 busy = true
946 break
947 }
948 }
949 return busy
950}
951
952func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
953 _, busy := a.activeRequests.Get(sessionID)
954 return busy
955}
956
957func (a *sessionAgent) QueuedPrompts(sessionID string) int {
958 l, ok := a.messageQueue.Get(sessionID)
959 if !ok {
960 return 0
961 }
962 return len(l)
963}
964
965func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
966 l, ok := a.messageQueue.Get(sessionID)
967 if !ok {
968 return nil
969 }
970 prompts := make([]string, len(l))
971 for i, call := range l {
972 prompts[i] = call.Prompt
973 }
974 return prompts
975}
976
977func (a *sessionAgent) SetModels(large Model, small Model) {
978 a.largeModel.Set(large)
979 a.smallModel.Set(small)
980}
981
982func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
983 a.tools.SetSlice(tools)
984}
985
986func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
987 a.systemPrompt.Set(systemPrompt)
988}
989
990func (a *sessionAgent) Model() Model {
991 return a.largeModel.Get()
992}
993
994// convertToToolResult converts a fantasy tool result to a message tool result.
995func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
996 baseResult := message.ToolResult{
997 ToolCallID: result.ToolCallID,
998 Name: result.ToolName,
999 Metadata: result.ClientMetadata,
1000 }
1001
1002 switch result.Result.GetType() {
1003 case fantasy.ToolResultContentTypeText:
1004 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1005 baseResult.Content = r.Text
1006 }
1007 case fantasy.ToolResultContentTypeError:
1008 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1009 baseResult.Content = r.Error.Error()
1010 baseResult.IsError = true
1011 }
1012 case fantasy.ToolResultContentTypeMedia:
1013 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1014 content := r.Text
1015 if content == "" {
1016 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1017 }
1018 baseResult.Content = content
1019 baseResult.Data = r.Data
1020 baseResult.MIMEType = r.MediaType
1021 }
1022 }
1023
1024 return baseResult
1025}
1026
1027// workaroundProviderMediaLimitations converts media content in tool results to
1028// user messages for providers that don't natively support images in tool results.
1029//
1030// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1031// don't support sending images/media in tool result messages - they only accept
1032// text in tool results. However, they DO support images in user messages.
1033//
1034// If we send media in tool results to these providers, the API returns an error.
1035//
1036// Solution: For these providers, we:
1037// 1. Replace the media in the tool result with a text placeholder
1038// 2. Inject a user message immediately after with the image as a file attachment
1039// 3. This maintains the tool execution flow while working around API limitations
1040//
1041// Anthropic and Bedrock support images natively in tool results, so we skip
1042// this workaround for them.
1043//
1044// Example transformation:
1045//
1046// BEFORE: [tool result: image data]
1047// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1048func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1049 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1050 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1051
1052 if providerSupportsMedia {
1053 return messages
1054 }
1055
1056 convertedMessages := make([]fantasy.Message, 0, len(messages))
1057
1058 for _, msg := range messages {
1059 if msg.Role != fantasy.MessageRoleTool {
1060 convertedMessages = append(convertedMessages, msg)
1061 continue
1062 }
1063
1064 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1065 var mediaFiles []fantasy.FilePart
1066
1067 for _, part := range msg.Content {
1068 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1069 if !ok {
1070 textParts = append(textParts, part)
1071 continue
1072 }
1073
1074 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1075 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1076 if err != nil {
1077 slog.Warn("failed to decode media data", "error", err)
1078 textParts = append(textParts, part)
1079 continue
1080 }
1081
1082 mediaFiles = append(mediaFiles, fantasy.FilePart{
1083 Data: decoded,
1084 MediaType: media.MediaType,
1085 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1086 })
1087
1088 textParts = append(textParts, fantasy.ToolResultPart{
1089 ToolCallID: toolResult.ToolCallID,
1090 Output: fantasy.ToolResultOutputContentText{
1091 Text: "[Image/media content loaded - see attached file]",
1092 },
1093 ProviderOptions: toolResult.ProviderOptions,
1094 })
1095 } else {
1096 textParts = append(textParts, part)
1097 }
1098 }
1099
1100 convertedMessages = append(convertedMessages, fantasy.Message{
1101 Role: fantasy.MessageRoleTool,
1102 Content: textParts,
1103 })
1104
1105 if len(mediaFiles) > 0 {
1106 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1107 "Here is the media content from the tool result:",
1108 mediaFiles...,
1109 ))
1110 }
1111 }
1112
1113 return convertedMessages
1114}
1115
1116// buildSummaryPrompt constructs the prompt text for session summarization.
1117func buildSummaryPrompt(todos []session.Todo) string {
1118 var sb strings.Builder
1119 sb.WriteString("Provide a detailed summary of our conversation above.")
1120 if len(todos) > 0 {
1121 sb.WriteString("\n\n## Current Todo List\n\n")
1122 for _, t := range todos {
1123 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1124 }
1125 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1126 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1127 }
1128 return sb.String()
1129}