1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/tools"
36 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
37 "github.com/charmbracelet/crush/internal/config"
38 "github.com/charmbracelet/crush/internal/csync"
39 "github.com/charmbracelet/crush/internal/message"
40 "github.com/charmbracelet/crush/internal/permission"
41 "github.com/charmbracelet/crush/internal/session"
42 "github.com/charmbracelet/crush/internal/stringext"
43 "github.com/charmbracelet/x/exp/charmtone"
44)
45
46const (
47 defaultSessionName = "Untitled Session"
48
49 // Constants for auto-summarization thresholds
50 largeContextWindowThreshold = 200_000
51 largeContextWindowBuffer = 20_000
52 smallContextWindowRatio = 0.2
53)
54
55//go:embed templates/title.md
56var titlePrompt []byte
57
58//go:embed templates/summary.md
59var summaryPrompt []byte
60
61// Used to remove <think> tags from generated titles.
62var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
63
64type SessionAgentCall struct {
65 SessionID string
66 Prompt string
67 ProviderOptions fantasy.ProviderOptions
68 Attachments []message.Attachment
69 MaxOutputTokens int64
70 Temperature *float64
71 TopP *float64
72 TopK *int64
73 FrequencyPenalty *float64
74 PresencePenalty *float64
75}
76
77type SessionAgent interface {
78 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
79 SetModels(large Model, small Model)
80 SetTools(tools []fantasy.AgentTool)
81 SetSystemPrompt(systemPrompt string)
82 Cancel(sessionID string)
83 CancelAll()
84 IsSessionBusy(sessionID string) bool
85 IsBusy() bool
86 QueuedPrompts(sessionID string) int
87 QueuedPromptsList(sessionID string) []string
88 ClearQueue(sessionID string)
89 Summarize(context.Context, string, fantasy.ProviderOptions) error
90 Model() Model
91}
92
93type Model struct {
94 Model fantasy.LanguageModel
95 CatwalkCfg catwalk.Model
96 ModelCfg config.SelectedModel
97}
98
99type sessionAgent struct {
100 largeModel *csync.Value[Model]
101 smallModel *csync.Value[Model]
102 systemPromptPrefix *csync.Value[string]
103 systemPrompt *csync.Value[string]
104 tools *csync.Slice[fantasy.AgentTool]
105
106 isSubAgent bool
107 sessions session.Service
108 messages message.Service
109 disableAutoSummarize bool
110 isYolo bool
111
112 messageQueue *csync.Map[string, []SessionAgentCall]
113 activeRequests *csync.Map[string, context.CancelFunc]
114}
115
116type SessionAgentOptions struct {
117 LargeModel Model
118 SmallModel Model
119 SystemPromptPrefix string
120 SystemPrompt string
121 IsSubAgent bool
122 DisableAutoSummarize bool
123 IsYolo bool
124 Sessions session.Service
125 Messages message.Service
126 Tools []fantasy.AgentTool
127}
128
129func NewSessionAgent(
130 opts SessionAgentOptions,
131) SessionAgent {
132 return &sessionAgent{
133 largeModel: csync.NewValue(opts.LargeModel),
134 smallModel: csync.NewValue(opts.SmallModel),
135 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
136 systemPrompt: csync.NewValue(opts.SystemPrompt),
137 isSubAgent: opts.IsSubAgent,
138 sessions: opts.Sessions,
139 messages: opts.Messages,
140 disableAutoSummarize: opts.DisableAutoSummarize,
141 tools: csync.NewSliceFrom(opts.Tools),
142 isYolo: opts.IsYolo,
143 messageQueue: csync.NewMap[string, []SessionAgentCall](),
144 activeRequests: csync.NewMap[string, context.CancelFunc](),
145 }
146}
147
148func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
149 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
150 return nil, ErrEmptyPrompt
151 }
152 if call.SessionID == "" {
153 return nil, ErrSessionMissing
154 }
155
156 // Queue the message if busy
157 if a.IsSessionBusy(call.SessionID) {
158 existing, ok := a.messageQueue.Get(call.SessionID)
159 if !ok {
160 existing = []SessionAgentCall{}
161 }
162 existing = append(existing, call)
163 a.messageQueue.Set(call.SessionID, existing)
164 return nil, nil
165 }
166
167 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
168 agentTools := a.tools.Copy()
169 largeModel := a.largeModel.Get()
170 systemPrompt := a.systemPrompt.Get()
171 promptPrefix := a.systemPromptPrefix.Get()
172 var instructions strings.Builder
173
174 for _, server := range mcp.GetStates() {
175 if server.State != mcp.StateConnected {
176 continue
177 }
178 if s := server.Client.InitializeResult().Instructions; s != "" {
179 instructions.WriteString(s)
180 instructions.WriteString("\n\n")
181 }
182 }
183
184 if s := instructions.String(); s != "" {
185 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
186 }
187
188 if len(agentTools) > 0 {
189 // Add Anthropic caching to the last tool.
190 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
191 }
192
193 agent := fantasy.NewAgent(
194 largeModel.Model,
195 fantasy.WithSystemPrompt(systemPrompt),
196 fantasy.WithTools(agentTools...),
197 )
198
199 sessionLock := sync.Mutex{}
200 currentSession, err := a.sessions.Get(ctx, call.SessionID)
201 if err != nil {
202 return nil, fmt.Errorf("failed to get session: %w", err)
203 }
204
205 msgs, err := a.getSessionMessages(ctx, currentSession)
206 if err != nil {
207 return nil, fmt.Errorf("failed to get session messages: %w", err)
208 }
209
210 var wg sync.WaitGroup
211 // Generate title if first message.
212 if len(msgs) == 0 {
213 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
214 wg.Go(func() {
215 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
216 })
217 }
218 defer wg.Wait()
219
220 // Add the user message to the session.
221 _, err = a.createUserMessage(ctx, call)
222 if err != nil {
223 return nil, err
224 }
225
226 // Add the session to the context.
227 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
228
229 genCtx, cancel := context.WithCancel(ctx)
230 a.activeRequests.Set(call.SessionID, cancel)
231
232 defer cancel()
233 defer a.activeRequests.Del(call.SessionID)
234
235 history, files := a.preparePrompt(msgs, call.Attachments...)
236
237 startTime := time.Now()
238 a.eventPromptSent(call.SessionID)
239
240 var currentAssistant *message.Message
241 var shouldSummarize bool
242 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
243 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
244 Files: files,
245 Messages: history,
246 ProviderOptions: call.ProviderOptions,
247 MaxOutputTokens: &call.MaxOutputTokens,
248 TopP: call.TopP,
249 Temperature: call.Temperature,
250 PresencePenalty: call.PresencePenalty,
251 TopK: call.TopK,
252 FrequencyPenalty: call.FrequencyPenalty,
253 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
254 prepared.Messages = options.Messages
255 for i := range prepared.Messages {
256 prepared.Messages[i].ProviderOptions = nil
257 }
258
259 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
260 a.messageQueue.Del(call.SessionID)
261 for _, queued := range queuedCalls {
262 userMessage, createErr := a.createUserMessage(callContext, queued)
263 if createErr != nil {
264 return callContext, prepared, createErr
265 }
266 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
267 }
268
269 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
270
271 lastSystemRoleInx := 0
272 systemMessageUpdated := false
273 for i, msg := range prepared.Messages {
274 // Only add cache control to the last message.
275 if msg.Role == fantasy.MessageRoleSystem {
276 lastSystemRoleInx = i
277 } else if !systemMessageUpdated {
278 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
279 systemMessageUpdated = true
280 }
281 // Than add cache control to the last 2 messages.
282 if i > len(prepared.Messages)-3 {
283 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
284 }
285 }
286
287 if promptPrefix != "" {
288 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
289 }
290
291 var assistantMsg message.Message
292 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
293 Role: message.Assistant,
294 Parts: []message.ContentPart{},
295 Model: largeModel.ModelCfg.Model,
296 Provider: largeModel.ModelCfg.Provider,
297 })
298 if err != nil {
299 return callContext, prepared, err
300 }
301 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
302 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
303 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
304 currentAssistant = &assistantMsg
305 return callContext, prepared, err
306 },
307 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
308 currentAssistant.AppendReasoningContent(reasoning.Text)
309 return a.messages.Update(genCtx, *currentAssistant)
310 },
311 OnReasoningDelta: func(id string, text string) error {
312 currentAssistant.AppendReasoningContent(text)
313 return a.messages.Update(genCtx, *currentAssistant)
314 },
315 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
316 // handle anthropic signature
317 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
318 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
319 currentAssistant.AppendReasoningSignature(reasoning.Signature)
320 }
321 }
322 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
323 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
324 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
325 }
326 }
327 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
328 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
329 currentAssistant.SetReasoningResponsesData(reasoning)
330 }
331 }
332 currentAssistant.FinishThinking()
333 return a.messages.Update(genCtx, *currentAssistant)
334 },
335 OnTextDelta: func(id string, text string) error {
336 // Strip leading newline from initial text content. This is is
337 // particularly important in non-interactive mode where leading
338 // newlines are very visible.
339 if len(currentAssistant.Parts) == 0 {
340 text = strings.TrimPrefix(text, "\n")
341 }
342
343 currentAssistant.AppendContent(text)
344 return a.messages.Update(genCtx, *currentAssistant)
345 },
346 OnToolInputStart: func(id string, toolName string) error {
347 toolCall := message.ToolCall{
348 ID: id,
349 Name: toolName,
350 ProviderExecuted: false,
351 Finished: false,
352 }
353 currentAssistant.AddToolCall(toolCall)
354 return a.messages.Update(genCtx, *currentAssistant)
355 },
356 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
357 // TODO: implement
358 },
359 OnToolCall: func(tc fantasy.ToolCallContent) error {
360 toolCall := message.ToolCall{
361 ID: tc.ToolCallID,
362 Name: tc.ToolName,
363 Input: tc.Input,
364 ProviderExecuted: false,
365 Finished: true,
366 }
367 currentAssistant.AddToolCall(toolCall)
368 return a.messages.Update(genCtx, *currentAssistant)
369 },
370 OnToolResult: func(result fantasy.ToolResultContent) error {
371 toolResult := a.convertToToolResult(result)
372 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
373 Role: message.Tool,
374 Parts: []message.ContentPart{
375 toolResult,
376 },
377 })
378 return createMsgErr
379 },
380 OnStepFinish: func(stepResult fantasy.StepResult) error {
381 finishReason := message.FinishReasonUnknown
382 switch stepResult.FinishReason {
383 case fantasy.FinishReasonLength:
384 finishReason = message.FinishReasonMaxTokens
385 case fantasy.FinishReasonStop:
386 finishReason = message.FinishReasonEndTurn
387 case fantasy.FinishReasonToolCalls:
388 finishReason = message.FinishReasonToolUse
389 }
390 currentAssistant.AddFinish(finishReason, "", "")
391 sessionLock.Lock()
392 defer sessionLock.Unlock()
393
394 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
395 if getSessionErr != nil {
396 return getSessionErr
397 }
398 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
399 _, sessionErr := a.sessions.Save(ctx, updatedSession)
400 if sessionErr != nil {
401 return sessionErr
402 }
403 currentSession = updatedSession
404 return a.messages.Update(genCtx, *currentAssistant)
405 },
406 StopWhen: []fantasy.StopCondition{
407 func(_ []fantasy.StepResult) bool {
408 cw := int64(largeModel.CatwalkCfg.ContextWindow)
409 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
410 remaining := cw - tokens
411 var threshold int64
412 if cw > largeContextWindowThreshold {
413 threshold = largeContextWindowBuffer
414 } else {
415 threshold = int64(float64(cw) * smallContextWindowRatio)
416 }
417 if (remaining <= threshold) && !a.disableAutoSummarize {
418 shouldSummarize = true
419 return true
420 }
421 return false
422 },
423 },
424 })
425
426 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
427
428 if err != nil {
429 isCancelErr := errors.Is(err, context.Canceled)
430 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
431 if currentAssistant == nil {
432 return result, err
433 }
434 // Ensure we finish thinking on error to close the reasoning state.
435 currentAssistant.FinishThinking()
436 toolCalls := currentAssistant.ToolCalls()
437 // INFO: we use the parent context here because the genCtx has been cancelled.
438 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
439 if createErr != nil {
440 return nil, createErr
441 }
442 for _, tc := range toolCalls {
443 if !tc.Finished {
444 tc.Finished = true
445 tc.Input = "{}"
446 currentAssistant.AddToolCall(tc)
447 updateErr := a.messages.Update(ctx, *currentAssistant)
448 if updateErr != nil {
449 return nil, updateErr
450 }
451 }
452
453 found := false
454 for _, msg := range msgs {
455 if msg.Role == message.Tool {
456 for _, tr := range msg.ToolResults() {
457 if tr.ToolCallID == tc.ID {
458 found = true
459 break
460 }
461 }
462 }
463 if found {
464 break
465 }
466 }
467 if found {
468 continue
469 }
470 content := "There was an error while executing the tool"
471 if isCancelErr {
472 content = "Tool execution canceled by user"
473 } else if isPermissionErr {
474 content = "User denied permission"
475 }
476 toolResult := message.ToolResult{
477 ToolCallID: tc.ID,
478 Name: tc.Name,
479 Content: content,
480 IsError: true,
481 }
482 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
483 Role: message.Tool,
484 Parts: []message.ContentPart{
485 toolResult,
486 },
487 })
488 if createErr != nil {
489 return nil, createErr
490 }
491 }
492 var fantasyErr *fantasy.Error
493 var providerErr *fantasy.ProviderError
494 const defaultTitle = "Provider Error"
495 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
496 if isCancelErr {
497 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
498 } else if isPermissionErr {
499 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
500 } else if errors.Is(err, hyper.ErrNoCredits) {
501 url := hyper.BaseURL()
502 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
503 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
504 } else if errors.As(err, &providerErr) {
505 if providerErr.Message == "The requested model is not supported." {
506 url := "https://github.com/settings/copilot/features"
507 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
508 currentAssistant.AddFinish(
509 message.FinishReasonError,
510 "Copilot model not enabled",
511 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
512 )
513 } else {
514 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
515 }
516 } else if errors.As(err, &fantasyErr) {
517 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
518 } else {
519 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
520 }
521 // Note: we use the parent context here because the genCtx has been
522 // cancelled.
523 updateErr := a.messages.Update(ctx, *currentAssistant)
524 if updateErr != nil {
525 return nil, updateErr
526 }
527 return nil, err
528 }
529
530 if shouldSummarize {
531 a.activeRequests.Del(call.SessionID)
532 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
533 return nil, summarizeErr
534 }
535 // If the agent wasn't done...
536 if len(currentAssistant.ToolCalls()) > 0 {
537 existing, ok := a.messageQueue.Get(call.SessionID)
538 if !ok {
539 existing = []SessionAgentCall{}
540 }
541 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
542 existing = append(existing, call)
543 a.messageQueue.Set(call.SessionID, existing)
544 }
545 }
546
547 // Release active request before processing queued messages.
548 a.activeRequests.Del(call.SessionID)
549 cancel()
550
551 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
552 if !ok || len(queuedMessages) == 0 {
553 return result, err
554 }
555 // There are queued messages restart the loop.
556 firstQueuedMessage := queuedMessages[0]
557 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
558 return a.Run(ctx, firstQueuedMessage)
559}
560
561func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
562 if a.IsSessionBusy(sessionID) {
563 return ErrSessionBusy
564 }
565
566 // Copy mutable fields under lock to avoid races with SetModels.
567 largeModel := a.largeModel.Get()
568 systemPromptPrefix := a.systemPromptPrefix.Get()
569
570 currentSession, err := a.sessions.Get(ctx, sessionID)
571 if err != nil {
572 return fmt.Errorf("failed to get session: %w", err)
573 }
574 msgs, err := a.getSessionMessages(ctx, currentSession)
575 if err != nil {
576 return err
577 }
578 if len(msgs) == 0 {
579 // Nothing to summarize.
580 return nil
581 }
582
583 aiMsgs, _ := a.preparePrompt(msgs)
584
585 genCtx, cancel := context.WithCancel(ctx)
586 a.activeRequests.Set(sessionID, cancel)
587 defer a.activeRequests.Del(sessionID)
588 defer cancel()
589
590 agent := fantasy.NewAgent(largeModel.Model,
591 fantasy.WithSystemPrompt(string(summaryPrompt)),
592 )
593 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
594 Role: message.Assistant,
595 Model: largeModel.Model.Model(),
596 Provider: largeModel.Model.Provider(),
597 IsSummaryMessage: true,
598 })
599 if err != nil {
600 return err
601 }
602
603 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
604
605 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
606 Prompt: summaryPromptText,
607 Messages: aiMsgs,
608 ProviderOptions: opts,
609 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
610 prepared.Messages = options.Messages
611 if systemPromptPrefix != "" {
612 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
613 }
614 return callContext, prepared, nil
615 },
616 OnReasoningDelta: func(id string, text string) error {
617 summaryMessage.AppendReasoningContent(text)
618 return a.messages.Update(genCtx, summaryMessage)
619 },
620 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
621 // Handle anthropic signature.
622 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
623 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
624 summaryMessage.AppendReasoningSignature(signature.Signature)
625 }
626 }
627 summaryMessage.FinishThinking()
628 return a.messages.Update(genCtx, summaryMessage)
629 },
630 OnTextDelta: func(id, text string) error {
631 summaryMessage.AppendContent(text)
632 return a.messages.Update(genCtx, summaryMessage)
633 },
634 })
635 if err != nil {
636 isCancelErr := errors.Is(err, context.Canceled)
637 if isCancelErr {
638 // User cancelled summarize we need to remove the summary message.
639 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
640 return deleteErr
641 }
642 return err
643 }
644
645 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
646 err = a.messages.Update(genCtx, summaryMessage)
647 if err != nil {
648 return err
649 }
650
651 var openrouterCost *float64
652 for _, step := range resp.Steps {
653 stepCost := a.openrouterCost(step.ProviderMetadata)
654 if stepCost != nil {
655 newCost := *stepCost
656 if openrouterCost != nil {
657 newCost += *openrouterCost
658 }
659 openrouterCost = &newCost
660 }
661 }
662
663 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
664
665 // Just in case, get just the last usage info.
666 usage := resp.Response.Usage
667 currentSession.SummaryMessageID = summaryMessage.ID
668 currentSession.CompletionTokens = usage.OutputTokens
669 currentSession.PromptTokens = 0
670 _, err = a.sessions.Save(genCtx, currentSession)
671 return err
672}
673
674func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
675 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
676 return fantasy.ProviderOptions{}
677 }
678 return fantasy.ProviderOptions{
679 anthropic.Name: &anthropic.ProviderCacheControlOptions{
680 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
681 },
682 bedrock.Name: &anthropic.ProviderCacheControlOptions{
683 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
684 },
685 vercel.Name: &anthropic.ProviderCacheControlOptions{
686 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
687 },
688 }
689}
690
691func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
692 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
693 var attachmentParts []message.ContentPart
694 for _, attachment := range call.Attachments {
695 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
696 }
697 parts = append(parts, attachmentParts...)
698 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
699 Role: message.User,
700 Parts: parts,
701 })
702 if err != nil {
703 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
704 }
705 return msg, nil
706}
707
708func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
709 var history []fantasy.Message
710 if !a.isSubAgent {
711 history = append(history, fantasy.NewUserMessage(
712 fmt.Sprintf("<system_reminder>%s</system_reminder>",
713 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
714If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
715If not, please feel free to ignore. Again do not mention this message to the user.`,
716 ),
717 ))
718 }
719 for _, m := range msgs {
720 if len(m.Parts) == 0 {
721 continue
722 }
723 // Assistant message without content or tool calls (cancelled before it
724 // returned anything).
725 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
726 continue
727 }
728 history = append(history, m.ToAIMessage()...)
729 }
730
731 var files []fantasy.FilePart
732 for _, attachment := range attachments {
733 if attachment.IsText() {
734 continue
735 }
736 files = append(files, fantasy.FilePart{
737 Filename: attachment.FileName,
738 Data: attachment.Content,
739 MediaType: attachment.MimeType,
740 })
741 }
742
743 return history, files
744}
745
746func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
747 msgs, err := a.messages.List(ctx, session.ID)
748 if err != nil {
749 return nil, fmt.Errorf("failed to list messages: %w", err)
750 }
751
752 if session.SummaryMessageID != "" {
753 summaryMsgIndex := -1
754 for i, msg := range msgs {
755 if msg.ID == session.SummaryMessageID {
756 summaryMsgIndex = i
757 break
758 }
759 }
760 if summaryMsgIndex != -1 {
761 msgs = msgs[summaryMsgIndex:]
762 msgs[0].Role = message.User
763 }
764 }
765 return msgs, nil
766}
767
768// generateTitle generates a session titled based on the initial prompt.
769func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
770 if userPrompt == "" {
771 return
772 }
773
774 smallModel := a.smallModel.Get()
775 largeModel := a.largeModel.Get()
776 systemPromptPrefix := a.systemPromptPrefix.Get()
777
778 var maxOutputTokens int64 = 40
779 if smallModel.CatwalkCfg.CanReason {
780 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
781 }
782
783 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
784 return fantasy.NewAgent(m,
785 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
786 fantasy.WithMaxOutputTokens(tok),
787 )
788 }
789
790 streamCall := fantasy.AgentStreamCall{
791 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
792 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
793 prepared.Messages = opts.Messages
794 if systemPromptPrefix != "" {
795 prepared.Messages = append([]fantasy.Message{
796 fantasy.NewSystemMessage(systemPromptPrefix),
797 }, prepared.Messages...)
798 }
799 return callCtx, prepared, nil
800 },
801 }
802
803 // Use the small model to generate the title.
804 model := smallModel
805 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
806 resp, err := agent.Stream(ctx, streamCall)
807 if err == nil {
808 // We successfully generated a title with the small model.
809 slog.Debug("Generated title with small model")
810 } else {
811 // It didn't work. Let's try with the big model.
812 slog.Error("Error generating title with small model; trying big model", "err", err)
813 model = largeModel
814 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
815 resp, err = agent.Stream(ctx, streamCall)
816 if err == nil {
817 slog.Debug("Generated title with large model")
818 } else {
819 // Welp, the large model didn't work either. Use the default
820 // session name and return.
821 slog.Error("Error generating title with large model", "err", err)
822 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
823 if saveErr != nil {
824 slog.Error("Failed to save session title and usage", "error", saveErr)
825 }
826 return
827 }
828 }
829
830 if resp == nil {
831 // Actually, we didn't get a response so we can't. Use the default
832 // session name and return.
833 slog.Error("Response is nil; can't generate title")
834 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
835 if saveErr != nil {
836 slog.Error("Failed to save session title and usage", "error", saveErr)
837 }
838 return
839 }
840
841 // Clean up title.
842 var title string
843 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
844
845 // Remove thinking tags if present.
846 title = thinkTagRegex.ReplaceAllString(title, "")
847
848 title = strings.TrimSpace(title)
849 if title == "" {
850 slog.Debug("Empty title; using fallback")
851 title = defaultSessionName
852 }
853
854 // Calculate usage and cost.
855 var openrouterCost *float64
856 for _, step := range resp.Steps {
857 stepCost := a.openrouterCost(step.ProviderMetadata)
858 if stepCost != nil {
859 newCost := *stepCost
860 if openrouterCost != nil {
861 newCost += *openrouterCost
862 }
863 openrouterCost = &newCost
864 }
865 }
866
867 modelConfig := model.CatwalkCfg
868 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
869 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
870 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
871 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
872
873 // Use override cost if available (e.g., from OpenRouter).
874 if openrouterCost != nil {
875 cost = *openrouterCost
876 }
877
878 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
879 completionTokens := resp.TotalUsage.OutputTokens
880
881 // Atomically update only title and usage fields to avoid overriding other
882 // concurrent session updates.
883 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
884 if saveErr != nil {
885 slog.Error("Failed to save session title and usage", "error", saveErr)
886 return
887 }
888}
889
890func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
891 openrouterMetadata, ok := metadata[openrouter.Name]
892 if !ok {
893 return nil
894 }
895
896 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
897 if !ok {
898 return nil
899 }
900 return &opts.Usage.Cost
901}
902
903func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
904 modelConfig := model.CatwalkCfg
905 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
906 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
907 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
908 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
909
910 a.eventTokensUsed(session.ID, model, usage, cost)
911
912 if overrideCost != nil {
913 session.Cost += *overrideCost
914 } else {
915 session.Cost += cost
916 }
917
918 session.CompletionTokens = usage.OutputTokens
919 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
920}
921
922func (a *sessionAgent) Cancel(sessionID string) {
923 // Cancel regular requests. Don't use Take() here - we need the entry to
924 // remain in activeRequests so IsBusy() returns true until the goroutine
925 // fully completes (including error handling that may access the DB).
926 // The defer in processRequest will clean up the entry.
927 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
928 slog.Debug("Request cancellation initiated", "session_id", sessionID)
929 cancel()
930 }
931
932 // Also check for summarize requests.
933 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
934 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
935 cancel()
936 }
937
938 if a.QueuedPrompts(sessionID) > 0 {
939 slog.Debug("Clearing queued prompts", "session_id", sessionID)
940 a.messageQueue.Del(sessionID)
941 }
942}
943
944func (a *sessionAgent) ClearQueue(sessionID string) {
945 if a.QueuedPrompts(sessionID) > 0 {
946 slog.Debug("Clearing queued prompts", "session_id", sessionID)
947 a.messageQueue.Del(sessionID)
948 }
949}
950
951func (a *sessionAgent) CancelAll() {
952 if !a.IsBusy() {
953 return
954 }
955 for key := range a.activeRequests.Seq2() {
956 a.Cancel(key) // key is sessionID
957 }
958
959 timeout := time.After(5 * time.Second)
960 for a.IsBusy() {
961 select {
962 case <-timeout:
963 return
964 default:
965 time.Sleep(200 * time.Millisecond)
966 }
967 }
968}
969
970func (a *sessionAgent) IsBusy() bool {
971 var busy bool
972 for cancelFunc := range a.activeRequests.Seq() {
973 if cancelFunc != nil {
974 busy = true
975 break
976 }
977 }
978 return busy
979}
980
981func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
982 _, busy := a.activeRequests.Get(sessionID)
983 return busy
984}
985
986func (a *sessionAgent) QueuedPrompts(sessionID string) int {
987 l, ok := a.messageQueue.Get(sessionID)
988 if !ok {
989 return 0
990 }
991 return len(l)
992}
993
994func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
995 l, ok := a.messageQueue.Get(sessionID)
996 if !ok {
997 return nil
998 }
999 prompts := make([]string, len(l))
1000 for i, call := range l {
1001 prompts[i] = call.Prompt
1002 }
1003 return prompts
1004}
1005
1006func (a *sessionAgent) SetModels(large Model, small Model) {
1007 a.largeModel.Set(large)
1008 a.smallModel.Set(small)
1009}
1010
1011func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1012 a.tools.SetSlice(tools)
1013}
1014
1015func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1016 a.systemPrompt.Set(systemPrompt)
1017}
1018
1019func (a *sessionAgent) Model() Model {
1020 return a.largeModel.Get()
1021}
1022
1023// convertToToolResult converts a fantasy tool result to a message tool result.
1024func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1025 baseResult := message.ToolResult{
1026 ToolCallID: result.ToolCallID,
1027 Name: result.ToolName,
1028 Metadata: result.ClientMetadata,
1029 }
1030
1031 switch result.Result.GetType() {
1032 case fantasy.ToolResultContentTypeText:
1033 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1034 baseResult.Content = r.Text
1035 }
1036 case fantasy.ToolResultContentTypeError:
1037 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1038 baseResult.Content = r.Error.Error()
1039 baseResult.IsError = true
1040 }
1041 case fantasy.ToolResultContentTypeMedia:
1042 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1043 content := r.Text
1044 if content == "" {
1045 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1046 }
1047 baseResult.Content = content
1048 baseResult.Data = r.Data
1049 baseResult.MIMEType = r.MediaType
1050 }
1051 }
1052
1053 return baseResult
1054}
1055
1056// workaroundProviderMediaLimitations converts media content in tool results to
1057// user messages for providers that don't natively support images in tool results.
1058//
1059// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1060// don't support sending images/media in tool result messages - they only accept
1061// text in tool results. However, they DO support images in user messages.
1062//
1063// If we send media in tool results to these providers, the API returns an error.
1064//
1065// Solution: For these providers, we:
1066// 1. Replace the media in the tool result with a text placeholder
1067// 2. Inject a user message immediately after with the image as a file attachment
1068// 3. This maintains the tool execution flow while working around API limitations
1069//
1070// Anthropic and Bedrock support images natively in tool results, so we skip
1071// this workaround for them.
1072//
1073// Example transformation:
1074//
1075// BEFORE: [tool result: image data]
1076// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1077func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1078 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1079 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1080
1081 if providerSupportsMedia {
1082 return messages
1083 }
1084
1085 convertedMessages := make([]fantasy.Message, 0, len(messages))
1086
1087 for _, msg := range messages {
1088 if msg.Role != fantasy.MessageRoleTool {
1089 convertedMessages = append(convertedMessages, msg)
1090 continue
1091 }
1092
1093 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1094 var mediaFiles []fantasy.FilePart
1095
1096 for _, part := range msg.Content {
1097 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1098 if !ok {
1099 textParts = append(textParts, part)
1100 continue
1101 }
1102
1103 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1104 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1105 if err != nil {
1106 slog.Warn("Failed to decode media data", "error", err)
1107 textParts = append(textParts, part)
1108 continue
1109 }
1110
1111 mediaFiles = append(mediaFiles, fantasy.FilePart{
1112 Data: decoded,
1113 MediaType: media.MediaType,
1114 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1115 })
1116
1117 textParts = append(textParts, fantasy.ToolResultPart{
1118 ToolCallID: toolResult.ToolCallID,
1119 Output: fantasy.ToolResultOutputContentText{
1120 Text: "[Image/media content loaded - see attached file]",
1121 },
1122 ProviderOptions: toolResult.ProviderOptions,
1123 })
1124 } else {
1125 textParts = append(textParts, part)
1126 }
1127 }
1128
1129 convertedMessages = append(convertedMessages, fantasy.Message{
1130 Role: fantasy.MessageRoleTool,
1131 Content: textParts,
1132 })
1133
1134 if len(mediaFiles) > 0 {
1135 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1136 "Here is the media content from the tool result:",
1137 mediaFiles...,
1138 ))
1139 }
1140 }
1141
1142 return convertedMessages
1143}
1144
1145// buildSummaryPrompt constructs the prompt text for session summarization.
1146func buildSummaryPrompt(todos []session.Todo) string {
1147 var sb strings.Builder
1148 sb.WriteString("Provide a detailed summary of our conversation above.")
1149 if len(todos) > 0 {
1150 sb.WriteString("\n\n## Current Todo List\n\n")
1151 for _, t := range todos {
1152 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1153 }
1154 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1155 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1156 }
1157 return sb.String()
1158}