1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41 "github.com/charmbracelet/x/exp/charmtone"
42)
43
44const (
45 defaultSessionName = "Untitled Session"
46
47 // Constants for auto-summarization thresholds
48 largeContextWindowThreshold = 200_000
49 largeContextWindowBuffer = 20_000
50 smallContextWindowRatio = 0.2
51)
52
53//go:embed templates/title.md
54var titlePrompt []byte
55
56//go:embed templates/summary.md
57var summaryPrompt []byte
58
59// Used to remove <think> tags from generated titles.
60var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
61
62type SessionAgentCall struct {
63 SessionID string
64 Prompt string
65 ProviderOptions fantasy.ProviderOptions
66 Attachments []message.Attachment
67 MaxOutputTokens int64
68 Temperature *float64
69 TopP *float64
70 TopK *int64
71 FrequencyPenalty *float64
72 PresencePenalty *float64
73}
74
75type SessionAgent interface {
76 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
77 SetModels(large Model, small Model)
78 SetTools(tools []fantasy.AgentTool)
79 SetSystemPrompt(systemPrompt string)
80 Cancel(sessionID string)
81 CancelAll()
82 IsSessionBusy(sessionID string) bool
83 IsBusy() bool
84 QueuedPrompts(sessionID string) int
85 QueuedPromptsList(sessionID string) []string
86 ClearQueue(sessionID string)
87 Summarize(context.Context, string, fantasy.ProviderOptions) error
88 Model() Model
89}
90
91type Model struct {
92 Model fantasy.LanguageModel
93 CatwalkCfg catwalk.Model
94 ModelCfg config.SelectedModel
95}
96
97type sessionAgent struct {
98 largeModel *csync.Value[Model]
99 smallModel *csync.Value[Model]
100 systemPromptPrefix *csync.Value[string]
101 systemPrompt *csync.Value[string]
102 tools *csync.Slice[fantasy.AgentTool]
103
104 isSubAgent bool
105 sessions session.Service
106 messages message.Service
107 disableAutoSummarize bool
108 isYolo bool
109
110 messageQueue *csync.Map[string, []SessionAgentCall]
111 activeRequests *csync.Map[string, context.CancelFunc]
112}
113
114type SessionAgentOptions struct {
115 LargeModel Model
116 SmallModel Model
117 SystemPromptPrefix string
118 SystemPrompt string
119 IsSubAgent bool
120 DisableAutoSummarize bool
121 IsYolo bool
122 Sessions session.Service
123 Messages message.Service
124 Tools []fantasy.AgentTool
125}
126
127func NewSessionAgent(
128 opts SessionAgentOptions,
129) SessionAgent {
130 return &sessionAgent{
131 largeModel: csync.NewValue(opts.LargeModel),
132 smallModel: csync.NewValue(opts.SmallModel),
133 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
134 systemPrompt: csync.NewValue(opts.SystemPrompt),
135 isSubAgent: opts.IsSubAgent,
136 sessions: opts.Sessions,
137 messages: opts.Messages,
138 disableAutoSummarize: opts.DisableAutoSummarize,
139 tools: csync.NewSliceFrom(opts.Tools),
140 isYolo: opts.IsYolo,
141 messageQueue: csync.NewMap[string, []SessionAgentCall](),
142 activeRequests: csync.NewMap[string, context.CancelFunc](),
143 }
144}
145
146func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
147 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
148 return nil, ErrEmptyPrompt
149 }
150 if call.SessionID == "" {
151 return nil, ErrSessionMissing
152 }
153
154 // Queue the message if busy
155 if a.IsSessionBusy(call.SessionID) {
156 existing, ok := a.messageQueue.Get(call.SessionID)
157 if !ok {
158 existing = []SessionAgentCall{}
159 }
160 existing = append(existing, call)
161 a.messageQueue.Set(call.SessionID, existing)
162 return nil, nil
163 }
164
165 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
166 agentTools := a.tools.Copy()
167 largeModel := a.largeModel.Get()
168 systemPrompt := a.systemPrompt.Get()
169 promptPrefix := a.systemPromptPrefix.Get()
170
171 if len(agentTools) > 0 {
172 // Add Anthropic caching to the last tool.
173 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
174 }
175
176 agent := fantasy.NewAgent(
177 largeModel.Model,
178 fantasy.WithSystemPrompt(systemPrompt),
179 fantasy.WithTools(agentTools...),
180 )
181
182 sessionLock := sync.Mutex{}
183 currentSession, err := a.sessions.Get(ctx, call.SessionID)
184 if err != nil {
185 return nil, fmt.Errorf("failed to get session: %w", err)
186 }
187
188 msgs, err := a.getSessionMessages(ctx, currentSession)
189 if err != nil {
190 return nil, fmt.Errorf("failed to get session messages: %w", err)
191 }
192
193 var wg sync.WaitGroup
194 // Generate title if first message.
195 if len(msgs) == 0 {
196 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
197 wg.Go(func() {
198 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
199 })
200 }
201 defer wg.Wait()
202
203 // Add the user message to the session.
204 _, err = a.createUserMessage(ctx, call)
205 if err != nil {
206 return nil, err
207 }
208
209 // Add the session to the context.
210 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
211
212 genCtx, cancel := context.WithCancel(ctx)
213 a.activeRequests.Set(call.SessionID, cancel)
214
215 defer cancel()
216 defer a.activeRequests.Del(call.SessionID)
217
218 history, files := a.preparePrompt(msgs, call.Attachments...)
219
220 startTime := time.Now()
221 a.eventPromptSent(call.SessionID)
222
223 var currentAssistant *message.Message
224 var shouldSummarize bool
225 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
226 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
227 Files: files,
228 Messages: history,
229 ProviderOptions: call.ProviderOptions,
230 MaxOutputTokens: &call.MaxOutputTokens,
231 TopP: call.TopP,
232 Temperature: call.Temperature,
233 PresencePenalty: call.PresencePenalty,
234 TopK: call.TopK,
235 FrequencyPenalty: call.FrequencyPenalty,
236 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
237 prepared.Messages = options.Messages
238 for i := range prepared.Messages {
239 prepared.Messages[i].ProviderOptions = nil
240 }
241
242 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
243 a.messageQueue.Del(call.SessionID)
244 for _, queued := range queuedCalls {
245 userMessage, createErr := a.createUserMessage(callContext, queued)
246 if createErr != nil {
247 return callContext, prepared, createErr
248 }
249 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
250 }
251
252 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
253
254 lastSystemRoleInx := 0
255 systemMessageUpdated := false
256 for i, msg := range prepared.Messages {
257 // Only add cache control to the last message.
258 if msg.Role == fantasy.MessageRoleSystem {
259 lastSystemRoleInx = i
260 } else if !systemMessageUpdated {
261 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
262 systemMessageUpdated = true
263 }
264 // Than add cache control to the last 2 messages.
265 if i > len(prepared.Messages)-3 {
266 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
267 }
268 }
269
270 if promptPrefix != "" {
271 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
272 }
273
274 var assistantMsg message.Message
275 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
276 Role: message.Assistant,
277 Parts: []message.ContentPart{},
278 Model: largeModel.ModelCfg.Model,
279 Provider: largeModel.ModelCfg.Provider,
280 })
281 if err != nil {
282 return callContext, prepared, err
283 }
284 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
285 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
286 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
287 currentAssistant = &assistantMsg
288 return callContext, prepared, err
289 },
290 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
291 currentAssistant.AppendReasoningContent(reasoning.Text)
292 return a.messages.Update(genCtx, *currentAssistant)
293 },
294 OnReasoningDelta: func(id string, text string) error {
295 currentAssistant.AppendReasoningContent(text)
296 return a.messages.Update(genCtx, *currentAssistant)
297 },
298 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
299 // handle anthropic signature
300 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
301 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
302 currentAssistant.AppendReasoningSignature(reasoning.Signature)
303 }
304 }
305 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
306 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
307 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
308 }
309 }
310 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
311 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
312 currentAssistant.SetReasoningResponsesData(reasoning)
313 }
314 }
315 currentAssistant.FinishThinking()
316 return a.messages.Update(genCtx, *currentAssistant)
317 },
318 OnTextDelta: func(id string, text string) error {
319 // Strip leading newline from initial text content. This is is
320 // particularly important in non-interactive mode where leading
321 // newlines are very visible.
322 if len(currentAssistant.Parts) == 0 {
323 text = strings.TrimPrefix(text, "\n")
324 }
325
326 currentAssistant.AppendContent(text)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnToolInputStart: func(id string, toolName string) error {
330 toolCall := message.ToolCall{
331 ID: id,
332 Name: toolName,
333 ProviderExecuted: false,
334 Finished: false,
335 }
336 currentAssistant.AddToolCall(toolCall)
337 return a.messages.Update(genCtx, *currentAssistant)
338 },
339 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
340 // TODO: implement
341 },
342 OnToolCall: func(tc fantasy.ToolCallContent) error {
343 toolCall := message.ToolCall{
344 ID: tc.ToolCallID,
345 Name: tc.ToolName,
346 Input: tc.Input,
347 ProviderExecuted: false,
348 Finished: true,
349 }
350 currentAssistant.AddToolCall(toolCall)
351 return a.messages.Update(genCtx, *currentAssistant)
352 },
353 OnToolResult: func(result fantasy.ToolResultContent) error {
354 toolResult := a.convertToToolResult(result)
355 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
356 Role: message.Tool,
357 Parts: []message.ContentPart{
358 toolResult,
359 },
360 })
361 return createMsgErr
362 },
363 OnStepFinish: func(stepResult fantasy.StepResult) error {
364 finishReason := message.FinishReasonUnknown
365 switch stepResult.FinishReason {
366 case fantasy.FinishReasonLength:
367 finishReason = message.FinishReasonMaxTokens
368 case fantasy.FinishReasonStop:
369 finishReason = message.FinishReasonEndTurn
370 case fantasy.FinishReasonToolCalls:
371 finishReason = message.FinishReasonToolUse
372 }
373 currentAssistant.AddFinish(finishReason, "", "")
374 sessionLock.Lock()
375 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
376 if getSessionErr != nil {
377 sessionLock.Unlock()
378 return getSessionErr
379 }
380 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
381 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
382 sessionLock.Unlock()
383 if sessionErr != nil {
384 return sessionErr
385 }
386 return a.messages.Update(genCtx, *currentAssistant)
387 },
388 StopWhen: []fantasy.StopCondition{
389 func(_ []fantasy.StepResult) bool {
390 cw := int64(largeModel.CatwalkCfg.ContextWindow)
391 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
392 remaining := cw - tokens
393 var threshold int64
394 if cw > largeContextWindowThreshold {
395 threshold = largeContextWindowBuffer
396 } else {
397 threshold = int64(float64(cw) * smallContextWindowRatio)
398 }
399 if (remaining <= threshold) && !a.disableAutoSummarize {
400 shouldSummarize = true
401 return true
402 }
403 return false
404 },
405 },
406 })
407
408 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
409
410 if err != nil {
411 isCancelErr := errors.Is(err, context.Canceled)
412 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
413 if currentAssistant == nil {
414 return result, err
415 }
416 // Ensure we finish thinking on error to close the reasoning state.
417 currentAssistant.FinishThinking()
418 toolCalls := currentAssistant.ToolCalls()
419 // INFO: we use the parent context here because the genCtx has been cancelled.
420 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
421 if createErr != nil {
422 return nil, createErr
423 }
424 for _, tc := range toolCalls {
425 if !tc.Finished {
426 tc.Finished = true
427 tc.Input = "{}"
428 currentAssistant.AddToolCall(tc)
429 updateErr := a.messages.Update(ctx, *currentAssistant)
430 if updateErr != nil {
431 return nil, updateErr
432 }
433 }
434
435 found := false
436 for _, msg := range msgs {
437 if msg.Role == message.Tool {
438 for _, tr := range msg.ToolResults() {
439 if tr.ToolCallID == tc.ID {
440 found = true
441 break
442 }
443 }
444 }
445 if found {
446 break
447 }
448 }
449 if found {
450 continue
451 }
452 content := "There was an error while executing the tool"
453 if isCancelErr {
454 content = "Tool execution canceled by user"
455 } else if isPermissionErr {
456 content = "User denied permission"
457 }
458 toolResult := message.ToolResult{
459 ToolCallID: tc.ID,
460 Name: tc.Name,
461 Content: content,
462 IsError: true,
463 }
464 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
465 Role: message.Tool,
466 Parts: []message.ContentPart{
467 toolResult,
468 },
469 })
470 if createErr != nil {
471 return nil, createErr
472 }
473 }
474 var fantasyErr *fantasy.Error
475 var providerErr *fantasy.ProviderError
476 const defaultTitle = "Provider Error"
477 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
478 if isCancelErr {
479 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
480 } else if isPermissionErr {
481 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
482 } else if errors.Is(err, hyper.ErrNoCredits) {
483 url := hyper.BaseURL()
484 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
485 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
486 } else if errors.As(err, &providerErr) {
487 if providerErr.Message == "The requested model is not supported." {
488 url := "https://github.com/settings/copilot/features"
489 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
490 currentAssistant.AddFinish(
491 message.FinishReasonError,
492 "Copilot model not enabled",
493 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
494 )
495 } else {
496 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
497 }
498 } else if errors.As(err, &fantasyErr) {
499 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
500 } else {
501 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
502 }
503 // Note: we use the parent context here because the genCtx has been
504 // cancelled.
505 updateErr := a.messages.Update(ctx, *currentAssistant)
506 if updateErr != nil {
507 return nil, updateErr
508 }
509 return nil, err
510 }
511
512 if shouldSummarize {
513 a.activeRequests.Del(call.SessionID)
514 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
515 return nil, summarizeErr
516 }
517 // If the agent wasn't done...
518 if len(currentAssistant.ToolCalls()) > 0 {
519 existing, ok := a.messageQueue.Get(call.SessionID)
520 if !ok {
521 existing = []SessionAgentCall{}
522 }
523 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
524 existing = append(existing, call)
525 a.messageQueue.Set(call.SessionID, existing)
526 }
527 }
528
529 // Release active request before processing queued messages.
530 a.activeRequests.Del(call.SessionID)
531 cancel()
532
533 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
534 if !ok || len(queuedMessages) == 0 {
535 return result, err
536 }
537 // There are queued messages restart the loop.
538 firstQueuedMessage := queuedMessages[0]
539 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
540 return a.Run(ctx, firstQueuedMessage)
541}
542
543func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
544 if a.IsSessionBusy(sessionID) {
545 return ErrSessionBusy
546 }
547
548 // Copy mutable fields under lock to avoid races with SetModels.
549 largeModel := a.largeModel.Get()
550 systemPromptPrefix := a.systemPromptPrefix.Get()
551
552 currentSession, err := a.sessions.Get(ctx, sessionID)
553 if err != nil {
554 return fmt.Errorf("failed to get session: %w", err)
555 }
556 msgs, err := a.getSessionMessages(ctx, currentSession)
557 if err != nil {
558 return err
559 }
560 if len(msgs) == 0 {
561 // Nothing to summarize.
562 return nil
563 }
564
565 aiMsgs, _ := a.preparePrompt(msgs)
566
567 genCtx, cancel := context.WithCancel(ctx)
568 a.activeRequests.Set(sessionID, cancel)
569 defer a.activeRequests.Del(sessionID)
570 defer cancel()
571
572 agent := fantasy.NewAgent(largeModel.Model,
573 fantasy.WithSystemPrompt(string(summaryPrompt)),
574 )
575 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
576 Role: message.Assistant,
577 Model: largeModel.Model.Model(),
578 Provider: largeModel.Model.Provider(),
579 IsSummaryMessage: true,
580 })
581 if err != nil {
582 return err
583 }
584
585 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
586
587 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
588 Prompt: summaryPromptText,
589 Messages: aiMsgs,
590 ProviderOptions: opts,
591 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
592 prepared.Messages = options.Messages
593 if systemPromptPrefix != "" {
594 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
595 }
596 return callContext, prepared, nil
597 },
598 OnReasoningDelta: func(id string, text string) error {
599 summaryMessage.AppendReasoningContent(text)
600 return a.messages.Update(genCtx, summaryMessage)
601 },
602 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
603 // Handle anthropic signature.
604 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
605 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
606 summaryMessage.AppendReasoningSignature(signature.Signature)
607 }
608 }
609 summaryMessage.FinishThinking()
610 return a.messages.Update(genCtx, summaryMessage)
611 },
612 OnTextDelta: func(id, text string) error {
613 summaryMessage.AppendContent(text)
614 return a.messages.Update(genCtx, summaryMessage)
615 },
616 })
617 if err != nil {
618 isCancelErr := errors.Is(err, context.Canceled)
619 if isCancelErr {
620 // User cancelled summarize we need to remove the summary message.
621 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
622 return deleteErr
623 }
624 return err
625 }
626
627 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
628 err = a.messages.Update(genCtx, summaryMessage)
629 if err != nil {
630 return err
631 }
632
633 var openrouterCost *float64
634 for _, step := range resp.Steps {
635 stepCost := a.openrouterCost(step.ProviderMetadata)
636 if stepCost != nil {
637 newCost := *stepCost
638 if openrouterCost != nil {
639 newCost += *openrouterCost
640 }
641 openrouterCost = &newCost
642 }
643 }
644
645 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
646
647 // Just in case, get just the last usage info.
648 usage := resp.Response.Usage
649 currentSession.SummaryMessageID = summaryMessage.ID
650 currentSession.CompletionTokens = usage.OutputTokens
651 currentSession.PromptTokens = 0
652 _, err = a.sessions.Save(genCtx, currentSession)
653 return err
654}
655
656func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
657 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
658 return fantasy.ProviderOptions{}
659 }
660 return fantasy.ProviderOptions{
661 anthropic.Name: &anthropic.ProviderCacheControlOptions{
662 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
663 },
664 bedrock.Name: &anthropic.ProviderCacheControlOptions{
665 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
666 },
667 }
668}
669
670func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
671 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
672 var attachmentParts []message.ContentPart
673 for _, attachment := range call.Attachments {
674 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
675 }
676 parts = append(parts, attachmentParts...)
677 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
678 Role: message.User,
679 Parts: parts,
680 })
681 if err != nil {
682 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
683 }
684 return msg, nil
685}
686
687func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
688 var history []fantasy.Message
689 if !a.isSubAgent {
690 history = append(history, fantasy.NewUserMessage(
691 fmt.Sprintf("<system_reminder>%s</system_reminder>",
692 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
693If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
694If not, please feel free to ignore. Again do not mention this message to the user.`,
695 ),
696 ))
697 }
698 for _, m := range msgs {
699 if len(m.Parts) == 0 {
700 continue
701 }
702 // Assistant message without content or tool calls (cancelled before it
703 // returned anything).
704 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
705 continue
706 }
707 history = append(history, m.ToAIMessage()...)
708 }
709
710 var files []fantasy.FilePart
711 for _, attachment := range attachments {
712 if attachment.IsText() {
713 continue
714 }
715 files = append(files, fantasy.FilePart{
716 Filename: attachment.FileName,
717 Data: attachment.Content,
718 MediaType: attachment.MimeType,
719 })
720 }
721
722 return history, files
723}
724
725func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
726 msgs, err := a.messages.List(ctx, session.ID)
727 if err != nil {
728 return nil, fmt.Errorf("failed to list messages: %w", err)
729 }
730
731 if session.SummaryMessageID != "" {
732 summaryMsgIndex := -1
733 for i, msg := range msgs {
734 if msg.ID == session.SummaryMessageID {
735 summaryMsgIndex = i
736 break
737 }
738 }
739 if summaryMsgIndex != -1 {
740 msgs = msgs[summaryMsgIndex:]
741 msgs[0].Role = message.User
742 }
743 }
744 return msgs, nil
745}
746
747// generateTitle generates a session titled based on the initial prompt.
748func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
749 if userPrompt == "" {
750 return
751 }
752
753 smallModel := a.smallModel.Get()
754 largeModel := a.largeModel.Get()
755 systemPromptPrefix := a.systemPromptPrefix.Get()
756
757 var maxOutputTokens int64 = 40
758 if smallModel.CatwalkCfg.CanReason {
759 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
760 }
761
762 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
763 return fantasy.NewAgent(m,
764 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
765 fantasy.WithMaxOutputTokens(tok),
766 )
767 }
768
769 var reasoningContent strings.Builder
770
771 streamCall := fantasy.AgentStreamCall{
772 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
773 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
774 prepared.Messages = opts.Messages
775 if systemPromptPrefix != "" {
776 prepared.Messages = append([]fantasy.Message{
777 fantasy.NewSystemMessage(systemPromptPrefix),
778 }, prepared.Messages...)
779 }
780 return callCtx, prepared, nil
781 },
782 OnReasoningDelta: func(id string, text string) error {
783 // Also capture reasoning for title fallback.
784 reasoningContent.WriteString(text)
785 return nil
786 },
787 }
788
789 // Use the small model to generate the title.
790 model := smallModel
791 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
792 resp, err := agent.Stream(ctx, streamCall)
793 if err == nil {
794 // We successfully generated a title with the small model.
795 slog.Info("generated title with small model")
796 } else {
797 // It didn't work. Let's try with the big model.
798 slog.Error("error generating title with small model; trying big model", "err", err)
799 model = largeModel
800 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
801 resp, err = agent.Stream(ctx, streamCall)
802 if err == nil {
803 slog.Info("generated title with large model")
804 } else {
805 // Welp, the large model didn't work either. Use the default
806 // session name and return.
807 slog.Error("error generating title with large model", "err", err)
808 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
809 if saveErr != nil {
810 slog.Error("failed to save session title and usage", "error", saveErr)
811 }
812 return
813 }
814 }
815
816 if resp == nil {
817 // Actually, we didn't get a response so we can't. Use the default
818 // session name and return.
819 slog.Error("response is nil; can't generate title")
820 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
821 if saveErr != nil {
822 slog.Error("failed to save session title and usage", "error", saveErr)
823 }
824 return
825 }
826
827 // Clean up title.
828 var title string
829 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
830
831 // Remove thinking tags if present.
832 title = removeThinkingTags(title)
833
834 // If title is empty, try reasoning content (models may put the title in
835 // reasoning).
836 if title == "" && reasoningContent.Len() > 0 {
837 reasoningTitle := strings.ReplaceAll(reasoningContent.String(), "\n", " ")
838 reasoningTitle = removeThinkingTags(reasoningTitle)
839 // Extract last sentence or reasonable length from reasoning, if
840 // present.
841 if len(reasoningTitle) > 0 {
842 if sentences := strings.Split(reasoningTitle, "."); len(sentences) > 1 {
843 reasoningTitle = strings.TrimSpace(sentences[len(sentences)-1])
844 }
845 }
846 title = reasoningTitle
847 }
848
849 if title == "" {
850 slog.Warn("empty title; using fallback")
851 title = defaultSessionName
852 }
853
854 // Calculate usage and cost.
855 var openrouterCost *float64
856 for _, step := range resp.Steps {
857 stepCost := a.openrouterCost(step.ProviderMetadata)
858 if stepCost != nil {
859 newCost := *stepCost
860 if openrouterCost != nil {
861 newCost += *openrouterCost
862 }
863 openrouterCost = &newCost
864 }
865 }
866
867 modelConfig := model.CatwalkCfg
868 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
869 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
870 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
871 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
872
873 // Use override cost if available (e.g., from OpenRouter).
874 if openrouterCost != nil {
875 cost = *openrouterCost
876 }
877
878 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
879 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
880
881 // Atomically update only title and usage fields to avoid overriding other
882 // concurrent session updates.
883 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
884 if saveErr != nil {
885 slog.Error("failed to save session title and usage", "error", saveErr)
886 return
887 }
888}
889
890func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
891 openrouterMetadata, ok := metadata[openrouter.Name]
892 if !ok {
893 return nil
894 }
895
896 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
897 if !ok {
898 return nil
899 }
900 return &opts.Usage.Cost
901}
902
903func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
904 modelConfig := model.CatwalkCfg
905 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
906 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
907 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
908 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
909
910 a.eventTokensUsed(session.ID, model, usage, cost)
911
912 if overrideCost != nil {
913 session.Cost += *overrideCost
914 } else {
915 session.Cost += cost
916 }
917
918 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
919 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
920}
921
922func (a *sessionAgent) Cancel(sessionID string) {
923 // Cancel regular requests. Don't use Take() here - we need the entry to
924 // remain in activeRequests so IsBusy() returns true until the goroutine
925 // fully completes (including error handling that may access the DB).
926 // The defer in processRequest will clean up the entry.
927 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
928 slog.Info("Request cancellation initiated", "session_id", sessionID)
929 cancel()
930 }
931
932 // Also check for summarize requests.
933 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
934 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
935 cancel()
936 }
937
938 if a.QueuedPrompts(sessionID) > 0 {
939 slog.Info("Clearing queued prompts", "session_id", sessionID)
940 a.messageQueue.Del(sessionID)
941 }
942}
943
944func (a *sessionAgent) ClearQueue(sessionID string) {
945 if a.QueuedPrompts(sessionID) > 0 {
946 slog.Info("Clearing queued prompts", "session_id", sessionID)
947 a.messageQueue.Del(sessionID)
948 }
949}
950
951func (a *sessionAgent) CancelAll() {
952 if !a.IsBusy() {
953 return
954 }
955 for key := range a.activeRequests.Seq2() {
956 a.Cancel(key) // key is sessionID
957 }
958
959 timeout := time.After(5 * time.Second)
960 for a.IsBusy() {
961 select {
962 case <-timeout:
963 return
964 default:
965 time.Sleep(200 * time.Millisecond)
966 }
967 }
968}
969
970func (a *sessionAgent) IsBusy() bool {
971 var busy bool
972 for cancelFunc := range a.activeRequests.Seq() {
973 if cancelFunc != nil {
974 busy = true
975 break
976 }
977 }
978 return busy
979}
980
981func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
982 _, busy := a.activeRequests.Get(sessionID)
983 return busy
984}
985
986func (a *sessionAgent) QueuedPrompts(sessionID string) int {
987 l, ok := a.messageQueue.Get(sessionID)
988 if !ok {
989 return 0
990 }
991 return len(l)
992}
993
994func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
995 l, ok := a.messageQueue.Get(sessionID)
996 if !ok {
997 return nil
998 }
999 prompts := make([]string, len(l))
1000 for i, call := range l {
1001 prompts[i] = call.Prompt
1002 }
1003 return prompts
1004}
1005
1006func (a *sessionAgent) SetModels(large Model, small Model) {
1007 a.largeModel.Set(large)
1008 a.smallModel.Set(small)
1009}
1010
1011func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1012 a.tools.SetSlice(tools)
1013}
1014
1015func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1016 a.systemPrompt.Set(systemPrompt)
1017}
1018
1019func (a *sessionAgent) Model() Model {
1020 return a.largeModel.Get()
1021}
1022
1023// convertToToolResult converts a fantasy tool result to a message tool result.
1024func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1025 baseResult := message.ToolResult{
1026 ToolCallID: result.ToolCallID,
1027 Name: result.ToolName,
1028 Metadata: result.ClientMetadata,
1029 }
1030
1031 switch result.Result.GetType() {
1032 case fantasy.ToolResultContentTypeText:
1033 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1034 baseResult.Content = r.Text
1035 }
1036 case fantasy.ToolResultContentTypeError:
1037 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1038 baseResult.Content = r.Error.Error()
1039 baseResult.IsError = true
1040 }
1041 case fantasy.ToolResultContentTypeMedia:
1042 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1043 content := r.Text
1044 if content == "" {
1045 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1046 }
1047 baseResult.Content = content
1048 baseResult.Data = r.Data
1049 baseResult.MIMEType = r.MediaType
1050 }
1051 }
1052
1053 return baseResult
1054}
1055
1056// workaroundProviderMediaLimitations converts media content in tool results to
1057// user messages for providers that don't natively support images in tool results.
1058//
1059// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1060// don't support sending images/media in tool result messages - they only accept
1061// text in tool results. However, they DO support images in user messages.
1062//
1063// If we send media in tool results to these providers, the API returns an error.
1064//
1065// Solution: For these providers, we:
1066// 1. Replace the media in the tool result with a text placeholder
1067// 2. Inject a user message immediately after with the image as a file attachment
1068// 3. This maintains the tool execution flow while working around API limitations
1069//
1070// Anthropic and Bedrock support images natively in tool results, so we skip
1071// this workaround for them.
1072//
1073// Example transformation:
1074//
1075// BEFORE: [tool result: image data]
1076// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1077func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1078 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1079 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1080
1081 if providerSupportsMedia {
1082 return messages
1083 }
1084
1085 convertedMessages := make([]fantasy.Message, 0, len(messages))
1086
1087 for _, msg := range messages {
1088 if msg.Role != fantasy.MessageRoleTool {
1089 convertedMessages = append(convertedMessages, msg)
1090 continue
1091 }
1092
1093 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1094 var mediaFiles []fantasy.FilePart
1095
1096 for _, part := range msg.Content {
1097 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1098 if !ok {
1099 textParts = append(textParts, part)
1100 continue
1101 }
1102
1103 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1104 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1105 if err != nil {
1106 slog.Warn("failed to decode media data", "error", err)
1107 textParts = append(textParts, part)
1108 continue
1109 }
1110
1111 mediaFiles = append(mediaFiles, fantasy.FilePart{
1112 Data: decoded,
1113 MediaType: media.MediaType,
1114 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1115 })
1116
1117 textParts = append(textParts, fantasy.ToolResultPart{
1118 ToolCallID: toolResult.ToolCallID,
1119 Output: fantasy.ToolResultOutputContentText{
1120 Text: "[Image/media content loaded - see attached file]",
1121 },
1122 ProviderOptions: toolResult.ProviderOptions,
1123 })
1124 } else {
1125 textParts = append(textParts, part)
1126 }
1127 }
1128
1129 convertedMessages = append(convertedMessages, fantasy.Message{
1130 Role: fantasy.MessageRoleTool,
1131 Content: textParts,
1132 })
1133
1134 if len(mediaFiles) > 0 {
1135 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1136 "Here is the media content from the tool result:",
1137 mediaFiles...,
1138 ))
1139 }
1140 }
1141
1142 return convertedMessages
1143}
1144
1145// buildSummaryPrompt constructs the prompt text for session summarization.
1146func buildSummaryPrompt(todos []session.Todo) string {
1147 var sb strings.Builder
1148 sb.WriteString("Provide a detailed summary of our conversation above.")
1149 if len(todos) > 0 {
1150 sb.WriteString("\n\n## Current Todo List\n\n")
1151 for _, t := range todos {
1152 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1153 }
1154 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1155 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1156 }
1157 return sb.String()
1158}
1159
1160// removeThinkingTags removes <think>...</think> tags from the given string.
1161// Used to clean up generated session titles.
1162func removeThinkingTags(s string) string {
1163 s = thinkTagRegex.ReplaceAllString(s, "")
1164 return strings.TrimSpace(s)
1165}