1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43const (
44 defaultSessionName = "Untitled Session"
45
46 // Constants for auto-summarization thresholds
47 largeContextWindowThreshold = 200_000
48 largeContextWindowBuffer = 20_000
49 smallContextWindowRatio = 0.2
50)
51
52//go:embed templates/title.md
53var titlePrompt []byte
54
55//go:embed templates/summary.md
56var summaryPrompt []byte
57
58// Used to remove <think> tags from generated titles.
59var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
60
61type SessionAgentCall struct {
62 SessionID string
63 Prompt string
64 ProviderOptions fantasy.ProviderOptions
65 Attachments []message.Attachment
66 MaxOutputTokens int64
67 Temperature *float64
68 TopP *float64
69 TopK *int64
70 FrequencyPenalty *float64
71 PresencePenalty *float64
72}
73
74type SessionAgent interface {
75 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
76 SetModels(large Model, small Model)
77 SetTools(tools []fantasy.AgentTool)
78 SetSystemPrompt(systemPrompt string)
79 Cancel(sessionID string)
80 CancelAll()
81 IsSessionBusy(sessionID string) bool
82 IsBusy() bool
83 QueuedPrompts(sessionID string) int
84 QueuedPromptsList(sessionID string) []string
85 ClearQueue(sessionID string)
86 Summarize(context.Context, string, fantasy.ProviderOptions) error
87 Model() Model
88}
89
90type Model struct {
91 Model fantasy.LanguageModel
92 CatwalkCfg catwalk.Model
93 ModelCfg config.SelectedModel
94}
95
96type sessionAgent struct {
97 largeModel *csync.Value[Model]
98 smallModel *csync.Value[Model]
99 systemPromptPrefix *csync.Value[string]
100 systemPrompt *csync.Value[string]
101 tools *csync.Slice[fantasy.AgentTool]
102
103 isSubAgent bool
104 sessions session.Service
105 messages message.Service
106 disableAutoSummarize bool
107 isYolo bool
108
109 messageQueue *csync.Map[string, []SessionAgentCall]
110 activeRequests *csync.Map[string, context.CancelFunc]
111}
112
113type SessionAgentOptions struct {
114 LargeModel Model
115 SmallModel Model
116 SystemPromptPrefix string
117 SystemPrompt string
118 IsSubAgent bool
119 DisableAutoSummarize bool
120 IsYolo bool
121 Sessions session.Service
122 Messages message.Service
123 Tools []fantasy.AgentTool
124}
125
126func NewSessionAgent(
127 opts SessionAgentOptions,
128) SessionAgent {
129 return &sessionAgent{
130 largeModel: csync.NewValue(opts.LargeModel),
131 smallModel: csync.NewValue(opts.SmallModel),
132 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
133 systemPrompt: csync.NewValue(opts.SystemPrompt),
134 isSubAgent: opts.IsSubAgent,
135 sessions: opts.Sessions,
136 messages: opts.Messages,
137 disableAutoSummarize: opts.DisableAutoSummarize,
138 tools: csync.NewSliceFrom(opts.Tools),
139 isYolo: opts.IsYolo,
140 messageQueue: csync.NewMap[string, []SessionAgentCall](),
141 activeRequests: csync.NewMap[string, context.CancelFunc](),
142 }
143}
144
145func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
146 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
147 return nil, ErrEmptyPrompt
148 }
149 if call.SessionID == "" {
150 return nil, ErrSessionMissing
151 }
152
153 // Queue the message if busy
154 if a.IsSessionBusy(call.SessionID) {
155 existing, ok := a.messageQueue.Get(call.SessionID)
156 if !ok {
157 existing = []SessionAgentCall{}
158 }
159 existing = append(existing, call)
160 a.messageQueue.Set(call.SessionID, existing)
161 return nil, nil
162 }
163
164 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
165 agentTools := a.tools.Copy()
166 largeModel := a.largeModel.Get()
167 systemPrompt := a.systemPrompt.Get()
168 promptPrefix := a.systemPromptPrefix.Get()
169
170 if len(agentTools) > 0 {
171 // Add Anthropic caching to the last tool.
172 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
173 }
174
175 agent := fantasy.NewAgent(
176 largeModel.Model,
177 fantasy.WithSystemPrompt(systemPrompt),
178 fantasy.WithTools(agentTools...),
179 )
180
181 sessionLock := sync.Mutex{}
182 currentSession, err := a.sessions.Get(ctx, call.SessionID)
183 if err != nil {
184 return nil, fmt.Errorf("failed to get session: %w", err)
185 }
186
187 msgs, err := a.getSessionMessages(ctx, currentSession)
188 if err != nil {
189 return nil, fmt.Errorf("failed to get session messages: %w", err)
190 }
191
192 var wg sync.WaitGroup
193 // Generate title if first message.
194 if len(msgs) == 0 {
195 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
196 wg.Go(func() {
197 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
198 })
199 }
200 defer wg.Wait()
201
202 // Add the user message to the session.
203 _, err = a.createUserMessage(ctx, call)
204 if err != nil {
205 return nil, err
206 }
207
208 // Add the session to the context.
209 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
210
211 genCtx, cancel := context.WithCancel(ctx)
212 a.activeRequests.Set(call.SessionID, cancel)
213
214 defer cancel()
215 defer a.activeRequests.Del(call.SessionID)
216
217 history, files := a.preparePrompt(msgs, call.Attachments...)
218
219 startTime := time.Now()
220 a.eventPromptSent(call.SessionID)
221
222 var currentAssistant *message.Message
223 var shouldSummarize bool
224 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
225 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
226 Files: files,
227 Messages: history,
228 ProviderOptions: call.ProviderOptions,
229 MaxOutputTokens: &call.MaxOutputTokens,
230 TopP: call.TopP,
231 Temperature: call.Temperature,
232 PresencePenalty: call.PresencePenalty,
233 TopK: call.TopK,
234 FrequencyPenalty: call.FrequencyPenalty,
235 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
236 prepared.Messages = options.Messages
237 for i := range prepared.Messages {
238 prepared.Messages[i].ProviderOptions = nil
239 }
240
241 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
242 a.messageQueue.Del(call.SessionID)
243 for _, queued := range queuedCalls {
244 userMessage, createErr := a.createUserMessage(callContext, queued)
245 if createErr != nil {
246 return callContext, prepared, createErr
247 }
248 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
249 }
250
251 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
252
253 lastSystemRoleInx := 0
254 systemMessageUpdated := false
255 for i, msg := range prepared.Messages {
256 // Only add cache control to the last message.
257 if msg.Role == fantasy.MessageRoleSystem {
258 lastSystemRoleInx = i
259 } else if !systemMessageUpdated {
260 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
261 systemMessageUpdated = true
262 }
263 // Than add cache control to the last 2 messages.
264 if i > len(prepared.Messages)-3 {
265 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
266 }
267 }
268
269 if promptPrefix != "" {
270 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
271 }
272
273 var assistantMsg message.Message
274 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
275 Role: message.Assistant,
276 Parts: []message.ContentPart{},
277 Model: largeModel.ModelCfg.Model,
278 Provider: largeModel.ModelCfg.Provider,
279 })
280 if err != nil {
281 return callContext, prepared, err
282 }
283 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
284 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
285 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
286 currentAssistant = &assistantMsg
287 return callContext, prepared, err
288 },
289 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
290 currentAssistant.AppendReasoningContent(reasoning.Text)
291 return a.messages.Update(genCtx, *currentAssistant)
292 },
293 OnReasoningDelta: func(id string, text string) error {
294 currentAssistant.AppendReasoningContent(text)
295 return a.messages.Update(genCtx, *currentAssistant)
296 },
297 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
298 // handle anthropic signature
299 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
300 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
301 currentAssistant.AppendReasoningSignature(reasoning.Signature)
302 }
303 }
304 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
305 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
306 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
307 }
308 }
309 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
310 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
311 currentAssistant.SetReasoningResponsesData(reasoning)
312 }
313 }
314 currentAssistant.FinishThinking()
315 return a.messages.Update(genCtx, *currentAssistant)
316 },
317 OnTextDelta: func(id string, text string) error {
318 // Strip leading newline from initial text content. This is is
319 // particularly important in non-interactive mode where leading
320 // newlines are very visible.
321 if len(currentAssistant.Parts) == 0 {
322 text = strings.TrimPrefix(text, "\n")
323 }
324
325 currentAssistant.AppendContent(text)
326 return a.messages.Update(genCtx, *currentAssistant)
327 },
328 OnToolInputStart: func(id string, toolName string) error {
329 toolCall := message.ToolCall{
330 ID: id,
331 Name: toolName,
332 ProviderExecuted: false,
333 Finished: false,
334 }
335 currentAssistant.AddToolCall(toolCall)
336 return a.messages.Update(genCtx, *currentAssistant)
337 },
338 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
339 // TODO: implement
340 },
341 OnToolCall: func(tc fantasy.ToolCallContent) error {
342 toolCall := message.ToolCall{
343 ID: tc.ToolCallID,
344 Name: tc.ToolName,
345 Input: tc.Input,
346 ProviderExecuted: false,
347 Finished: true,
348 }
349 currentAssistant.AddToolCall(toolCall)
350 return a.messages.Update(genCtx, *currentAssistant)
351 },
352 OnToolResult: func(result fantasy.ToolResultContent) error {
353 toolResult := a.convertToToolResult(result)
354 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
355 Role: message.Tool,
356 Parts: []message.ContentPart{
357 toolResult,
358 },
359 })
360 return createMsgErr
361 },
362 OnStepFinish: func(stepResult fantasy.StepResult) error {
363 finishReason := message.FinishReasonUnknown
364 switch stepResult.FinishReason {
365 case fantasy.FinishReasonLength:
366 finishReason = message.FinishReasonMaxTokens
367 case fantasy.FinishReasonStop:
368 finishReason = message.FinishReasonEndTurn
369 case fantasy.FinishReasonToolCalls:
370 finishReason = message.FinishReasonToolUse
371 }
372 currentAssistant.AddFinish(finishReason, "", "")
373 sessionLock.Lock()
374 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
375 if getSessionErr != nil {
376 sessionLock.Unlock()
377 return getSessionErr
378 }
379 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
380 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
381 sessionLock.Unlock()
382 if sessionErr != nil {
383 return sessionErr
384 }
385 return a.messages.Update(genCtx, *currentAssistant)
386 },
387 StopWhen: []fantasy.StopCondition{
388 func(_ []fantasy.StepResult) bool {
389 cw := int64(largeModel.CatwalkCfg.ContextWindow)
390 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
391 remaining := cw - tokens
392 var threshold int64
393 if cw > largeContextWindowThreshold {
394 threshold = largeContextWindowBuffer
395 } else {
396 threshold = int64(float64(cw) * smallContextWindowRatio)
397 }
398 if (remaining <= threshold) && !a.disableAutoSummarize {
399 shouldSummarize = true
400 return true
401 }
402 return false
403 },
404 },
405 })
406
407 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
408
409 if err != nil {
410 isCancelErr := errors.Is(err, context.Canceled)
411 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
412 if currentAssistant == nil {
413 return result, err
414 }
415 // Ensure we finish thinking on error to close the reasoning state.
416 currentAssistant.FinishThinking()
417 toolCalls := currentAssistant.ToolCalls()
418 // INFO: we use the parent context here because the genCtx has been cancelled.
419 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
420 if createErr != nil {
421 return nil, createErr
422 }
423 for _, tc := range toolCalls {
424 if !tc.Finished {
425 tc.Finished = true
426 tc.Input = "{}"
427 currentAssistant.AddToolCall(tc)
428 updateErr := a.messages.Update(ctx, *currentAssistant)
429 if updateErr != nil {
430 return nil, updateErr
431 }
432 }
433
434 found := false
435 for _, msg := range msgs {
436 if msg.Role == message.Tool {
437 for _, tr := range msg.ToolResults() {
438 if tr.ToolCallID == tc.ID {
439 found = true
440 break
441 }
442 }
443 }
444 if found {
445 break
446 }
447 }
448 if found {
449 continue
450 }
451 content := "There was an error while executing the tool"
452 if isCancelErr {
453 content = "Tool execution canceled by user"
454 } else if isPermissionErr {
455 content = "User denied permission"
456 }
457 toolResult := message.ToolResult{
458 ToolCallID: tc.ID,
459 Name: tc.Name,
460 Content: content,
461 IsError: true,
462 }
463 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
464 Role: message.Tool,
465 Parts: []message.ContentPart{
466 toolResult,
467 },
468 })
469 if createErr != nil {
470 return nil, createErr
471 }
472 }
473 var fantasyErr *fantasy.Error
474 var providerErr *fantasy.ProviderError
475 const defaultTitle = "Provider Error"
476 if isCancelErr {
477 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
478 } else if isPermissionErr {
479 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
480 } else if errors.Is(err, hyper.ErrNoCredits) {
481 url := hyper.BaseURL()
482 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
483 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
484 } else if errors.As(err, &providerErr) {
485 if providerErr.Message == "The requested model is not supported." {
486 url := "https://github.com/settings/copilot/features"
487 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
488 currentAssistant.AddFinish(
489 message.FinishReasonError,
490 "Copilot model not enabled",
491 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
492 )
493 } else {
494 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
495 }
496 } else if errors.As(err, &fantasyErr) {
497 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
498 } else {
499 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
500 }
501 // Note: we use the parent context here because the genCtx has been
502 // cancelled.
503 updateErr := a.messages.Update(ctx, *currentAssistant)
504 if updateErr != nil {
505 return nil, updateErr
506 }
507 return nil, err
508 }
509
510 if shouldSummarize {
511 a.activeRequests.Del(call.SessionID)
512 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
513 return nil, summarizeErr
514 }
515 // If the agent wasn't done...
516 if len(currentAssistant.ToolCalls()) > 0 {
517 existing, ok := a.messageQueue.Get(call.SessionID)
518 if !ok {
519 existing = []SessionAgentCall{}
520 }
521 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
522 existing = append(existing, call)
523 a.messageQueue.Set(call.SessionID, existing)
524 }
525 }
526
527 // Release active request before processing queued messages.
528 a.activeRequests.Del(call.SessionID)
529 cancel()
530
531 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
532 if !ok || len(queuedMessages) == 0 {
533 return result, err
534 }
535 // There are queued messages restart the loop.
536 firstQueuedMessage := queuedMessages[0]
537 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
538 return a.Run(ctx, firstQueuedMessage)
539}
540
541func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
542 if a.IsSessionBusy(sessionID) {
543 return ErrSessionBusy
544 }
545
546 // Copy mutable fields under lock to avoid races with SetModels.
547 largeModel := a.largeModel.Get()
548 systemPromptPrefix := a.systemPromptPrefix.Get()
549
550 currentSession, err := a.sessions.Get(ctx, sessionID)
551 if err != nil {
552 return fmt.Errorf("failed to get session: %w", err)
553 }
554 msgs, err := a.getSessionMessages(ctx, currentSession)
555 if err != nil {
556 return err
557 }
558 if len(msgs) == 0 {
559 // Nothing to summarize.
560 return nil
561 }
562
563 aiMsgs, _ := a.preparePrompt(msgs)
564
565 genCtx, cancel := context.WithCancel(ctx)
566 a.activeRequests.Set(sessionID, cancel)
567 defer a.activeRequests.Del(sessionID)
568 defer cancel()
569
570 agent := fantasy.NewAgent(largeModel.Model,
571 fantasy.WithSystemPrompt(string(summaryPrompt)),
572 )
573 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
574 Role: message.Assistant,
575 Model: largeModel.Model.Model(),
576 Provider: largeModel.Model.Provider(),
577 IsSummaryMessage: true,
578 })
579 if err != nil {
580 return err
581 }
582
583 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
584
585 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
586 Prompt: summaryPromptText,
587 Messages: aiMsgs,
588 ProviderOptions: opts,
589 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
590 prepared.Messages = options.Messages
591 if systemPromptPrefix != "" {
592 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
593 }
594 return callContext, prepared, nil
595 },
596 OnReasoningDelta: func(id string, text string) error {
597 summaryMessage.AppendReasoningContent(text)
598 return a.messages.Update(genCtx, summaryMessage)
599 },
600 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
601 // Handle anthropic signature.
602 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
603 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
604 summaryMessage.AppendReasoningSignature(signature.Signature)
605 }
606 }
607 summaryMessage.FinishThinking()
608 return a.messages.Update(genCtx, summaryMessage)
609 },
610 OnTextDelta: func(id, text string) error {
611 summaryMessage.AppendContent(text)
612 return a.messages.Update(genCtx, summaryMessage)
613 },
614 })
615 if err != nil {
616 isCancelErr := errors.Is(err, context.Canceled)
617 if isCancelErr {
618 // User cancelled summarize we need to remove the summary message.
619 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
620 return deleteErr
621 }
622 return err
623 }
624
625 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
626 err = a.messages.Update(genCtx, summaryMessage)
627 if err != nil {
628 return err
629 }
630
631 var openrouterCost *float64
632 for _, step := range resp.Steps {
633 stepCost := a.openrouterCost(step.ProviderMetadata)
634 if stepCost != nil {
635 newCost := *stepCost
636 if openrouterCost != nil {
637 newCost += *openrouterCost
638 }
639 openrouterCost = &newCost
640 }
641 }
642
643 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
644
645 // Just in case, get just the last usage info.
646 usage := resp.Response.Usage
647 currentSession.SummaryMessageID = summaryMessage.ID
648 currentSession.CompletionTokens = usage.OutputTokens
649 currentSession.PromptTokens = 0
650 _, err = a.sessions.Save(genCtx, currentSession)
651 return err
652}
653
654func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
655 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
656 return fantasy.ProviderOptions{}
657 }
658 return fantasy.ProviderOptions{
659 anthropic.Name: &anthropic.ProviderCacheControlOptions{
660 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
661 },
662 bedrock.Name: &anthropic.ProviderCacheControlOptions{
663 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
664 },
665 }
666}
667
668func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
669 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
670 var attachmentParts []message.ContentPart
671 for _, attachment := range call.Attachments {
672 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
673 }
674 parts = append(parts, attachmentParts...)
675 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
676 Role: message.User,
677 Parts: parts,
678 })
679 if err != nil {
680 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
681 }
682 return msg, nil
683}
684
685func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
686 var history []fantasy.Message
687 if !a.isSubAgent {
688 history = append(history, fantasy.NewUserMessage(
689 fmt.Sprintf("<system_reminder>%s</system_reminder>",
690 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
691If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
692If not, please feel free to ignore. Again do not mention this message to the user.`,
693 ),
694 ))
695 }
696 for _, m := range msgs {
697 if len(m.Parts) == 0 {
698 continue
699 }
700 // Assistant message without content or tool calls (cancelled before it
701 // returned anything).
702 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
703 continue
704 }
705 history = append(history, m.ToAIMessage()...)
706 }
707
708 var files []fantasy.FilePart
709 for _, attachment := range attachments {
710 if attachment.IsText() {
711 continue
712 }
713 files = append(files, fantasy.FilePart{
714 Filename: attachment.FileName,
715 Data: attachment.Content,
716 MediaType: attachment.MimeType,
717 })
718 }
719
720 return history, files
721}
722
723func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
724 msgs, err := a.messages.List(ctx, session.ID)
725 if err != nil {
726 return nil, fmt.Errorf("failed to list messages: %w", err)
727 }
728
729 if session.SummaryMessageID != "" {
730 summaryMsgIndex := -1
731 for i, msg := range msgs {
732 if msg.ID == session.SummaryMessageID {
733 summaryMsgIndex = i
734 break
735 }
736 }
737 if summaryMsgIndex != -1 {
738 msgs = msgs[summaryMsgIndex:]
739 msgs[0].Role = message.User
740 }
741 }
742 return msgs, nil
743}
744
745// generateTitle generates a session titled based on the initial prompt.
746func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
747 if userPrompt == "" {
748 return
749 }
750
751 smallModel := a.smallModel.Get()
752 largeModel := a.largeModel.Get()
753 systemPromptPrefix := a.systemPromptPrefix.Get()
754
755 var maxOutputTokens int64 = 40
756 if smallModel.CatwalkCfg.CanReason {
757 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
758 }
759
760 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
761 return fantasy.NewAgent(m,
762 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
763 fantasy.WithMaxOutputTokens(tok),
764 )
765 }
766
767 streamCall := fantasy.AgentStreamCall{
768 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
769 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
770 prepared.Messages = opts.Messages
771 if systemPromptPrefix != "" {
772 prepared.Messages = append([]fantasy.Message{
773 fantasy.NewSystemMessage(systemPromptPrefix),
774 }, prepared.Messages...)
775 }
776 return callCtx, prepared, nil
777 },
778 }
779
780 // Use the small model to generate the title.
781 model := smallModel
782 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
783 resp, err := agent.Stream(ctx, streamCall)
784 if err == nil {
785 // We successfully generated a title with the small model.
786 slog.Info("generated title with small model")
787 } else {
788 // It didn't work. Let's try with the big model.
789 slog.Error("error generating title with small model; trying big model", "err", err)
790 model = largeModel
791 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
792 resp, err = agent.Stream(ctx, streamCall)
793 if err == nil {
794 slog.Info("generated title with large model")
795 } else {
796 // Welp, the large model didn't work either. Use the default
797 // session name and return.
798 slog.Error("error generating title with large model", "err", err)
799 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
800 if saveErr != nil {
801 slog.Error("failed to save session title and usage", "error", saveErr)
802 }
803 return
804 }
805 }
806
807 if resp == nil {
808 // Actually, we didn't get a response so we can't. Use the default
809 // session name and return.
810 slog.Error("response is nil; can't generate title")
811 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
812 if saveErr != nil {
813 slog.Error("failed to save session title and usage", "error", saveErr)
814 }
815 return
816 }
817
818 // Clean up title.
819 var title string
820 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
821 slog.Info("generated title", "title", title)
822
823 // Remove thinking tags if present.
824 title = thinkTagRegex.ReplaceAllString(title, "")
825
826 title = strings.TrimSpace(title)
827 if title == "" {
828 slog.Warn("empty title; using fallback")
829 title = defaultSessionName
830 }
831
832 // Calculate usage and cost.
833 var openrouterCost *float64
834 for _, step := range resp.Steps {
835 stepCost := a.openrouterCost(step.ProviderMetadata)
836 if stepCost != nil {
837 newCost := *stepCost
838 if openrouterCost != nil {
839 newCost += *openrouterCost
840 }
841 openrouterCost = &newCost
842 }
843 }
844
845 modelConfig := model.CatwalkCfg
846 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
847 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
848 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
849 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
850
851 // Use override cost if available (e.g., from OpenRouter).
852 if openrouterCost != nil {
853 cost = *openrouterCost
854 }
855
856 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
857 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
858
859 // Atomically update only title and usage fields to avoid overriding other
860 // concurrent session updates.
861 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
862 if saveErr != nil {
863 slog.Error("failed to save session title and usage", "error", saveErr)
864 return
865 }
866}
867
868func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
869 openrouterMetadata, ok := metadata[openrouter.Name]
870 if !ok {
871 return nil
872 }
873
874 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
875 if !ok {
876 return nil
877 }
878 return &opts.Usage.Cost
879}
880
881func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
882 modelConfig := model.CatwalkCfg
883 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
884 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
885 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
886 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
887
888 a.eventTokensUsed(session.ID, model, usage, cost)
889
890 if overrideCost != nil {
891 session.Cost += *overrideCost
892 } else {
893 session.Cost += cost
894 }
895
896 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
897 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
898}
899
900func (a *sessionAgent) Cancel(sessionID string) {
901 // Cancel regular requests. Don't use Take() here - we need the entry to
902 // remain in activeRequests so IsBusy() returns true until the goroutine
903 // fully completes (including error handling that may access the DB).
904 // The defer in processRequest will clean up the entry.
905 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
906 slog.Info("Request cancellation initiated", "session_id", sessionID)
907 cancel()
908 }
909
910 // Also check for summarize requests.
911 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
912 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
913 cancel()
914 }
915
916 if a.QueuedPrompts(sessionID) > 0 {
917 slog.Info("Clearing queued prompts", "session_id", sessionID)
918 a.messageQueue.Del(sessionID)
919 }
920}
921
922func (a *sessionAgent) ClearQueue(sessionID string) {
923 if a.QueuedPrompts(sessionID) > 0 {
924 slog.Info("Clearing queued prompts", "session_id", sessionID)
925 a.messageQueue.Del(sessionID)
926 }
927}
928
929func (a *sessionAgent) CancelAll() {
930 if !a.IsBusy() {
931 return
932 }
933 for key := range a.activeRequests.Seq2() {
934 a.Cancel(key) // key is sessionID
935 }
936
937 timeout := time.After(5 * time.Second)
938 for a.IsBusy() {
939 select {
940 case <-timeout:
941 return
942 default:
943 time.Sleep(200 * time.Millisecond)
944 }
945 }
946}
947
948func (a *sessionAgent) IsBusy() bool {
949 var busy bool
950 for cancelFunc := range a.activeRequests.Seq() {
951 if cancelFunc != nil {
952 busy = true
953 break
954 }
955 }
956 return busy
957}
958
959func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
960 _, busy := a.activeRequests.Get(sessionID)
961 return busy
962}
963
964func (a *sessionAgent) QueuedPrompts(sessionID string) int {
965 l, ok := a.messageQueue.Get(sessionID)
966 if !ok {
967 return 0
968 }
969 return len(l)
970}
971
972func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
973 l, ok := a.messageQueue.Get(sessionID)
974 if !ok {
975 return nil
976 }
977 prompts := make([]string, len(l))
978 for i, call := range l {
979 prompts[i] = call.Prompt
980 }
981 return prompts
982}
983
984func (a *sessionAgent) SetModels(large Model, small Model) {
985 a.largeModel.Set(large)
986 a.smallModel.Set(small)
987}
988
989func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
990 a.tools.SetSlice(tools)
991}
992
993func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
994 a.systemPrompt.Set(systemPrompt)
995}
996
997func (a *sessionAgent) Model() Model {
998 return a.largeModel.Get()
999}
1000
1001// convertToToolResult converts a fantasy tool result to a message tool result.
1002func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1003 baseResult := message.ToolResult{
1004 ToolCallID: result.ToolCallID,
1005 Name: result.ToolName,
1006 Metadata: result.ClientMetadata,
1007 }
1008
1009 switch result.Result.GetType() {
1010 case fantasy.ToolResultContentTypeText:
1011 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1012 baseResult.Content = r.Text
1013 }
1014 case fantasy.ToolResultContentTypeError:
1015 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1016 baseResult.Content = r.Error.Error()
1017 baseResult.IsError = true
1018 }
1019 case fantasy.ToolResultContentTypeMedia:
1020 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1021 content := r.Text
1022 if content == "" {
1023 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1024 }
1025 baseResult.Content = content
1026 baseResult.Data = r.Data
1027 baseResult.MIMEType = r.MediaType
1028 }
1029 }
1030
1031 return baseResult
1032}
1033
1034// workaroundProviderMediaLimitations converts media content in tool results to
1035// user messages for providers that don't natively support images in tool results.
1036//
1037// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1038// don't support sending images/media in tool result messages - they only accept
1039// text in tool results. However, they DO support images in user messages.
1040//
1041// If we send media in tool results to these providers, the API returns an error.
1042//
1043// Solution: For these providers, we:
1044// 1. Replace the media in the tool result with a text placeholder
1045// 2. Inject a user message immediately after with the image as a file attachment
1046// 3. This maintains the tool execution flow while working around API limitations
1047//
1048// Anthropic and Bedrock support images natively in tool results, so we skip
1049// this workaround for them.
1050//
1051// Example transformation:
1052//
1053// BEFORE: [tool result: image data]
1054// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1055func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1056 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1057 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1058
1059 if providerSupportsMedia {
1060 return messages
1061 }
1062
1063 convertedMessages := make([]fantasy.Message, 0, len(messages))
1064
1065 for _, msg := range messages {
1066 if msg.Role != fantasy.MessageRoleTool {
1067 convertedMessages = append(convertedMessages, msg)
1068 continue
1069 }
1070
1071 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1072 var mediaFiles []fantasy.FilePart
1073
1074 for _, part := range msg.Content {
1075 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1076 if !ok {
1077 textParts = append(textParts, part)
1078 continue
1079 }
1080
1081 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1082 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1083 if err != nil {
1084 slog.Warn("failed to decode media data", "error", err)
1085 textParts = append(textParts, part)
1086 continue
1087 }
1088
1089 mediaFiles = append(mediaFiles, fantasy.FilePart{
1090 Data: decoded,
1091 MediaType: media.MediaType,
1092 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1093 })
1094
1095 textParts = append(textParts, fantasy.ToolResultPart{
1096 ToolCallID: toolResult.ToolCallID,
1097 Output: fantasy.ToolResultOutputContentText{
1098 Text: "[Image/media content loaded - see attached file]",
1099 },
1100 ProviderOptions: toolResult.ProviderOptions,
1101 })
1102 } else {
1103 textParts = append(textParts, part)
1104 }
1105 }
1106
1107 convertedMessages = append(convertedMessages, fantasy.Message{
1108 Role: fantasy.MessageRoleTool,
1109 Content: textParts,
1110 })
1111
1112 if len(mediaFiles) > 0 {
1113 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1114 "Here is the media content from the tool result:",
1115 mediaFiles...,
1116 ))
1117 }
1118 }
1119
1120 return convertedMessages
1121}
1122
1123// buildSummaryPrompt constructs the prompt text for session summarization.
1124func buildSummaryPrompt(todos []session.Todo) string {
1125 var sb strings.Builder
1126 sb.WriteString("Provide a detailed summary of our conversation above.")
1127 if len(todos) > 0 {
1128 sb.WriteString("\n\n## Current Todo List\n\n")
1129 for _, t := range todos {
1130 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1131 }
1132 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1133 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1134 }
1135 return sb.String()
1136}