1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41 "github.com/charmbracelet/x/exp/charmtone"
42)
43
44const (
45 defaultSessionName = "Untitled Session"
46
47 // Constants for auto-summarization thresholds
48 largeContextWindowThreshold = 200_000
49 largeContextWindowBuffer = 20_000
50 smallContextWindowRatio = 0.2
51)
52
53//go:embed templates/title.md
54var titlePrompt []byte
55
56//go:embed templates/summary.md
57var summaryPrompt []byte
58
59// Used to remove <think> tags from generated titles.
60var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
61
62type SessionAgentCall struct {
63 SessionID string
64 Prompt string
65 ProviderOptions fantasy.ProviderOptions
66 Attachments []message.Attachment
67 MaxOutputTokens int64
68 Temperature *float64
69 TopP *float64
70 TopK *int64
71 FrequencyPenalty *float64
72 PresencePenalty *float64
73}
74
75type SessionAgent interface {
76 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
77 SetModels(large Model, small Model)
78 SetTools(tools []fantasy.AgentTool)
79 SetSystemPrompt(systemPrompt string)
80 Cancel(sessionID string)
81 CancelAll()
82 IsSessionBusy(sessionID string) bool
83 IsBusy() bool
84 QueuedPrompts(sessionID string) int
85 QueuedPromptsList(sessionID string) []string
86 ClearQueue(sessionID string)
87 Summarize(context.Context, string, fantasy.ProviderOptions) error
88 Model() Model
89}
90
91type Model struct {
92 Model fantasy.LanguageModel
93 CatwalkCfg catwalk.Model
94 ModelCfg config.SelectedModel
95}
96
97type sessionAgent struct {
98 largeModel *csync.Value[Model]
99 smallModel *csync.Value[Model]
100 systemPromptPrefix *csync.Value[string]
101 systemPrompt *csync.Value[string]
102 tools *csync.Slice[fantasy.AgentTool]
103
104 isSubAgent bool
105 sessions session.Service
106 messages message.Service
107 disableAutoSummarize bool
108 isYolo bool
109
110 messageQueue *csync.Map[string, []SessionAgentCall]
111 activeRequests *csync.Map[string, context.CancelFunc]
112}
113
114type SessionAgentOptions struct {
115 LargeModel Model
116 SmallModel Model
117 SystemPromptPrefix string
118 SystemPrompt string
119 IsSubAgent bool
120 DisableAutoSummarize bool
121 IsYolo bool
122 Sessions session.Service
123 Messages message.Service
124 Tools []fantasy.AgentTool
125}
126
127func NewSessionAgent(
128 opts SessionAgentOptions,
129) SessionAgent {
130 return &sessionAgent{
131 largeModel: csync.NewValue(opts.LargeModel),
132 smallModel: csync.NewValue(opts.SmallModel),
133 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
134 systemPrompt: csync.NewValue(opts.SystemPrompt),
135 isSubAgent: opts.IsSubAgent,
136 sessions: opts.Sessions,
137 messages: opts.Messages,
138 disableAutoSummarize: opts.DisableAutoSummarize,
139 tools: csync.NewSliceFrom(opts.Tools),
140 isYolo: opts.IsYolo,
141 messageQueue: csync.NewMap[string, []SessionAgentCall](),
142 activeRequests: csync.NewMap[string, context.CancelFunc](),
143 }
144}
145
146func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
147 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
148 return nil, ErrEmptyPrompt
149 }
150 if call.SessionID == "" {
151 return nil, ErrSessionMissing
152 }
153
154 // Queue the message if busy
155 if a.IsSessionBusy(call.SessionID) {
156 existing, ok := a.messageQueue.Get(call.SessionID)
157 if !ok {
158 existing = []SessionAgentCall{}
159 }
160 existing = append(existing, call)
161 a.messageQueue.Set(call.SessionID, existing)
162 return nil, nil
163 }
164
165 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
166 agentTools := a.tools.Copy()
167 largeModel := a.largeModel.Get()
168 systemPrompt := a.systemPrompt.Get()
169 promptPrefix := a.systemPromptPrefix.Get()
170
171 if len(agentTools) > 0 {
172 // Add Anthropic caching to the last tool.
173 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
174 }
175
176 agent := fantasy.NewAgent(
177 largeModel.Model,
178 fantasy.WithSystemPrompt(systemPrompt),
179 fantasy.WithTools(agentTools...),
180 )
181
182 sessionLock := sync.Mutex{}
183 currentSession, err := a.sessions.Get(ctx, call.SessionID)
184 if err != nil {
185 return nil, fmt.Errorf("failed to get session: %w", err)
186 }
187
188 msgs, err := a.getSessionMessages(ctx, currentSession)
189 if err != nil {
190 return nil, fmt.Errorf("failed to get session messages: %w", err)
191 }
192
193 var wg sync.WaitGroup
194 // Generate title if first message.
195 if len(msgs) == 0 {
196 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
197 wg.Go(func() {
198 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
199 })
200 }
201 defer wg.Wait()
202
203 // Add the user message to the session.
204 _, err = a.createUserMessage(ctx, call)
205 if err != nil {
206 return nil, err
207 }
208
209 // Add the session to the context.
210 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
211
212 genCtx, cancel := context.WithCancel(ctx)
213 a.activeRequests.Set(call.SessionID, cancel)
214
215 defer cancel()
216 defer a.activeRequests.Del(call.SessionID)
217
218 history, files := a.preparePrompt(msgs, call.Attachments...)
219
220 startTime := time.Now()
221 a.eventPromptSent(call.SessionID)
222
223 var currentAssistant *message.Message
224 var shouldSummarize bool
225 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
226 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
227 Files: files,
228 Messages: history,
229 ProviderOptions: call.ProviderOptions,
230 MaxOutputTokens: &call.MaxOutputTokens,
231 TopP: call.TopP,
232 Temperature: call.Temperature,
233 PresencePenalty: call.PresencePenalty,
234 TopK: call.TopK,
235 FrequencyPenalty: call.FrequencyPenalty,
236 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
237 prepared.Messages = options.Messages
238 for i := range prepared.Messages {
239 prepared.Messages[i].ProviderOptions = nil
240 }
241
242 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
243 a.messageQueue.Del(call.SessionID)
244 for _, queued := range queuedCalls {
245 userMessage, createErr := a.createUserMessage(callContext, queued)
246 if createErr != nil {
247 return callContext, prepared, createErr
248 }
249 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
250 }
251
252 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
253
254 lastSystemRoleInx := 0
255 systemMessageUpdated := false
256 for i, msg := range prepared.Messages {
257 // Only add cache control to the last message.
258 if msg.Role == fantasy.MessageRoleSystem {
259 lastSystemRoleInx = i
260 } else if !systemMessageUpdated {
261 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
262 systemMessageUpdated = true
263 }
264 // Than add cache control to the last 2 messages.
265 if i > len(prepared.Messages)-3 {
266 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
267 }
268 }
269
270 if promptPrefix != "" {
271 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
272 }
273
274 var assistantMsg message.Message
275 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
276 Role: message.Assistant,
277 Parts: []message.ContentPart{},
278 Model: largeModel.ModelCfg.Model,
279 Provider: largeModel.ModelCfg.Provider,
280 })
281 if err != nil {
282 return callContext, prepared, err
283 }
284 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
285 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
286 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
287 currentAssistant = &assistantMsg
288 return callContext, prepared, err
289 },
290 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
291 currentAssistant.AppendReasoningContent(reasoning.Text)
292 return a.messages.Update(genCtx, *currentAssistant)
293 },
294 OnReasoningDelta: func(id string, text string) error {
295 currentAssistant.AppendReasoningContent(text)
296 return a.messages.Update(genCtx, *currentAssistant)
297 },
298 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
299 // handle anthropic signature
300 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
301 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
302 currentAssistant.AppendReasoningSignature(reasoning.Signature)
303 }
304 }
305 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
306 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
307 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
308 }
309 }
310 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
311 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
312 currentAssistant.SetReasoningResponsesData(reasoning)
313 }
314 }
315 currentAssistant.FinishThinking()
316 return a.messages.Update(genCtx, *currentAssistant)
317 },
318 OnTextDelta: func(id string, text string) error {
319 // Strip leading newline from initial text content. This is is
320 // particularly important in non-interactive mode where leading
321 // newlines are very visible.
322 if len(currentAssistant.Parts) == 0 {
323 text = strings.TrimPrefix(text, "\n")
324 }
325
326 currentAssistant.AppendContent(text)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnToolInputStart: func(id string, toolName string) error {
330 toolCall := message.ToolCall{
331 ID: id,
332 Name: toolName,
333 ProviderExecuted: false,
334 Finished: false,
335 }
336 currentAssistant.AddToolCall(toolCall)
337 return a.messages.Update(genCtx, *currentAssistant)
338 },
339 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
340 // TODO: implement
341 },
342 OnToolCall: func(tc fantasy.ToolCallContent) error {
343 toolCall := message.ToolCall{
344 ID: tc.ToolCallID,
345 Name: tc.ToolName,
346 Input: tc.Input,
347 ProviderExecuted: false,
348 Finished: true,
349 }
350 currentAssistant.AddToolCall(toolCall)
351 return a.messages.Update(genCtx, *currentAssistant)
352 },
353 OnToolResult: func(result fantasy.ToolResultContent) error {
354 toolResult := a.convertToToolResult(result)
355 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
356 Role: message.Tool,
357 Parts: []message.ContentPart{
358 toolResult,
359 },
360 })
361 return createMsgErr
362 },
363 OnStepFinish: func(stepResult fantasy.StepResult) error {
364 finishReason := message.FinishReasonUnknown
365 switch stepResult.FinishReason {
366 case fantasy.FinishReasonLength:
367 finishReason = message.FinishReasonMaxTokens
368 case fantasy.FinishReasonStop:
369 finishReason = message.FinishReasonEndTurn
370 case fantasy.FinishReasonToolCalls:
371 finishReason = message.FinishReasonToolUse
372 }
373 currentAssistant.AddFinish(finishReason, "", "")
374 sessionLock.Lock()
375 defer sessionLock.Unlock()
376
377 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
378 if getSessionErr != nil {
379 return getSessionErr
380 }
381 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
382 _, sessionErr := a.sessions.Save(ctx, updatedSession)
383 if sessionErr != nil {
384 return sessionErr
385 }
386 currentSession = updatedSession
387 return a.messages.Update(genCtx, *currentAssistant)
388 },
389 StopWhen: []fantasy.StopCondition{
390 func(_ []fantasy.StepResult) bool {
391 cw := int64(largeModel.CatwalkCfg.ContextWindow)
392 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
393 remaining := cw - tokens
394 var threshold int64
395 if cw > largeContextWindowThreshold {
396 threshold = largeContextWindowBuffer
397 } else {
398 threshold = int64(float64(cw) * smallContextWindowRatio)
399 }
400 if (remaining <= threshold) && !a.disableAutoSummarize {
401 shouldSummarize = true
402 return true
403 }
404 return false
405 },
406 },
407 })
408
409 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
410
411 if err != nil {
412 isCancelErr := errors.Is(err, context.Canceled)
413 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
414 if currentAssistant == nil {
415 return result, err
416 }
417 // Ensure we finish thinking on error to close the reasoning state.
418 currentAssistant.FinishThinking()
419 toolCalls := currentAssistant.ToolCalls()
420 // INFO: we use the parent context here because the genCtx has been cancelled.
421 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
422 if createErr != nil {
423 return nil, createErr
424 }
425 for _, tc := range toolCalls {
426 if !tc.Finished {
427 tc.Finished = true
428 tc.Input = "{}"
429 currentAssistant.AddToolCall(tc)
430 updateErr := a.messages.Update(ctx, *currentAssistant)
431 if updateErr != nil {
432 return nil, updateErr
433 }
434 }
435
436 found := false
437 for _, msg := range msgs {
438 if msg.Role == message.Tool {
439 for _, tr := range msg.ToolResults() {
440 if tr.ToolCallID == tc.ID {
441 found = true
442 break
443 }
444 }
445 }
446 if found {
447 break
448 }
449 }
450 if found {
451 continue
452 }
453 content := "There was an error while executing the tool"
454 if isCancelErr {
455 content = "Tool execution canceled by user"
456 } else if isPermissionErr {
457 content = "User denied permission"
458 }
459 toolResult := message.ToolResult{
460 ToolCallID: tc.ID,
461 Name: tc.Name,
462 Content: content,
463 IsError: true,
464 }
465 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
466 Role: message.Tool,
467 Parts: []message.ContentPart{
468 toolResult,
469 },
470 })
471 if createErr != nil {
472 return nil, createErr
473 }
474 }
475 var fantasyErr *fantasy.Error
476 var providerErr *fantasy.ProviderError
477 const defaultTitle = "Provider Error"
478 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
479 if isCancelErr {
480 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
481 } else if isPermissionErr {
482 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
483 } else if errors.Is(err, hyper.ErrNoCredits) {
484 url := hyper.BaseURL()
485 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
486 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
487 } else if errors.As(err, &providerErr) {
488 if providerErr.Message == "The requested model is not supported." {
489 url := "https://github.com/settings/copilot/features"
490 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
491 currentAssistant.AddFinish(
492 message.FinishReasonError,
493 "Copilot model not enabled",
494 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
495 )
496 } else {
497 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
498 }
499 } else if errors.As(err, &fantasyErr) {
500 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
501 } else {
502 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
503 }
504 // Note: we use the parent context here because the genCtx has been
505 // cancelled.
506 updateErr := a.messages.Update(ctx, *currentAssistant)
507 if updateErr != nil {
508 return nil, updateErr
509 }
510 return nil, err
511 }
512
513 if shouldSummarize {
514 a.activeRequests.Del(call.SessionID)
515 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
516 return nil, summarizeErr
517 }
518 // If the agent wasn't done...
519 if len(currentAssistant.ToolCalls()) > 0 {
520 existing, ok := a.messageQueue.Get(call.SessionID)
521 if !ok {
522 existing = []SessionAgentCall{}
523 }
524 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
525 existing = append(existing, call)
526 a.messageQueue.Set(call.SessionID, existing)
527 }
528 }
529
530 // Release active request before processing queued messages.
531 a.activeRequests.Del(call.SessionID)
532 cancel()
533
534 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
535 if !ok || len(queuedMessages) == 0 {
536 return result, err
537 }
538 // There are queued messages restart the loop.
539 firstQueuedMessage := queuedMessages[0]
540 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
541 return a.Run(ctx, firstQueuedMessage)
542}
543
544func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
545 if a.IsSessionBusy(sessionID) {
546 return ErrSessionBusy
547 }
548
549 // Copy mutable fields under lock to avoid races with SetModels.
550 largeModel := a.largeModel.Get()
551 systemPromptPrefix := a.systemPromptPrefix.Get()
552
553 currentSession, err := a.sessions.Get(ctx, sessionID)
554 if err != nil {
555 return fmt.Errorf("failed to get session: %w", err)
556 }
557 msgs, err := a.getSessionMessages(ctx, currentSession)
558 if err != nil {
559 return err
560 }
561 if len(msgs) == 0 {
562 // Nothing to summarize.
563 return nil
564 }
565
566 aiMsgs, _ := a.preparePrompt(msgs)
567
568 genCtx, cancel := context.WithCancel(ctx)
569 a.activeRequests.Set(sessionID, cancel)
570 defer a.activeRequests.Del(sessionID)
571 defer cancel()
572
573 agent := fantasy.NewAgent(largeModel.Model,
574 fantasy.WithSystemPrompt(string(summaryPrompt)),
575 )
576 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
577 Role: message.Assistant,
578 Model: largeModel.Model.Model(),
579 Provider: largeModel.Model.Provider(),
580 IsSummaryMessage: true,
581 })
582 if err != nil {
583 return err
584 }
585
586 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
587
588 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
589 Prompt: summaryPromptText,
590 Messages: aiMsgs,
591 ProviderOptions: opts,
592 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
593 prepared.Messages = options.Messages
594 if systemPromptPrefix != "" {
595 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
596 }
597 return callContext, prepared, nil
598 },
599 OnReasoningDelta: func(id string, text string) error {
600 summaryMessage.AppendReasoningContent(text)
601 return a.messages.Update(genCtx, summaryMessage)
602 },
603 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
604 // Handle anthropic signature.
605 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
606 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
607 summaryMessage.AppendReasoningSignature(signature.Signature)
608 }
609 }
610 summaryMessage.FinishThinking()
611 return a.messages.Update(genCtx, summaryMessage)
612 },
613 OnTextDelta: func(id, text string) error {
614 summaryMessage.AppendContent(text)
615 return a.messages.Update(genCtx, summaryMessage)
616 },
617 })
618 if err != nil {
619 isCancelErr := errors.Is(err, context.Canceled)
620 if isCancelErr {
621 // User cancelled summarize we need to remove the summary message.
622 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
623 return deleteErr
624 }
625 return err
626 }
627
628 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
629 err = a.messages.Update(genCtx, summaryMessage)
630 if err != nil {
631 return err
632 }
633
634 var openrouterCost *float64
635 for _, step := range resp.Steps {
636 stepCost := a.openrouterCost(step.ProviderMetadata)
637 if stepCost != nil {
638 newCost := *stepCost
639 if openrouterCost != nil {
640 newCost += *openrouterCost
641 }
642 openrouterCost = &newCost
643 }
644 }
645
646 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
647
648 // Just in case, get just the last usage info.
649 usage := resp.Response.Usage
650 currentSession.SummaryMessageID = summaryMessage.ID
651 currentSession.CompletionTokens = usage.OutputTokens
652 currentSession.PromptTokens = 0
653 _, err = a.sessions.Save(genCtx, currentSession)
654 return err
655}
656
657func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
658 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
659 return fantasy.ProviderOptions{}
660 }
661 return fantasy.ProviderOptions{
662 anthropic.Name: &anthropic.ProviderCacheControlOptions{
663 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
664 },
665 bedrock.Name: &anthropic.ProviderCacheControlOptions{
666 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
667 },
668 }
669}
670
671func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
672 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
673 var attachmentParts []message.ContentPart
674 for _, attachment := range call.Attachments {
675 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
676 }
677 parts = append(parts, attachmentParts...)
678 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
679 Role: message.User,
680 Parts: parts,
681 })
682 if err != nil {
683 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
684 }
685 return msg, nil
686}
687
688func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
689 var history []fantasy.Message
690 if !a.isSubAgent {
691 history = append(history, fantasy.NewUserMessage(
692 fmt.Sprintf("<system_reminder>%s</system_reminder>",
693 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
694If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
695If not, please feel free to ignore. Again do not mention this message to the user.`,
696 ),
697 ))
698 }
699 for _, m := range msgs {
700 if len(m.Parts) == 0 {
701 continue
702 }
703 // Assistant message without content or tool calls (cancelled before it
704 // returned anything).
705 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
706 continue
707 }
708 history = append(history, m.ToAIMessage()...)
709 }
710
711 var files []fantasy.FilePart
712 for _, attachment := range attachments {
713 if attachment.IsText() {
714 continue
715 }
716 files = append(files, fantasy.FilePart{
717 Filename: attachment.FileName,
718 Data: attachment.Content,
719 MediaType: attachment.MimeType,
720 })
721 }
722
723 return history, files
724}
725
726func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
727 msgs, err := a.messages.List(ctx, session.ID)
728 if err != nil {
729 return nil, fmt.Errorf("failed to list messages: %w", err)
730 }
731
732 if session.SummaryMessageID != "" {
733 summaryMsgIndex := -1
734 for i, msg := range msgs {
735 if msg.ID == session.SummaryMessageID {
736 summaryMsgIndex = i
737 break
738 }
739 }
740 if summaryMsgIndex != -1 {
741 msgs = msgs[summaryMsgIndex:]
742 msgs[0].Role = message.User
743 }
744 }
745 return msgs, nil
746}
747
748// generateTitle generates a session titled based on the initial prompt.
749func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
750 if userPrompt == "" {
751 return
752 }
753
754 smallModel := a.smallModel.Get()
755 largeModel := a.largeModel.Get()
756 systemPromptPrefix := a.systemPromptPrefix.Get()
757
758 var maxOutputTokens int64 = 40
759 if smallModel.CatwalkCfg.CanReason {
760 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
761 }
762
763 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
764 return fantasy.NewAgent(m,
765 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
766 fantasy.WithMaxOutputTokens(tok),
767 )
768 }
769
770 streamCall := fantasy.AgentStreamCall{
771 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
772 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
773 prepared.Messages = opts.Messages
774 if systemPromptPrefix != "" {
775 prepared.Messages = append([]fantasy.Message{
776 fantasy.NewSystemMessage(systemPromptPrefix),
777 }, prepared.Messages...)
778 }
779 return callCtx, prepared, nil
780 },
781 }
782
783 // Use the small model to generate the title.
784 model := smallModel
785 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
786 resp, err := agent.Stream(ctx, streamCall)
787 if err == nil {
788 // We successfully generated a title with the small model.
789 slog.Info("generated title with small model")
790 } else {
791 // It didn't work. Let's try with the big model.
792 slog.Error("error generating title with small model; trying big model", "err", err)
793 model = largeModel
794 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
795 resp, err = agent.Stream(ctx, streamCall)
796 if err == nil {
797 slog.Info("generated title with large model")
798 } else {
799 // Welp, the large model didn't work either. Use the default
800 // session name and return.
801 slog.Error("error generating title with large model", "err", err)
802 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
803 if saveErr != nil {
804 slog.Error("failed to save session title and usage", "error", saveErr)
805 }
806 return
807 }
808 }
809
810 if resp == nil {
811 // Actually, we didn't get a response so we can't. Use the default
812 // session name and return.
813 slog.Error("response is nil; can't generate title")
814 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
815 if saveErr != nil {
816 slog.Error("failed to save session title and usage", "error", saveErr)
817 }
818 return
819 }
820
821 // Clean up title.
822 var title string
823 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
824
825 // Remove thinking tags if present.
826 title = thinkTagRegex.ReplaceAllString(title, "")
827
828 title = strings.TrimSpace(title)
829 if title == "" {
830 slog.Warn("empty title; using fallback")
831 title = defaultSessionName
832 }
833
834 // Calculate usage and cost.
835 var openrouterCost *float64
836 for _, step := range resp.Steps {
837 stepCost := a.openrouterCost(step.ProviderMetadata)
838 if stepCost != nil {
839 newCost := *stepCost
840 if openrouterCost != nil {
841 newCost += *openrouterCost
842 }
843 openrouterCost = &newCost
844 }
845 }
846
847 modelConfig := model.CatwalkCfg
848 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
849 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
850 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
851 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
852
853 // Use override cost if available (e.g., from OpenRouter).
854 if openrouterCost != nil {
855 cost = *openrouterCost
856 }
857
858 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
859 completionTokens := resp.TotalUsage.OutputTokens
860
861 // Atomically update only title and usage fields to avoid overriding other
862 // concurrent session updates.
863 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
864 if saveErr != nil {
865 slog.Error("failed to save session title and usage", "error", saveErr)
866 return
867 }
868}
869
870func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
871 openrouterMetadata, ok := metadata[openrouter.Name]
872 if !ok {
873 return nil
874 }
875
876 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
877 if !ok {
878 return nil
879 }
880 return &opts.Usage.Cost
881}
882
883func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
884 modelConfig := model.CatwalkCfg
885 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
886 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
887 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
888 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
889
890 a.eventTokensUsed(session.ID, model, usage, cost)
891
892 if overrideCost != nil {
893 session.Cost += *overrideCost
894 } else {
895 session.Cost += cost
896 }
897
898 session.CompletionTokens = usage.OutputTokens
899 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
900}
901
902func (a *sessionAgent) Cancel(sessionID string) {
903 // Cancel regular requests. Don't use Take() here - we need the entry to
904 // remain in activeRequests so IsBusy() returns true until the goroutine
905 // fully completes (including error handling that may access the DB).
906 // The defer in processRequest will clean up the entry.
907 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
908 slog.Info("Request cancellation initiated", "session_id", sessionID)
909 cancel()
910 }
911
912 // Also check for summarize requests.
913 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
914 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
915 cancel()
916 }
917
918 if a.QueuedPrompts(sessionID) > 0 {
919 slog.Info("Clearing queued prompts", "session_id", sessionID)
920 a.messageQueue.Del(sessionID)
921 }
922}
923
924func (a *sessionAgent) ClearQueue(sessionID string) {
925 if a.QueuedPrompts(sessionID) > 0 {
926 slog.Info("Clearing queued prompts", "session_id", sessionID)
927 a.messageQueue.Del(sessionID)
928 }
929}
930
931func (a *sessionAgent) CancelAll() {
932 if !a.IsBusy() {
933 return
934 }
935 for key := range a.activeRequests.Seq2() {
936 a.Cancel(key) // key is sessionID
937 }
938
939 timeout := time.After(5 * time.Second)
940 for a.IsBusy() {
941 select {
942 case <-timeout:
943 return
944 default:
945 time.Sleep(200 * time.Millisecond)
946 }
947 }
948}
949
950func (a *sessionAgent) IsBusy() bool {
951 var busy bool
952 for cancelFunc := range a.activeRequests.Seq() {
953 if cancelFunc != nil {
954 busy = true
955 break
956 }
957 }
958 return busy
959}
960
961func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
962 _, busy := a.activeRequests.Get(sessionID)
963 return busy
964}
965
966func (a *sessionAgent) QueuedPrompts(sessionID string) int {
967 l, ok := a.messageQueue.Get(sessionID)
968 if !ok {
969 return 0
970 }
971 return len(l)
972}
973
974func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
975 l, ok := a.messageQueue.Get(sessionID)
976 if !ok {
977 return nil
978 }
979 prompts := make([]string, len(l))
980 for i, call := range l {
981 prompts[i] = call.Prompt
982 }
983 return prompts
984}
985
986func (a *sessionAgent) SetModels(large Model, small Model) {
987 a.largeModel.Set(large)
988 a.smallModel.Set(small)
989}
990
991func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
992 a.tools.SetSlice(tools)
993}
994
995func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
996 a.systemPrompt.Set(systemPrompt)
997}
998
999func (a *sessionAgent) Model() Model {
1000 return a.largeModel.Get()
1001}
1002
1003// convertToToolResult converts a fantasy tool result to a message tool result.
1004func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1005 baseResult := message.ToolResult{
1006 ToolCallID: result.ToolCallID,
1007 Name: result.ToolName,
1008 Metadata: result.ClientMetadata,
1009 }
1010
1011 switch result.Result.GetType() {
1012 case fantasy.ToolResultContentTypeText:
1013 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1014 baseResult.Content = r.Text
1015 }
1016 case fantasy.ToolResultContentTypeError:
1017 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1018 baseResult.Content = r.Error.Error()
1019 baseResult.IsError = true
1020 }
1021 case fantasy.ToolResultContentTypeMedia:
1022 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1023 content := r.Text
1024 if content == "" {
1025 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1026 }
1027 baseResult.Content = content
1028 baseResult.Data = r.Data
1029 baseResult.MIMEType = r.MediaType
1030 }
1031 }
1032
1033 return baseResult
1034}
1035
1036// workaroundProviderMediaLimitations converts media content in tool results to
1037// user messages for providers that don't natively support images in tool results.
1038//
1039// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1040// don't support sending images/media in tool result messages - they only accept
1041// text in tool results. However, they DO support images in user messages.
1042//
1043// If we send media in tool results to these providers, the API returns an error.
1044//
1045// Solution: For these providers, we:
1046// 1. Replace the media in the tool result with a text placeholder
1047// 2. Inject a user message immediately after with the image as a file attachment
1048// 3. This maintains the tool execution flow while working around API limitations
1049//
1050// Anthropic and Bedrock support images natively in tool results, so we skip
1051// this workaround for them.
1052//
1053// Example transformation:
1054//
1055// BEFORE: [tool result: image data]
1056// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1057func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1058 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1059 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1060
1061 if providerSupportsMedia {
1062 return messages
1063 }
1064
1065 convertedMessages := make([]fantasy.Message, 0, len(messages))
1066
1067 for _, msg := range messages {
1068 if msg.Role != fantasy.MessageRoleTool {
1069 convertedMessages = append(convertedMessages, msg)
1070 continue
1071 }
1072
1073 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1074 var mediaFiles []fantasy.FilePart
1075
1076 for _, part := range msg.Content {
1077 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1078 if !ok {
1079 textParts = append(textParts, part)
1080 continue
1081 }
1082
1083 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1084 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1085 if err != nil {
1086 slog.Warn("failed to decode media data", "error", err)
1087 textParts = append(textParts, part)
1088 continue
1089 }
1090
1091 mediaFiles = append(mediaFiles, fantasy.FilePart{
1092 Data: decoded,
1093 MediaType: media.MediaType,
1094 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1095 })
1096
1097 textParts = append(textParts, fantasy.ToolResultPart{
1098 ToolCallID: toolResult.ToolCallID,
1099 Output: fantasy.ToolResultOutputContentText{
1100 Text: "[Image/media content loaded - see attached file]",
1101 },
1102 ProviderOptions: toolResult.ProviderOptions,
1103 })
1104 } else {
1105 textParts = append(textParts, part)
1106 }
1107 }
1108
1109 convertedMessages = append(convertedMessages, fantasy.Message{
1110 Role: fantasy.MessageRoleTool,
1111 Content: textParts,
1112 })
1113
1114 if len(mediaFiles) > 0 {
1115 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1116 "Here is the media content from the tool result:",
1117 mediaFiles...,
1118 ))
1119 }
1120 }
1121
1122 return convertedMessages
1123}
1124
1125// buildSummaryPrompt constructs the prompt text for session summarization.
1126func buildSummaryPrompt(todos []session.Todo) string {
1127 var sb strings.Builder
1128 sb.WriteString("Provide a detailed summary of our conversation above.")
1129 if len(todos) > 0 {
1130 sb.WriteString("\n\n## Current Todo List\n\n")
1131 for _, t := range todos {
1132 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1133 }
1134 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1135 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1136 }
1137 return sb.String()
1138}