1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
36 "github.com/charmbracelet/crush/internal/config"
37 "github.com/charmbracelet/crush/internal/csync"
38 "github.com/charmbracelet/crush/internal/message"
39 "github.com/charmbracelet/crush/internal/permission"
40 "github.com/charmbracelet/crush/internal/session"
41 "github.com/charmbracelet/crush/internal/stringext"
42 "github.com/charmbracelet/x/exp/charmtone"
43)
44
45const (
46 defaultSessionName = "Untitled Session"
47
48 // Constants for auto-summarization thresholds
49 largeContextWindowThreshold = 200_000
50 largeContextWindowBuffer = 20_000
51 smallContextWindowRatio = 0.2
52)
53
54//go:embed templates/title.md
55var titlePrompt []byte
56
57//go:embed templates/summary.md
58var summaryPrompt []byte
59
60// Used to remove <think> tags from generated titles.
61var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
62
63type SessionAgentCall struct {
64 SessionID string
65 Prompt string
66 ProviderOptions fantasy.ProviderOptions
67 Attachments []message.Attachment
68 MaxOutputTokens int64
69 Temperature *float64
70 TopP *float64
71 TopK *int64
72 FrequencyPenalty *float64
73 PresencePenalty *float64
74}
75
76type SessionAgent interface {
77 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
78 SetModels(large Model, small Model)
79 SetTools(tools []fantasy.AgentTool)
80 SetSystemPrompt(systemPrompt string)
81 Cancel(sessionID string)
82 CancelAll()
83 IsSessionBusy(sessionID string) bool
84 IsBusy() bool
85 QueuedPrompts(sessionID string) int
86 QueuedPromptsList(sessionID string) []string
87 ClearQueue(sessionID string)
88 Summarize(context.Context, string, fantasy.ProviderOptions) error
89 Model() Model
90}
91
92type Model struct {
93 Model fantasy.LanguageModel
94 CatwalkCfg catwalk.Model
95 ModelCfg config.SelectedModel
96}
97
98type sessionAgent struct {
99 largeModel *csync.Value[Model]
100 smallModel *csync.Value[Model]
101 systemPromptPrefix *csync.Value[string]
102 systemPrompt *csync.Value[string]
103 tools *csync.Slice[fantasy.AgentTool]
104
105 isSubAgent bool
106 sessions session.Service
107 messages message.Service
108 disableAutoSummarize bool
109 isYolo bool
110
111 messageQueue *csync.Map[string, []SessionAgentCall]
112 activeRequests *csync.Map[string, context.CancelFunc]
113}
114
115type SessionAgentOptions struct {
116 LargeModel Model
117 SmallModel Model
118 SystemPromptPrefix string
119 SystemPrompt string
120 IsSubAgent bool
121 DisableAutoSummarize bool
122 IsYolo bool
123 Sessions session.Service
124 Messages message.Service
125 Tools []fantasy.AgentTool
126}
127
128func NewSessionAgent(
129 opts SessionAgentOptions,
130) SessionAgent {
131 return &sessionAgent{
132 largeModel: csync.NewValue(opts.LargeModel),
133 smallModel: csync.NewValue(opts.SmallModel),
134 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
135 systemPrompt: csync.NewValue(opts.SystemPrompt),
136 isSubAgent: opts.IsSubAgent,
137 sessions: opts.Sessions,
138 messages: opts.Messages,
139 disableAutoSummarize: opts.DisableAutoSummarize,
140 tools: csync.NewSliceFrom(opts.Tools),
141 isYolo: opts.IsYolo,
142 messageQueue: csync.NewMap[string, []SessionAgentCall](),
143 activeRequests: csync.NewMap[string, context.CancelFunc](),
144 }
145}
146
147func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
148 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
149 return nil, ErrEmptyPrompt
150 }
151 if call.SessionID == "" {
152 return nil, ErrSessionMissing
153 }
154
155 // Queue the message if busy
156 if a.IsSessionBusy(call.SessionID) {
157 existing, ok := a.messageQueue.Get(call.SessionID)
158 if !ok {
159 existing = []SessionAgentCall{}
160 }
161 existing = append(existing, call)
162 a.messageQueue.Set(call.SessionID, existing)
163 return nil, nil
164 }
165
166 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
167 agentTools := a.tools.Copy()
168 largeModel := a.largeModel.Get()
169 systemPrompt := a.systemPrompt.Get()
170 promptPrefix := a.systemPromptPrefix.Get()
171 var instructions strings.Builder
172
173 for _, server := range mcp.GetStates() {
174 if server.State != mcp.StateConnected {
175 continue
176 }
177 if s := server.Client.InitializeResult().Instructions; s != "" {
178 instructions.WriteString(s)
179 instructions.WriteString("\n\n")
180 }
181 }
182
183 if s := instructions.String(); s != "" {
184 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
185 }
186
187 if len(agentTools) > 0 {
188 // Add Anthropic caching to the last tool.
189 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
190 }
191
192 agent := fantasy.NewAgent(
193 largeModel.Model,
194 fantasy.WithSystemPrompt(systemPrompt),
195 fantasy.WithTools(agentTools...),
196 )
197
198 sessionLock := sync.Mutex{}
199 currentSession, err := a.sessions.Get(ctx, call.SessionID)
200 if err != nil {
201 return nil, fmt.Errorf("failed to get session: %w", err)
202 }
203
204 msgs, err := a.getSessionMessages(ctx, currentSession)
205 if err != nil {
206 return nil, fmt.Errorf("failed to get session messages: %w", err)
207 }
208
209 var wg sync.WaitGroup
210 // Generate title if first message.
211 if len(msgs) == 0 {
212 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
213 wg.Go(func() {
214 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
215 })
216 }
217 defer wg.Wait()
218
219 // Add the user message to the session.
220 _, err = a.createUserMessage(ctx, call)
221 if err != nil {
222 return nil, err
223 }
224
225 // Add the session to the context.
226 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
227
228 genCtx, cancel := context.WithCancel(ctx)
229 a.activeRequests.Set(call.SessionID, cancel)
230
231 defer cancel()
232 defer a.activeRequests.Del(call.SessionID)
233
234 history, files := a.preparePrompt(msgs, call.Attachments...)
235
236 startTime := time.Now()
237 a.eventPromptSent(call.SessionID)
238
239 var currentAssistant *message.Message
240 var shouldSummarize bool
241 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
242 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
243 Files: files,
244 Messages: history,
245 ProviderOptions: call.ProviderOptions,
246 MaxOutputTokens: &call.MaxOutputTokens,
247 TopP: call.TopP,
248 Temperature: call.Temperature,
249 PresencePenalty: call.PresencePenalty,
250 TopK: call.TopK,
251 FrequencyPenalty: call.FrequencyPenalty,
252 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
253 prepared.Messages = options.Messages
254 for i := range prepared.Messages {
255 prepared.Messages[i].ProviderOptions = nil
256 }
257
258 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
259 a.messageQueue.Del(call.SessionID)
260 for _, queued := range queuedCalls {
261 userMessage, createErr := a.createUserMessage(callContext, queued)
262 if createErr != nil {
263 return callContext, prepared, createErr
264 }
265 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
266 }
267
268 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
269
270 lastSystemRoleInx := 0
271 systemMessageUpdated := false
272 for i, msg := range prepared.Messages {
273 // Only add cache control to the last message.
274 if msg.Role == fantasy.MessageRoleSystem {
275 lastSystemRoleInx = i
276 } else if !systemMessageUpdated {
277 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
278 systemMessageUpdated = true
279 }
280 // Than add cache control to the last 2 messages.
281 if i > len(prepared.Messages)-3 {
282 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
283 }
284 }
285
286 if promptPrefix != "" {
287 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
288 }
289
290 var assistantMsg message.Message
291 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
292 Role: message.Assistant,
293 Parts: []message.ContentPart{},
294 Model: largeModel.ModelCfg.Model,
295 Provider: largeModel.ModelCfg.Provider,
296 })
297 if err != nil {
298 return callContext, prepared, err
299 }
300 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
301 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
302 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
303 currentAssistant = &assistantMsg
304 return callContext, prepared, err
305 },
306 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
307 currentAssistant.AppendReasoningContent(reasoning.Text)
308 return a.messages.Update(genCtx, *currentAssistant)
309 },
310 OnReasoningDelta: func(id string, text string) error {
311 currentAssistant.AppendReasoningContent(text)
312 return a.messages.Update(genCtx, *currentAssistant)
313 },
314 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
315 // handle anthropic signature
316 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
317 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
318 currentAssistant.AppendReasoningSignature(reasoning.Signature)
319 }
320 }
321 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
322 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
323 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
324 }
325 }
326 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
327 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
328 currentAssistant.SetReasoningResponsesData(reasoning)
329 }
330 }
331 currentAssistant.FinishThinking()
332 return a.messages.Update(genCtx, *currentAssistant)
333 },
334 OnTextDelta: func(id string, text string) error {
335 // Strip leading newline from initial text content. This is is
336 // particularly important in non-interactive mode where leading
337 // newlines are very visible.
338 if len(currentAssistant.Parts) == 0 {
339 text = strings.TrimPrefix(text, "\n")
340 }
341
342 currentAssistant.AppendContent(text)
343 return a.messages.Update(genCtx, *currentAssistant)
344 },
345 OnToolInputStart: func(id string, toolName string) error {
346 toolCall := message.ToolCall{
347 ID: id,
348 Name: toolName,
349 ProviderExecuted: false,
350 Finished: false,
351 }
352 currentAssistant.AddToolCall(toolCall)
353 return a.messages.Update(genCtx, *currentAssistant)
354 },
355 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
356 // TODO: implement
357 },
358 OnToolCall: func(tc fantasy.ToolCallContent) error {
359 toolCall := message.ToolCall{
360 ID: tc.ToolCallID,
361 Name: tc.ToolName,
362 Input: tc.Input,
363 ProviderExecuted: false,
364 Finished: true,
365 }
366 currentAssistant.AddToolCall(toolCall)
367 return a.messages.Update(genCtx, *currentAssistant)
368 },
369 OnToolResult: func(result fantasy.ToolResultContent) error {
370 toolResult := a.convertToToolResult(result)
371 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
372 Role: message.Tool,
373 Parts: []message.ContentPart{
374 toolResult,
375 },
376 })
377 return createMsgErr
378 },
379 OnStepFinish: func(stepResult fantasy.StepResult) error {
380 finishReason := message.FinishReasonUnknown
381 switch stepResult.FinishReason {
382 case fantasy.FinishReasonLength:
383 finishReason = message.FinishReasonMaxTokens
384 case fantasy.FinishReasonStop:
385 finishReason = message.FinishReasonEndTurn
386 case fantasy.FinishReasonToolCalls:
387 finishReason = message.FinishReasonToolUse
388 }
389 currentAssistant.AddFinish(finishReason, "", "")
390 sessionLock.Lock()
391 defer sessionLock.Unlock()
392
393 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
394 if getSessionErr != nil {
395 return getSessionErr
396 }
397 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
398 _, sessionErr := a.sessions.Save(ctx, updatedSession)
399 if sessionErr != nil {
400 return sessionErr
401 }
402 currentSession = updatedSession
403 return a.messages.Update(genCtx, *currentAssistant)
404 },
405 StopWhen: []fantasy.StopCondition{
406 func(_ []fantasy.StepResult) bool {
407 cw := int64(largeModel.CatwalkCfg.ContextWindow)
408 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
409 remaining := cw - tokens
410 var threshold int64
411 if cw > largeContextWindowThreshold {
412 threshold = largeContextWindowBuffer
413 } else {
414 threshold = int64(float64(cw) * smallContextWindowRatio)
415 }
416 if (remaining <= threshold) && !a.disableAutoSummarize {
417 shouldSummarize = true
418 return true
419 }
420 return false
421 },
422 },
423 })
424
425 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
426
427 if err != nil {
428 isCancelErr := errors.Is(err, context.Canceled)
429 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
430 if currentAssistant == nil {
431 return result, err
432 }
433 // Ensure we finish thinking on error to close the reasoning state.
434 currentAssistant.FinishThinking()
435 toolCalls := currentAssistant.ToolCalls()
436 // INFO: we use the parent context here because the genCtx has been cancelled.
437 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
438 if createErr != nil {
439 return nil, createErr
440 }
441 for _, tc := range toolCalls {
442 if !tc.Finished {
443 tc.Finished = true
444 tc.Input = "{}"
445 currentAssistant.AddToolCall(tc)
446 updateErr := a.messages.Update(ctx, *currentAssistant)
447 if updateErr != nil {
448 return nil, updateErr
449 }
450 }
451
452 found := false
453 for _, msg := range msgs {
454 if msg.Role == message.Tool {
455 for _, tr := range msg.ToolResults() {
456 if tr.ToolCallID == tc.ID {
457 found = true
458 break
459 }
460 }
461 }
462 if found {
463 break
464 }
465 }
466 if found {
467 continue
468 }
469 content := "There was an error while executing the tool"
470 if isCancelErr {
471 content = "Tool execution canceled by user"
472 } else if isPermissionErr {
473 content = "User denied permission"
474 }
475 toolResult := message.ToolResult{
476 ToolCallID: tc.ID,
477 Name: tc.Name,
478 Content: content,
479 IsError: true,
480 }
481 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
482 Role: message.Tool,
483 Parts: []message.ContentPart{
484 toolResult,
485 },
486 })
487 if createErr != nil {
488 return nil, createErr
489 }
490 }
491 var fantasyErr *fantasy.Error
492 var providerErr *fantasy.ProviderError
493 const defaultTitle = "Provider Error"
494 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
495 if isCancelErr {
496 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
497 } else if isPermissionErr {
498 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
499 } else if errors.Is(err, hyper.ErrNoCredits) {
500 url := hyper.BaseURL()
501 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
502 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
503 } else if errors.As(err, &providerErr) {
504 if providerErr.Message == "The requested model is not supported." {
505 url := "https://github.com/settings/copilot/features"
506 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
507 currentAssistant.AddFinish(
508 message.FinishReasonError,
509 "Copilot model not enabled",
510 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
511 )
512 } else {
513 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
514 }
515 } else if errors.As(err, &fantasyErr) {
516 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
517 } else {
518 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
519 }
520 // Note: we use the parent context here because the genCtx has been
521 // cancelled.
522 updateErr := a.messages.Update(ctx, *currentAssistant)
523 if updateErr != nil {
524 return nil, updateErr
525 }
526 return nil, err
527 }
528
529 if shouldSummarize {
530 a.activeRequests.Del(call.SessionID)
531 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
532 return nil, summarizeErr
533 }
534 // If the agent wasn't done...
535 if len(currentAssistant.ToolCalls()) > 0 {
536 existing, ok := a.messageQueue.Get(call.SessionID)
537 if !ok {
538 existing = []SessionAgentCall{}
539 }
540 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
541 existing = append(existing, call)
542 a.messageQueue.Set(call.SessionID, existing)
543 }
544 }
545
546 // Release active request before processing queued messages.
547 a.activeRequests.Del(call.SessionID)
548 cancel()
549
550 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
551 if !ok || len(queuedMessages) == 0 {
552 return result, err
553 }
554 // There are queued messages restart the loop.
555 firstQueuedMessage := queuedMessages[0]
556 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
557 return a.Run(ctx, firstQueuedMessage)
558}
559
560func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
561 if a.IsSessionBusy(sessionID) {
562 return ErrSessionBusy
563 }
564
565 // Copy mutable fields under lock to avoid races with SetModels.
566 largeModel := a.largeModel.Get()
567 systemPromptPrefix := a.systemPromptPrefix.Get()
568
569 currentSession, err := a.sessions.Get(ctx, sessionID)
570 if err != nil {
571 return fmt.Errorf("failed to get session: %w", err)
572 }
573 msgs, err := a.getSessionMessages(ctx, currentSession)
574 if err != nil {
575 return err
576 }
577 if len(msgs) == 0 {
578 // Nothing to summarize.
579 return nil
580 }
581
582 aiMsgs, _ := a.preparePrompt(msgs)
583
584 genCtx, cancel := context.WithCancel(ctx)
585 a.activeRequests.Set(sessionID, cancel)
586 defer a.activeRequests.Del(sessionID)
587 defer cancel()
588
589 agent := fantasy.NewAgent(largeModel.Model,
590 fantasy.WithSystemPrompt(string(summaryPrompt)),
591 )
592 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
593 Role: message.Assistant,
594 Model: largeModel.Model.Model(),
595 Provider: largeModel.Model.Provider(),
596 IsSummaryMessage: true,
597 })
598 if err != nil {
599 return err
600 }
601
602 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
603
604 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
605 Prompt: summaryPromptText,
606 Messages: aiMsgs,
607 ProviderOptions: opts,
608 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
609 prepared.Messages = options.Messages
610 if systemPromptPrefix != "" {
611 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
612 }
613 return callContext, prepared, nil
614 },
615 OnReasoningDelta: func(id string, text string) error {
616 summaryMessage.AppendReasoningContent(text)
617 return a.messages.Update(genCtx, summaryMessage)
618 },
619 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
620 // Handle anthropic signature.
621 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
622 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
623 summaryMessage.AppendReasoningSignature(signature.Signature)
624 }
625 }
626 summaryMessage.FinishThinking()
627 return a.messages.Update(genCtx, summaryMessage)
628 },
629 OnTextDelta: func(id, text string) error {
630 summaryMessage.AppendContent(text)
631 return a.messages.Update(genCtx, summaryMessage)
632 },
633 })
634 if err != nil {
635 isCancelErr := errors.Is(err, context.Canceled)
636 if isCancelErr {
637 // User cancelled summarize we need to remove the summary message.
638 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
639 return deleteErr
640 }
641 return err
642 }
643
644 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
645 err = a.messages.Update(genCtx, summaryMessage)
646 if err != nil {
647 return err
648 }
649
650 var openrouterCost *float64
651 for _, step := range resp.Steps {
652 stepCost := a.openrouterCost(step.ProviderMetadata)
653 if stepCost != nil {
654 newCost := *stepCost
655 if openrouterCost != nil {
656 newCost += *openrouterCost
657 }
658 openrouterCost = &newCost
659 }
660 }
661
662 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
663
664 // Just in case, get just the last usage info.
665 usage := resp.Response.Usage
666 currentSession.SummaryMessageID = summaryMessage.ID
667 currentSession.CompletionTokens = usage.OutputTokens
668 currentSession.PromptTokens = 0
669 _, err = a.sessions.Save(genCtx, currentSession)
670 return err
671}
672
673func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
674 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
675 return fantasy.ProviderOptions{}
676 }
677 return fantasy.ProviderOptions{
678 anthropic.Name: &anthropic.ProviderCacheControlOptions{
679 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
680 },
681 bedrock.Name: &anthropic.ProviderCacheControlOptions{
682 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
683 },
684 }
685}
686
687func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
688 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
689 var attachmentParts []message.ContentPart
690 for _, attachment := range call.Attachments {
691 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
692 }
693 parts = append(parts, attachmentParts...)
694 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
695 Role: message.User,
696 Parts: parts,
697 })
698 if err != nil {
699 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
700 }
701 return msg, nil
702}
703
704func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
705 var history []fantasy.Message
706 if !a.isSubAgent {
707 history = append(history, fantasy.NewUserMessage(
708 fmt.Sprintf("<system_reminder>%s</system_reminder>",
709 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
710If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
711If not, please feel free to ignore. Again do not mention this message to the user.`,
712 ),
713 ))
714 }
715 for _, m := range msgs {
716 if len(m.Parts) == 0 {
717 continue
718 }
719 // Assistant message without content or tool calls (cancelled before it
720 // returned anything).
721 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
722 continue
723 }
724 history = append(history, m.ToAIMessage()...)
725 }
726
727 var files []fantasy.FilePart
728 for _, attachment := range attachments {
729 if attachment.IsText() {
730 continue
731 }
732 files = append(files, fantasy.FilePart{
733 Filename: attachment.FileName,
734 Data: attachment.Content,
735 MediaType: attachment.MimeType,
736 })
737 }
738
739 return history, files
740}
741
742func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
743 msgs, err := a.messages.List(ctx, session.ID)
744 if err != nil {
745 return nil, fmt.Errorf("failed to list messages: %w", err)
746 }
747
748 if session.SummaryMessageID != "" {
749 summaryMsgIndex := -1
750 for i, msg := range msgs {
751 if msg.ID == session.SummaryMessageID {
752 summaryMsgIndex = i
753 break
754 }
755 }
756 if summaryMsgIndex != -1 {
757 msgs = msgs[summaryMsgIndex:]
758 msgs[0].Role = message.User
759 }
760 }
761 return msgs, nil
762}
763
764// generateTitle generates a session titled based on the initial prompt.
765func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
766 if userPrompt == "" {
767 return
768 }
769
770 smallModel := a.smallModel.Get()
771 largeModel := a.largeModel.Get()
772 systemPromptPrefix := a.systemPromptPrefix.Get()
773
774 var maxOutputTokens int64 = 40
775 if smallModel.CatwalkCfg.CanReason {
776 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
777 }
778
779 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
780 return fantasy.NewAgent(m,
781 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
782 fantasy.WithMaxOutputTokens(tok),
783 )
784 }
785
786 streamCall := fantasy.AgentStreamCall{
787 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
788 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
789 prepared.Messages = opts.Messages
790 if systemPromptPrefix != "" {
791 prepared.Messages = append([]fantasy.Message{
792 fantasy.NewSystemMessage(systemPromptPrefix),
793 }, prepared.Messages...)
794 }
795 return callCtx, prepared, nil
796 },
797 }
798
799 // Use the small model to generate the title.
800 model := smallModel
801 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
802 resp, err := agent.Stream(ctx, streamCall)
803 if err == nil {
804 // We successfully generated a title with the small model.
805 slog.Debug("Generated title with small model")
806 } else {
807 // It didn't work. Let's try with the big model.
808 slog.Error("Error generating title with small model; trying big model", "err", err)
809 model = largeModel
810 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
811 resp, err = agent.Stream(ctx, streamCall)
812 if err == nil {
813 slog.Debug("Generated title with large model")
814 } else {
815 // Welp, the large model didn't work either. Use the default
816 // session name and return.
817 slog.Error("Error generating title with large model", "err", err)
818 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
819 if saveErr != nil {
820 slog.Error("Failed to save session title and usage", "error", saveErr)
821 }
822 return
823 }
824 }
825
826 if resp == nil {
827 // Actually, we didn't get a response so we can't. Use the default
828 // session name and return.
829 slog.Error("Response is nil; can't generate title")
830 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
831 if saveErr != nil {
832 slog.Error("Failed to save session title and usage", "error", saveErr)
833 }
834 return
835 }
836
837 // Clean up title.
838 var title string
839 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
840
841 // Remove thinking tags if present.
842 title = thinkTagRegex.ReplaceAllString(title, "")
843
844 title = strings.TrimSpace(title)
845 if title == "" {
846 slog.Debug("Empty title; using fallback")
847 title = defaultSessionName
848 }
849
850 // Calculate usage and cost.
851 var openrouterCost *float64
852 for _, step := range resp.Steps {
853 stepCost := a.openrouterCost(step.ProviderMetadata)
854 if stepCost != nil {
855 newCost := *stepCost
856 if openrouterCost != nil {
857 newCost += *openrouterCost
858 }
859 openrouterCost = &newCost
860 }
861 }
862
863 modelConfig := model.CatwalkCfg
864 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
865 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
866 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
867 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
868
869 // Use override cost if available (e.g., from OpenRouter).
870 if openrouterCost != nil {
871 cost = *openrouterCost
872 }
873
874 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
875 completionTokens := resp.TotalUsage.OutputTokens
876
877 // Atomically update only title and usage fields to avoid overriding other
878 // concurrent session updates.
879 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
880 if saveErr != nil {
881 slog.Error("Failed to save session title and usage", "error", saveErr)
882 return
883 }
884}
885
886func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
887 openrouterMetadata, ok := metadata[openrouter.Name]
888 if !ok {
889 return nil
890 }
891
892 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
893 if !ok {
894 return nil
895 }
896 return &opts.Usage.Cost
897}
898
899func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
900 modelConfig := model.CatwalkCfg
901 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
902 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
903 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
904 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
905
906 a.eventTokensUsed(session.ID, model, usage, cost)
907
908 if overrideCost != nil {
909 session.Cost += *overrideCost
910 } else {
911 session.Cost += cost
912 }
913
914 session.CompletionTokens = usage.OutputTokens
915 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
916}
917
918func (a *sessionAgent) Cancel(sessionID string) {
919 // Cancel regular requests. Don't use Take() here - we need the entry to
920 // remain in activeRequests so IsBusy() returns true until the goroutine
921 // fully completes (including error handling that may access the DB).
922 // The defer in processRequest will clean up the entry.
923 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
924 slog.Debug("Request cancellation initiated", "session_id", sessionID)
925 cancel()
926 }
927
928 // Also check for summarize requests.
929 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
930 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
931 cancel()
932 }
933
934 if a.QueuedPrompts(sessionID) > 0 {
935 slog.Debug("Clearing queued prompts", "session_id", sessionID)
936 a.messageQueue.Del(sessionID)
937 }
938}
939
940func (a *sessionAgent) ClearQueue(sessionID string) {
941 if a.QueuedPrompts(sessionID) > 0 {
942 slog.Debug("Clearing queued prompts", "session_id", sessionID)
943 a.messageQueue.Del(sessionID)
944 }
945}
946
947func (a *sessionAgent) CancelAll() {
948 if !a.IsBusy() {
949 return
950 }
951 for key := range a.activeRequests.Seq2() {
952 a.Cancel(key) // key is sessionID
953 }
954
955 timeout := time.After(5 * time.Second)
956 for a.IsBusy() {
957 select {
958 case <-timeout:
959 return
960 default:
961 time.Sleep(200 * time.Millisecond)
962 }
963 }
964}
965
966func (a *sessionAgent) IsBusy() bool {
967 var busy bool
968 for cancelFunc := range a.activeRequests.Seq() {
969 if cancelFunc != nil {
970 busy = true
971 break
972 }
973 }
974 return busy
975}
976
977func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
978 _, busy := a.activeRequests.Get(sessionID)
979 return busy
980}
981
982func (a *sessionAgent) QueuedPrompts(sessionID string) int {
983 l, ok := a.messageQueue.Get(sessionID)
984 if !ok {
985 return 0
986 }
987 return len(l)
988}
989
990func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
991 l, ok := a.messageQueue.Get(sessionID)
992 if !ok {
993 return nil
994 }
995 prompts := make([]string, len(l))
996 for i, call := range l {
997 prompts[i] = call.Prompt
998 }
999 return prompts
1000}
1001
1002func (a *sessionAgent) SetModels(large Model, small Model) {
1003 a.largeModel.Set(large)
1004 a.smallModel.Set(small)
1005}
1006
1007func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1008 a.tools.SetSlice(tools)
1009}
1010
1011func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1012 a.systemPrompt.Set(systemPrompt)
1013}
1014
1015func (a *sessionAgent) Model() Model {
1016 return a.largeModel.Get()
1017}
1018
1019// convertToToolResult converts a fantasy tool result to a message tool result.
1020func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1021 baseResult := message.ToolResult{
1022 ToolCallID: result.ToolCallID,
1023 Name: result.ToolName,
1024 Metadata: result.ClientMetadata,
1025 }
1026
1027 switch result.Result.GetType() {
1028 case fantasy.ToolResultContentTypeText:
1029 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1030 baseResult.Content = r.Text
1031 }
1032 case fantasy.ToolResultContentTypeError:
1033 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1034 baseResult.Content = r.Error.Error()
1035 baseResult.IsError = true
1036 }
1037 case fantasy.ToolResultContentTypeMedia:
1038 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1039 content := r.Text
1040 if content == "" {
1041 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1042 }
1043 baseResult.Content = content
1044 baseResult.Data = r.Data
1045 baseResult.MIMEType = r.MediaType
1046 }
1047 }
1048
1049 return baseResult
1050}
1051
1052// workaroundProviderMediaLimitations converts media content in tool results to
1053// user messages for providers that don't natively support images in tool results.
1054//
1055// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1056// don't support sending images/media in tool result messages - they only accept
1057// text in tool results. However, they DO support images in user messages.
1058//
1059// If we send media in tool results to these providers, the API returns an error.
1060//
1061// Solution: For these providers, we:
1062// 1. Replace the media in the tool result with a text placeholder
1063// 2. Inject a user message immediately after with the image as a file attachment
1064// 3. This maintains the tool execution flow while working around API limitations
1065//
1066// Anthropic and Bedrock support images natively in tool results, so we skip
1067// this workaround for them.
1068//
1069// Example transformation:
1070//
1071// BEFORE: [tool result: image data]
1072// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1073func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1074 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1075 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1076
1077 if providerSupportsMedia {
1078 return messages
1079 }
1080
1081 convertedMessages := make([]fantasy.Message, 0, len(messages))
1082
1083 for _, msg := range messages {
1084 if msg.Role != fantasy.MessageRoleTool {
1085 convertedMessages = append(convertedMessages, msg)
1086 continue
1087 }
1088
1089 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1090 var mediaFiles []fantasy.FilePart
1091
1092 for _, part := range msg.Content {
1093 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1094 if !ok {
1095 textParts = append(textParts, part)
1096 continue
1097 }
1098
1099 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1100 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1101 if err != nil {
1102 slog.Warn("Failed to decode media data", "error", err)
1103 textParts = append(textParts, part)
1104 continue
1105 }
1106
1107 mediaFiles = append(mediaFiles, fantasy.FilePart{
1108 Data: decoded,
1109 MediaType: media.MediaType,
1110 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1111 })
1112
1113 textParts = append(textParts, fantasy.ToolResultPart{
1114 ToolCallID: toolResult.ToolCallID,
1115 Output: fantasy.ToolResultOutputContentText{
1116 Text: "[Image/media content loaded - see attached file]",
1117 },
1118 ProviderOptions: toolResult.ProviderOptions,
1119 })
1120 } else {
1121 textParts = append(textParts, part)
1122 }
1123 }
1124
1125 convertedMessages = append(convertedMessages, fantasy.Message{
1126 Role: fantasy.MessageRoleTool,
1127 Content: textParts,
1128 })
1129
1130 if len(mediaFiles) > 0 {
1131 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1132 "Here is the media content from the tool result:",
1133 mediaFiles...,
1134 ))
1135 }
1136 }
1137
1138 return convertedMessages
1139}
1140
1141// buildSummaryPrompt constructs the prompt text for session summarization.
1142func buildSummaryPrompt(todos []session.Todo) string {
1143 var sb strings.Builder
1144 sb.WriteString("Provide a detailed summary of our conversation above.")
1145 if len(todos) > 0 {
1146 sb.WriteString("\n\n## Current Todo List\n\n")
1147 for _, t := range todos {
1148 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1149 }
1150 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1151 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1152 }
1153 return sb.String()
1154}