1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41 "github.com/charmbracelet/x/exp/charmtone"
42)
43
44const (
45 defaultSessionName = "Untitled Session"
46
47 // Constants for auto-summarization thresholds
48 largeContextWindowThreshold = 200_000
49 largeContextWindowBuffer = 20_000
50 smallContextWindowRatio = 0.2
51)
52
53//go:embed templates/title.md
54var titlePrompt []byte
55
56//go:embed templates/summary.md
57var summaryPrompt []byte
58
59// Used to remove <think> tags from generated titles.
60var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
61
62type SessionAgentCall struct {
63 SessionID string
64 Prompt string
65 ProviderOptions fantasy.ProviderOptions
66 Attachments []message.Attachment
67 MaxOutputTokens int64
68 Temperature *float64
69 TopP *float64
70 TopK *int64
71 FrequencyPenalty *float64
72 PresencePenalty *float64
73}
74
75type SessionAgent interface {
76 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
77 SetModels(large Model, small Model)
78 SetTools(tools []fantasy.AgentTool)
79 SetSystemPrompt(systemPrompt string)
80 Cancel(sessionID string)
81 CancelAll()
82 IsSessionBusy(sessionID string) bool
83 IsBusy() bool
84 QueuedPrompts(sessionID string) int
85 QueuedPromptsList(sessionID string) []string
86 ClearQueue(sessionID string)
87 Summarize(context.Context, string, fantasy.ProviderOptions) error
88 Model() Model
89}
90
91type Model struct {
92 Model fantasy.LanguageModel
93 CatwalkCfg catwalk.Model
94 ModelCfg config.SelectedModel
95}
96
97type sessionAgent struct {
98 largeModel *csync.Value[Model]
99 smallModel *csync.Value[Model]
100 systemPromptPrefix *csync.Value[string]
101 systemPrompt *csync.Value[string]
102 tools *csync.Slice[fantasy.AgentTool]
103
104 isSubAgent bool
105 sessions session.Service
106 messages message.Service
107 disableAutoSummarize bool
108 isYolo bool
109
110 messageQueue *csync.Map[string, []SessionAgentCall]
111 activeRequests *csync.Map[string, context.CancelFunc]
112}
113
114type SessionAgentOptions struct {
115 LargeModel Model
116 SmallModel Model
117 SystemPromptPrefix string
118 SystemPrompt string
119 IsSubAgent bool
120 DisableAutoSummarize bool
121 IsYolo bool
122 Sessions session.Service
123 Messages message.Service
124 Tools []fantasy.AgentTool
125}
126
127func NewSessionAgent(
128 opts SessionAgentOptions,
129) SessionAgent {
130 return &sessionAgent{
131 largeModel: csync.NewValue(opts.LargeModel),
132 smallModel: csync.NewValue(opts.SmallModel),
133 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
134 systemPrompt: csync.NewValue(opts.SystemPrompt),
135 isSubAgent: opts.IsSubAgent,
136 sessions: opts.Sessions,
137 messages: opts.Messages,
138 disableAutoSummarize: opts.DisableAutoSummarize,
139 tools: csync.NewSliceFrom(opts.Tools),
140 isYolo: opts.IsYolo,
141 messageQueue: csync.NewMap[string, []SessionAgentCall](),
142 activeRequests: csync.NewMap[string, context.CancelFunc](),
143 }
144}
145
146func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
147 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
148 return nil, ErrEmptyPrompt
149 }
150 if call.SessionID == "" {
151 return nil, ErrSessionMissing
152 }
153
154 // Queue the message if busy
155 if a.IsSessionBusy(call.SessionID) {
156 existing, ok := a.messageQueue.Get(call.SessionID)
157 if !ok {
158 existing = []SessionAgentCall{}
159 }
160 existing = append(existing, call)
161 a.messageQueue.Set(call.SessionID, existing)
162 return nil, nil
163 }
164
165 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
166 agentTools := a.tools.Copy()
167 largeModel := a.largeModel.Get()
168 systemPrompt := a.systemPrompt.Get()
169 promptPrefix := a.systemPromptPrefix.Get()
170
171 if len(agentTools) > 0 {
172 // Add Anthropic caching to the last tool.
173 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
174 }
175
176 agent := fantasy.NewAgent(
177 largeModel.Model,
178 fantasy.WithSystemPrompt(systemPrompt),
179 fantasy.WithTools(agentTools...),
180 )
181
182 sessionLock := sync.Mutex{}
183 currentSession, err := a.sessions.Get(ctx, call.SessionID)
184 if err != nil {
185 return nil, fmt.Errorf("failed to get session: %w", err)
186 }
187
188 msgs, err := a.getSessionMessages(ctx, currentSession)
189 if err != nil {
190 return nil, fmt.Errorf("failed to get session messages: %w", err)
191 }
192
193 var wg sync.WaitGroup
194 // Generate title if first message.
195 if len(msgs) == 0 {
196 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
197 wg.Go(func() {
198 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
199 })
200 }
201 defer wg.Wait()
202
203 // Add the user message to the session.
204 _, err = a.createUserMessage(ctx, call)
205 if err != nil {
206 return nil, err
207 }
208
209 // Add the session to the context.
210 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
211
212 genCtx, cancel := context.WithCancel(ctx)
213 a.activeRequests.Set(call.SessionID, cancel)
214
215 defer cancel()
216 defer a.activeRequests.Del(call.SessionID)
217
218 history, files := a.preparePrompt(msgs, call.Attachments...)
219
220 startTime := time.Now()
221 a.eventPromptSent(call.SessionID)
222
223 var currentAssistant *message.Message
224 var shouldSummarize bool
225 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
226 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
227 Files: files,
228 Messages: history,
229 ProviderOptions: call.ProviderOptions,
230 MaxOutputTokens: &call.MaxOutputTokens,
231 TopP: call.TopP,
232 Temperature: call.Temperature,
233 PresencePenalty: call.PresencePenalty,
234 TopK: call.TopK,
235 FrequencyPenalty: call.FrequencyPenalty,
236 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
237 prepared.Messages = options.Messages
238 for i := range prepared.Messages {
239 prepared.Messages[i].ProviderOptions = nil
240 }
241
242 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
243 a.messageQueue.Del(call.SessionID)
244 for _, queued := range queuedCalls {
245 userMessage, createErr := a.createUserMessage(callContext, queued)
246 if createErr != nil {
247 return callContext, prepared, createErr
248 }
249 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
250 }
251
252 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
253
254 lastSystemRoleInx := 0
255 systemMessageUpdated := false
256 for i, msg := range prepared.Messages {
257 // Only add cache control to the last message.
258 if msg.Role == fantasy.MessageRoleSystem {
259 lastSystemRoleInx = i
260 } else if !systemMessageUpdated {
261 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
262 systemMessageUpdated = true
263 }
264 // Than add cache control to the last 2 messages.
265 if i > len(prepared.Messages)-3 {
266 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
267 }
268 }
269
270 if promptPrefix != "" {
271 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
272 }
273
274 var assistantMsg message.Message
275 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
276 Role: message.Assistant,
277 Parts: []message.ContentPart{},
278 Model: largeModel.ModelCfg.Model,
279 Provider: largeModel.ModelCfg.Provider,
280 })
281 if err != nil {
282 return callContext, prepared, err
283 }
284 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
285 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
286 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
287 currentAssistant = &assistantMsg
288 return callContext, prepared, err
289 },
290 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
291 currentAssistant.AppendReasoningContent(reasoning.Text)
292 return a.messages.Update(genCtx, *currentAssistant)
293 },
294 OnReasoningDelta: func(id string, text string) error {
295 currentAssistant.AppendReasoningContent(text)
296 return a.messages.Update(genCtx, *currentAssistant)
297 },
298 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
299 // handle anthropic signature
300 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
301 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
302 currentAssistant.AppendReasoningSignature(reasoning.Signature)
303 }
304 }
305 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
306 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
307 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
308 }
309 }
310 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
311 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
312 currentAssistant.SetReasoningResponsesData(reasoning)
313 }
314 }
315 currentAssistant.FinishThinking()
316 return a.messages.Update(genCtx, *currentAssistant)
317 },
318 OnTextDelta: func(id string, text string) error {
319 // Strip leading newline from initial text content. This is is
320 // particularly important in non-interactive mode where leading
321 // newlines are very visible.
322 if len(currentAssistant.Parts) == 0 {
323 text = strings.TrimPrefix(text, "\n")
324 }
325
326 currentAssistant.AppendContent(text)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnToolInputStart: func(id string, toolName string) error {
330 toolCall := message.ToolCall{
331 ID: id,
332 Name: toolName,
333 ProviderExecuted: false,
334 Finished: false,
335 }
336 currentAssistant.AddToolCall(toolCall)
337 return a.messages.Update(genCtx, *currentAssistant)
338 },
339 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
340 // TODO: implement
341 },
342 OnToolCall: func(tc fantasy.ToolCallContent) error {
343 toolCall := message.ToolCall{
344 ID: tc.ToolCallID,
345 Name: tc.ToolName,
346 Input: tc.Input,
347 ProviderExecuted: false,
348 Finished: true,
349 }
350 currentAssistant.AddToolCall(toolCall)
351 return a.messages.Update(genCtx, *currentAssistant)
352 },
353 OnToolResult: func(result fantasy.ToolResultContent) error {
354 toolResult := a.convertToToolResult(result)
355 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
356 Role: message.Tool,
357 Parts: []message.ContentPart{
358 toolResult,
359 },
360 })
361 return createMsgErr
362 },
363 OnStepFinish: func(stepResult fantasy.StepResult) error {
364 finishReason := message.FinishReasonUnknown
365 switch stepResult.FinishReason {
366 case fantasy.FinishReasonLength:
367 finishReason = message.FinishReasonMaxTokens
368 case fantasy.FinishReasonStop:
369 finishReason = message.FinishReasonEndTurn
370 case fantasy.FinishReasonToolCalls:
371 finishReason = message.FinishReasonToolUse
372 }
373 currentAssistant.AddFinish(finishReason, "", "")
374 sessionLock.Lock()
375 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
376 if getSessionErr != nil {
377 sessionLock.Unlock()
378 return getSessionErr
379 }
380 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
381 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
382 sessionLock.Unlock()
383 if sessionErr != nil {
384 return sessionErr
385 }
386 return a.messages.Update(genCtx, *currentAssistant)
387 },
388 StopWhen: []fantasy.StopCondition{
389 func(_ []fantasy.StepResult) bool {
390 cw := int64(largeModel.CatwalkCfg.ContextWindow)
391 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
392 remaining := cw - tokens
393 var threshold int64
394 if cw > largeContextWindowThreshold {
395 threshold = largeContextWindowBuffer
396 } else {
397 threshold = int64(float64(cw) * smallContextWindowRatio)
398 }
399 if (remaining <= threshold) && !a.disableAutoSummarize {
400 shouldSummarize = true
401 return true
402 }
403 return false
404 },
405 },
406 })
407
408 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
409
410 if err != nil {
411 isCancelErr := errors.Is(err, context.Canceled)
412 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
413 if currentAssistant == nil {
414 return result, err
415 }
416 // Ensure we finish thinking on error to close the reasoning state.
417 currentAssistant.FinishThinking()
418 toolCalls := currentAssistant.ToolCalls()
419 // INFO: we use the parent context here because the genCtx has been cancelled.
420 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
421 if createErr != nil {
422 return nil, createErr
423 }
424 for _, tc := range toolCalls {
425 if !tc.Finished {
426 tc.Finished = true
427 tc.Input = "{}"
428 currentAssistant.AddToolCall(tc)
429 updateErr := a.messages.Update(ctx, *currentAssistant)
430 if updateErr != nil {
431 return nil, updateErr
432 }
433 }
434
435 found := false
436 for _, msg := range msgs {
437 if msg.Role == message.Tool {
438 for _, tr := range msg.ToolResults() {
439 if tr.ToolCallID == tc.ID {
440 found = true
441 break
442 }
443 }
444 }
445 if found {
446 break
447 }
448 }
449 if found {
450 continue
451 }
452 content := "There was an error while executing the tool"
453 if isCancelErr {
454 content = "Tool execution canceled by user"
455 } else if isPermissionErr {
456 content = "User denied permission"
457 }
458 toolResult := message.ToolResult{
459 ToolCallID: tc.ID,
460 Name: tc.Name,
461 Content: content,
462 IsError: true,
463 }
464 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
465 Role: message.Tool,
466 Parts: []message.ContentPart{
467 toolResult,
468 },
469 })
470 if createErr != nil {
471 return nil, createErr
472 }
473 }
474 var fantasyErr *fantasy.Error
475 var providerErr *fantasy.ProviderError
476 const defaultTitle = "Provider Error"
477 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
478 if isCancelErr {
479 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
480 } else if isPermissionErr {
481 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
482 } else if errors.Is(err, hyper.ErrNoCredits) {
483 url := hyper.BaseURL()
484 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
485 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
486 } else if errors.As(err, &providerErr) {
487 if providerErr.Message == "The requested model is not supported." {
488 url := "https://github.com/settings/copilot/features"
489 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
490 currentAssistant.AddFinish(
491 message.FinishReasonError,
492 "Copilot model not enabled",
493 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
494 )
495 } else {
496 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
497 }
498 } else if errors.As(err, &fantasyErr) {
499 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
500 } else {
501 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
502 }
503 // Note: we use the parent context here because the genCtx has been
504 // cancelled.
505 updateErr := a.messages.Update(ctx, *currentAssistant)
506 if updateErr != nil {
507 return nil, updateErr
508 }
509 return nil, err
510 }
511
512 if shouldSummarize {
513 a.activeRequests.Del(call.SessionID)
514 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
515 return nil, summarizeErr
516 }
517 // If the agent wasn't done...
518 if len(currentAssistant.ToolCalls()) > 0 {
519 existing, ok := a.messageQueue.Get(call.SessionID)
520 if !ok {
521 existing = []SessionAgentCall{}
522 }
523 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
524 existing = append(existing, call)
525 a.messageQueue.Set(call.SessionID, existing)
526 }
527 }
528
529 // Release active request before processing queued messages.
530 a.activeRequests.Del(call.SessionID)
531 cancel()
532
533 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
534 if !ok || len(queuedMessages) == 0 {
535 return result, err
536 }
537 // There are queued messages restart the loop.
538 firstQueuedMessage := queuedMessages[0]
539 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
540 return a.Run(ctx, firstQueuedMessage)
541}
542
543func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
544 if a.IsSessionBusy(sessionID) {
545 return ErrSessionBusy
546 }
547
548 // Copy mutable fields under lock to avoid races with SetModels.
549 largeModel := a.largeModel.Get()
550 systemPromptPrefix := a.systemPromptPrefix.Get()
551
552 currentSession, err := a.sessions.Get(ctx, sessionID)
553 if err != nil {
554 return fmt.Errorf("failed to get session: %w", err)
555 }
556 msgs, err := a.getSessionMessages(ctx, currentSession)
557 if err != nil {
558 return err
559 }
560 if len(msgs) == 0 {
561 // Nothing to summarize.
562 return nil
563 }
564
565 aiMsgs, _ := a.preparePrompt(msgs)
566
567 genCtx, cancel := context.WithCancel(ctx)
568 a.activeRequests.Set(sessionID, cancel)
569 defer a.activeRequests.Del(sessionID)
570 defer cancel()
571
572 agent := fantasy.NewAgent(largeModel.Model,
573 fantasy.WithSystemPrompt(string(summaryPrompt)),
574 )
575 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
576 Role: message.Assistant,
577 Model: largeModel.Model.Model(),
578 Provider: largeModel.Model.Provider(),
579 IsSummaryMessage: true,
580 })
581 if err != nil {
582 return err
583 }
584
585 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
586
587 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
588 Prompt: summaryPromptText,
589 Messages: aiMsgs,
590 ProviderOptions: opts,
591 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
592 prepared.Messages = options.Messages
593 if systemPromptPrefix != "" {
594 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
595 }
596 return callContext, prepared, nil
597 },
598 OnReasoningDelta: func(id string, text string) error {
599 summaryMessage.AppendReasoningContent(text)
600 return a.messages.Update(genCtx, summaryMessage)
601 },
602 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
603 // Handle anthropic signature.
604 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
605 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
606 summaryMessage.AppendReasoningSignature(signature.Signature)
607 }
608 }
609 summaryMessage.FinishThinking()
610 return a.messages.Update(genCtx, summaryMessage)
611 },
612 OnTextDelta: func(id, text string) error {
613 summaryMessage.AppendContent(text)
614 return a.messages.Update(genCtx, summaryMessage)
615 },
616 })
617 if err != nil {
618 isCancelErr := errors.Is(err, context.Canceled)
619 if isCancelErr {
620 // User cancelled summarize we need to remove the summary message.
621 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
622 return deleteErr
623 }
624 return err
625 }
626
627 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
628 err = a.messages.Update(genCtx, summaryMessage)
629 if err != nil {
630 return err
631 }
632
633 var openrouterCost *float64
634 for _, step := range resp.Steps {
635 stepCost := a.openrouterCost(step.ProviderMetadata)
636 if stepCost != nil {
637 newCost := *stepCost
638 if openrouterCost != nil {
639 newCost += *openrouterCost
640 }
641 openrouterCost = &newCost
642 }
643 }
644
645 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
646
647 // Just in case, get just the last usage info.
648 usage := resp.Response.Usage
649 currentSession.SummaryMessageID = summaryMessage.ID
650 currentSession.CompletionTokens = usage.OutputTokens
651 currentSession.PromptTokens = 0
652 _, err = a.sessions.Save(genCtx, currentSession)
653 return err
654}
655
656func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
657 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
658 return fantasy.ProviderOptions{}
659 }
660 return fantasy.ProviderOptions{
661 anthropic.Name: &anthropic.ProviderCacheControlOptions{
662 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
663 },
664 bedrock.Name: &anthropic.ProviderCacheControlOptions{
665 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
666 },
667 }
668}
669
670func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
671 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
672 var attachmentParts []message.ContentPart
673 for _, attachment := range call.Attachments {
674 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
675 }
676 parts = append(parts, attachmentParts...)
677 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
678 Role: message.User,
679 Parts: parts,
680 })
681 if err != nil {
682 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
683 }
684 return msg, nil
685}
686
687func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
688 var history []fantasy.Message
689 if !a.isSubAgent {
690 history = append(history, fantasy.NewUserMessage(
691 fmt.Sprintf("<system_reminder>%s</system_reminder>",
692 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
693If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
694If not, please feel free to ignore. Again do not mention this message to the user.`,
695 ),
696 ))
697 }
698 for _, m := range msgs {
699 if len(m.Parts) == 0 {
700 continue
701 }
702 // Assistant message without content or tool calls (cancelled before it
703 // returned anything).
704 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
705 continue
706 }
707 history = append(history, m.ToAIMessage()...)
708 }
709
710 var files []fantasy.FilePart
711 for _, attachment := range attachments {
712 if attachment.IsText() {
713 continue
714 }
715 files = append(files, fantasy.FilePart{
716 Filename: attachment.FileName,
717 Data: attachment.Content,
718 MediaType: attachment.MimeType,
719 })
720 }
721
722 return history, files
723}
724
725func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
726 msgs, err := a.messages.List(ctx, session.ID)
727 if err != nil {
728 return nil, fmt.Errorf("failed to list messages: %w", err)
729 }
730
731 if session.SummaryMessageID != "" {
732 summaryMsgIndex := -1
733 for i, msg := range msgs {
734 if msg.ID == session.SummaryMessageID {
735 summaryMsgIndex = i
736 break
737 }
738 }
739 if summaryMsgIndex != -1 {
740 msgs = msgs[summaryMsgIndex:]
741 msgs[0].Role = message.User
742 }
743 }
744 return msgs, nil
745}
746
747// generateTitle generates a session titled based on the initial prompt.
748func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
749 if userPrompt == "" {
750 return
751 }
752
753 smallModel := a.smallModel.Get()
754 largeModel := a.largeModel.Get()
755 systemPromptPrefix := a.systemPromptPrefix.Get()
756
757 var maxOutputTokens int64 = 40
758 if smallModel.CatwalkCfg.CanReason {
759 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
760 }
761
762 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
763 return fantasy.NewAgent(m,
764 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
765 fantasy.WithMaxOutputTokens(tok),
766 )
767 }
768
769 streamCall := fantasy.AgentStreamCall{
770 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
771 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
772 prepared.Messages = opts.Messages
773 if systemPromptPrefix != "" {
774 prepared.Messages = append([]fantasy.Message{
775 fantasy.NewSystemMessage(systemPromptPrefix),
776 }, prepared.Messages...)
777 }
778 return callCtx, prepared, nil
779 },
780 }
781
782 // Use the small model to generate the title.
783 model := smallModel
784 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
785 resp, err := agent.Stream(ctx, streamCall)
786 if err == nil {
787 // We successfully generated a title with the small model.
788 slog.Info("generated title with small model")
789 } else {
790 // It didn't work. Let's try with the big model.
791 slog.Error("error generating title with small model; trying big model", "err", err)
792 model = largeModel
793 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
794 resp, err = agent.Stream(ctx, streamCall)
795 if err == nil {
796 slog.Info("generated title with large model")
797 } else {
798 // Welp, the large model didn't work either. Use the default
799 // session name and return.
800 slog.Error("error generating title with large model", "err", err)
801 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
802 if saveErr != nil {
803 slog.Error("failed to save session title and usage", "error", saveErr)
804 }
805 return
806 }
807 }
808
809 if resp == nil {
810 // Actually, we didn't get a response so we can't. Use the default
811 // session name and return.
812 slog.Error("response is nil; can't generate title")
813 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
814 if saveErr != nil {
815 slog.Error("failed to save session title and usage", "error", saveErr)
816 }
817 return
818 }
819
820 // Clean up title.
821 var title string
822 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
823
824 // Remove thinking tags if present.
825 title = thinkTagRegex.ReplaceAllString(title, "")
826
827 title = strings.TrimSpace(title)
828 if title == "" {
829 slog.Warn("empty title; using fallback")
830 title = defaultSessionName
831 }
832
833 // Calculate usage and cost.
834 var openrouterCost *float64
835 for _, step := range resp.Steps {
836 stepCost := a.openrouterCost(step.ProviderMetadata)
837 if stepCost != nil {
838 newCost := *stepCost
839 if openrouterCost != nil {
840 newCost += *openrouterCost
841 }
842 openrouterCost = &newCost
843 }
844 }
845
846 modelConfig := model.CatwalkCfg
847 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
848 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
849 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
850 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
851
852 // Use override cost if available (e.g., from OpenRouter).
853 if openrouterCost != nil {
854 cost = *openrouterCost
855 }
856
857 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
858 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
859
860 // Atomically update only title and usage fields to avoid overriding other
861 // concurrent session updates.
862 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
863 if saveErr != nil {
864 slog.Error("failed to save session title and usage", "error", saveErr)
865 return
866 }
867}
868
869func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
870 openrouterMetadata, ok := metadata[openrouter.Name]
871 if !ok {
872 return nil
873 }
874
875 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
876 if !ok {
877 return nil
878 }
879 return &opts.Usage.Cost
880}
881
882func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
883 modelConfig := model.CatwalkCfg
884 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
885 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
886 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
887 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
888
889 a.eventTokensUsed(session.ID, model, usage, cost)
890
891 if overrideCost != nil {
892 session.Cost += *overrideCost
893 } else {
894 session.Cost += cost
895 }
896
897 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
898 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
899}
900
901func (a *sessionAgent) Cancel(sessionID string) {
902 // Cancel regular requests. Don't use Take() here - we need the entry to
903 // remain in activeRequests so IsBusy() returns true until the goroutine
904 // fully completes (including error handling that may access the DB).
905 // The defer in processRequest will clean up the entry.
906 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
907 slog.Info("Request cancellation initiated", "session_id", sessionID)
908 cancel()
909 }
910
911 // Also check for summarize requests.
912 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
913 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
914 cancel()
915 }
916
917 if a.QueuedPrompts(sessionID) > 0 {
918 slog.Info("Clearing queued prompts", "session_id", sessionID)
919 a.messageQueue.Del(sessionID)
920 }
921}
922
923func (a *sessionAgent) ClearQueue(sessionID string) {
924 if a.QueuedPrompts(sessionID) > 0 {
925 slog.Info("Clearing queued prompts", "session_id", sessionID)
926 a.messageQueue.Del(sessionID)
927 }
928}
929
930func (a *sessionAgent) CancelAll() {
931 if !a.IsBusy() {
932 return
933 }
934 for key := range a.activeRequests.Seq2() {
935 a.Cancel(key) // key is sessionID
936 }
937
938 timeout := time.After(5 * time.Second)
939 for a.IsBusy() {
940 select {
941 case <-timeout:
942 return
943 default:
944 time.Sleep(200 * time.Millisecond)
945 }
946 }
947}
948
949func (a *sessionAgent) IsBusy() bool {
950 var busy bool
951 for cancelFunc := range a.activeRequests.Seq() {
952 if cancelFunc != nil {
953 busy = true
954 break
955 }
956 }
957 return busy
958}
959
960func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
961 _, busy := a.activeRequests.Get(sessionID)
962 return busy
963}
964
965func (a *sessionAgent) QueuedPrompts(sessionID string) int {
966 l, ok := a.messageQueue.Get(sessionID)
967 if !ok {
968 return 0
969 }
970 return len(l)
971}
972
973func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
974 l, ok := a.messageQueue.Get(sessionID)
975 if !ok {
976 return nil
977 }
978 prompts := make([]string, len(l))
979 for i, call := range l {
980 prompts[i] = call.Prompt
981 }
982 return prompts
983}
984
985func (a *sessionAgent) SetModels(large Model, small Model) {
986 a.largeModel.Set(large)
987 a.smallModel.Set(small)
988}
989
990func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
991 a.tools.SetSlice(tools)
992}
993
994func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
995 a.systemPrompt.Set(systemPrompt)
996}
997
998func (a *sessionAgent) Model() Model {
999 return a.largeModel.Get()
1000}
1001
1002// convertToToolResult converts a fantasy tool result to a message tool result.
1003func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1004 baseResult := message.ToolResult{
1005 ToolCallID: result.ToolCallID,
1006 Name: result.ToolName,
1007 Metadata: result.ClientMetadata,
1008 }
1009
1010 switch result.Result.GetType() {
1011 case fantasy.ToolResultContentTypeText:
1012 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1013 baseResult.Content = r.Text
1014 }
1015 case fantasy.ToolResultContentTypeError:
1016 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1017 baseResult.Content = r.Error.Error()
1018 baseResult.IsError = true
1019 }
1020 case fantasy.ToolResultContentTypeMedia:
1021 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1022 content := r.Text
1023 if content == "" {
1024 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1025 }
1026 baseResult.Content = content
1027 baseResult.Data = r.Data
1028 baseResult.MIMEType = r.MediaType
1029 }
1030 }
1031
1032 return baseResult
1033}
1034
1035// workaroundProviderMediaLimitations converts media content in tool results to
1036// user messages for providers that don't natively support images in tool results.
1037//
1038// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1039// don't support sending images/media in tool result messages - they only accept
1040// text in tool results. However, they DO support images in user messages.
1041//
1042// If we send media in tool results to these providers, the API returns an error.
1043//
1044// Solution: For these providers, we:
1045// 1. Replace the media in the tool result with a text placeholder
1046// 2. Inject a user message immediately after with the image as a file attachment
1047// 3. This maintains the tool execution flow while working around API limitations
1048//
1049// Anthropic and Bedrock support images natively in tool results, so we skip
1050// this workaround for them.
1051//
1052// Example transformation:
1053//
1054// BEFORE: [tool result: image data]
1055// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1056func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1057 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1058 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1059
1060 if providerSupportsMedia {
1061 return messages
1062 }
1063
1064 convertedMessages := make([]fantasy.Message, 0, len(messages))
1065
1066 for _, msg := range messages {
1067 if msg.Role != fantasy.MessageRoleTool {
1068 convertedMessages = append(convertedMessages, msg)
1069 continue
1070 }
1071
1072 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1073 var mediaFiles []fantasy.FilePart
1074
1075 for _, part := range msg.Content {
1076 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1077 if !ok {
1078 textParts = append(textParts, part)
1079 continue
1080 }
1081
1082 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1083 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1084 if err != nil {
1085 slog.Warn("failed to decode media data", "error", err)
1086 textParts = append(textParts, part)
1087 continue
1088 }
1089
1090 mediaFiles = append(mediaFiles, fantasy.FilePart{
1091 Data: decoded,
1092 MediaType: media.MediaType,
1093 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1094 })
1095
1096 textParts = append(textParts, fantasy.ToolResultPart{
1097 ToolCallID: toolResult.ToolCallID,
1098 Output: fantasy.ToolResultOutputContentText{
1099 Text: "[Image/media content loaded - see attached file]",
1100 },
1101 ProviderOptions: toolResult.ProviderOptions,
1102 })
1103 } else {
1104 textParts = append(textParts, part)
1105 }
1106 }
1107
1108 convertedMessages = append(convertedMessages, fantasy.Message{
1109 Role: fantasy.MessageRoleTool,
1110 Content: textParts,
1111 })
1112
1113 if len(mediaFiles) > 0 {
1114 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1115 "Here is the media content from the tool result:",
1116 mediaFiles...,
1117 ))
1118 }
1119 }
1120
1121 return convertedMessages
1122}
1123
1124// buildSummaryPrompt constructs the prompt text for session summarization.
1125func buildSummaryPrompt(todos []session.Todo) string {
1126 var sb strings.Builder
1127 sb.WriteString("Provide a detailed summary of our conversation above.")
1128 if len(todos) > 0 {
1129 sb.WriteString("\n\n## Current Todo List\n\n")
1130 for _, t := range todos {
1131 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1132 }
1133 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1134 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1135 }
1136 return sb.String()
1137}