1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/tools"
36 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
37 "github.com/charmbracelet/crush/internal/config"
38 "github.com/charmbracelet/crush/internal/csync"
39 "github.com/charmbracelet/crush/internal/message"
40 "github.com/charmbracelet/crush/internal/permission"
41 "github.com/charmbracelet/crush/internal/session"
42 "github.com/charmbracelet/crush/internal/stringext"
43 "github.com/charmbracelet/x/exp/charmtone"
44)
45
46const (
47 defaultSessionName = "Untitled Session"
48
49 // Constants for auto-summarization thresholds
50 largeContextWindowThreshold = 200_000
51 largeContextWindowBuffer = 20_000
52 smallContextWindowRatio = 0.2
53)
54
55//go:embed templates/title.md
56var titlePrompt []byte
57
58//go:embed templates/summary.md
59var summaryPrompt []byte
60
61// Used to remove <think> tags from generated titles.
62var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
63
64type SessionAgentCall struct {
65 SessionID string
66 Prompt string
67 ProviderOptions fantasy.ProviderOptions
68 Attachments []message.Attachment
69 MaxOutputTokens int64
70 ShowToolCalls bool
71 Temperature *float64
72 TopP *float64
73 TopK *int64
74 FrequencyPenalty *float64
75 PresencePenalty *float64
76}
77
78type SessionAgent interface {
79 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
80 SetModels(large Model, small Model)
81 SetTools(tools []fantasy.AgentTool)
82 SetSystemPrompt(systemPrompt string)
83 Cancel(sessionID string)
84 CancelAll()
85 IsSessionBusy(sessionID string) bool
86 IsBusy() bool
87 QueuedPrompts(sessionID string) int
88 QueuedPromptsList(sessionID string) []string
89 ClearQueue(sessionID string)
90 Summarize(context.Context, string, fantasy.ProviderOptions) error
91 Model() Model
92}
93
94type Model struct {
95 Model fantasy.LanguageModel
96 CatwalkCfg catwalk.Model
97 ModelCfg config.SelectedModel
98}
99
100type sessionAgent struct {
101 largeModel *csync.Value[Model]
102 smallModel *csync.Value[Model]
103 systemPromptPrefix *csync.Value[string]
104 systemPrompt *csync.Value[string]
105 tools *csync.Slice[fantasy.AgentTool]
106
107 isSubAgent bool
108 sessions session.Service
109 messages message.Service
110 disableAutoSummarize bool
111 isYolo bool
112
113 messageQueue *csync.Map[string, []SessionAgentCall]
114 activeRequests *csync.Map[string, context.CancelFunc]
115}
116
117type SessionAgentOptions struct {
118 LargeModel Model
119 SmallModel Model
120 SystemPromptPrefix string
121 SystemPrompt string
122 IsSubAgent bool
123 DisableAutoSummarize bool
124 IsYolo bool
125 Sessions session.Service
126 Messages message.Service
127 Tools []fantasy.AgentTool
128}
129
130func NewSessionAgent(
131 opts SessionAgentOptions,
132) SessionAgent {
133 return &sessionAgent{
134 largeModel: csync.NewValue(opts.LargeModel),
135 smallModel: csync.NewValue(opts.SmallModel),
136 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
137 systemPrompt: csync.NewValue(opts.SystemPrompt),
138 isSubAgent: opts.IsSubAgent,
139 sessions: opts.Sessions,
140 messages: opts.Messages,
141 disableAutoSummarize: opts.DisableAutoSummarize,
142 tools: csync.NewSliceFrom(opts.Tools),
143 isYolo: opts.IsYolo,
144 messageQueue: csync.NewMap[string, []SessionAgentCall](),
145 activeRequests: csync.NewMap[string, context.CancelFunc](),
146 }
147}
148
149func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
150 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
151 return nil, ErrEmptyPrompt
152 }
153 if call.SessionID == "" {
154 return nil, ErrSessionMissing
155 }
156
157 // Queue the message if busy
158 if a.IsSessionBusy(call.SessionID) {
159 existing, ok := a.messageQueue.Get(call.SessionID)
160 if !ok {
161 existing = []SessionAgentCall{}
162 }
163 existing = append(existing, call)
164 a.messageQueue.Set(call.SessionID, existing)
165 return nil, nil
166 }
167
168 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
169 agentTools := a.tools.Copy()
170 largeModel := a.largeModel.Get()
171 systemPrompt := a.systemPrompt.Get()
172 promptPrefix := a.systemPromptPrefix.Get()
173 var instructions strings.Builder
174
175 for _, server := range mcp.GetStates() {
176 if server.State != mcp.StateConnected {
177 continue
178 }
179 if s := server.Client.InitializeResult().Instructions; s != "" {
180 instructions.WriteString(s)
181 instructions.WriteString("\n\n")
182 }
183 }
184
185 if s := instructions.String(); s != "" {
186 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
187 }
188
189 if len(agentTools) > 0 {
190 // Add Anthropic caching to the last tool.
191 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
192 }
193
194 agent := fantasy.NewAgent(
195 largeModel.Model,
196 fantasy.WithSystemPrompt(systemPrompt),
197 fantasy.WithTools(agentTools...),
198 )
199
200 sessionLock := sync.Mutex{}
201 currentSession, err := a.sessions.Get(ctx, call.SessionID)
202 if err != nil {
203 return nil, fmt.Errorf("failed to get session: %w", err)
204 }
205
206 msgs, err := a.getSessionMessages(ctx, currentSession)
207 if err != nil {
208 return nil, fmt.Errorf("failed to get session messages: %w", err)
209 }
210
211 var wg sync.WaitGroup
212 // Generate title if first message.
213 if len(msgs) == 0 {
214 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
215 wg.Go(func() {
216 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
217 })
218 }
219 defer wg.Wait()
220
221 // Add the user message to the session.
222 _, err = a.createUserMessage(ctx, call)
223 if err != nil {
224 return nil, err
225 }
226
227 // Add the session to the context.
228 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
229
230 genCtx, cancel := context.WithCancel(ctx)
231 a.activeRequests.Set(call.SessionID, cancel)
232
233 defer cancel()
234 defer a.activeRequests.Del(call.SessionID)
235
236 history, files := a.preparePrompt(msgs, call.Attachments...)
237
238 startTime := time.Now()
239 a.eventPromptSent(call.SessionID)
240
241 var currentAssistant *message.Message
242 var shouldSummarize bool
243 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
244 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
245 Files: files,
246 Messages: history,
247 ProviderOptions: call.ProviderOptions,
248 MaxOutputTokens: &call.MaxOutputTokens,
249 TopP: call.TopP,
250 Temperature: call.Temperature,
251 PresencePenalty: call.PresencePenalty,
252 TopK: call.TopK,
253 FrequencyPenalty: call.FrequencyPenalty,
254 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
255 prepared.Messages = options.Messages
256 for i := range prepared.Messages {
257 prepared.Messages[i].ProviderOptions = nil
258 }
259
260 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
261 a.messageQueue.Del(call.SessionID)
262 for _, queued := range queuedCalls {
263 userMessage, createErr := a.createUserMessage(callContext, queued)
264 if createErr != nil {
265 return callContext, prepared, createErr
266 }
267 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
268 }
269
270 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
271
272 lastSystemRoleInx := 0
273 systemMessageUpdated := false
274 for i, msg := range prepared.Messages {
275 // Only add cache control to the last message.
276 if msg.Role == fantasy.MessageRoleSystem {
277 lastSystemRoleInx = i
278 } else if !systemMessageUpdated {
279 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
280 systemMessageUpdated = true
281 }
282 // Than add cache control to the last 2 messages.
283 if i > len(prepared.Messages)-3 {
284 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
285 }
286 }
287
288 if promptPrefix != "" {
289 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
290 }
291
292 var assistantMsg message.Message
293 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
294 Role: message.Assistant,
295 Parts: []message.ContentPart{},
296 Model: largeModel.ModelCfg.Model,
297 Provider: largeModel.ModelCfg.Provider,
298 })
299 if err != nil {
300 return callContext, prepared, err
301 }
302 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
303 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
304 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
305 currentAssistant = &assistantMsg
306 return callContext, prepared, err
307 },
308 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
309 currentAssistant.AppendReasoningContent(reasoning.Text)
310 return a.messages.Update(genCtx, *currentAssistant)
311 },
312 OnReasoningDelta: func(id string, text string) error {
313 currentAssistant.AppendReasoningContent(text)
314 return a.messages.Update(genCtx, *currentAssistant)
315 },
316 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
317 // handle anthropic signature
318 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
319 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
320 currentAssistant.AppendReasoningSignature(reasoning.Signature)
321 }
322 }
323 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
324 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
325 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
326 }
327 }
328 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
329 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
330 currentAssistant.SetReasoningResponsesData(reasoning)
331 }
332 }
333 currentAssistant.FinishThinking()
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnTextDelta: func(id string, text string) error {
337 // Strip leading newline from initial text content. This is is
338 // particularly important in non-interactive mode where leading
339 // newlines are very visible.
340 if len(currentAssistant.Parts) == 0 {
341 text = strings.TrimPrefix(text, "\n")
342 }
343
344 currentAssistant.AppendContent(text)
345 return a.messages.Update(genCtx, *currentAssistant)
346 },
347 OnToolInputStart: func(id string, toolName string) error {
348 toolCall := message.ToolCall{
349 ID: id,
350 Name: toolName,
351 ProviderExecuted: false,
352 Finished: false,
353 }
354 currentAssistant.AddToolCall(toolCall)
355 return a.messages.Update(genCtx, *currentAssistant)
356 },
357 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
358 // TODO: implement
359 },
360 OnToolCall: func(tc fantasy.ToolCallContent) error {
361 if call.ShowToolCalls {
362 slog.Default().WithGroup("TOOL").Info("call", "name", tc.ToolName, "input", tc.Input)
363 }
364 toolCall := message.ToolCall{
365 ID: tc.ToolCallID,
366 Name: tc.ToolName,
367 Input: tc.Input,
368 ProviderExecuted: false,
369 Finished: true,
370 }
371 currentAssistant.AddToolCall(toolCall)
372 return a.messages.Update(genCtx, *currentAssistant)
373 },
374 OnToolResult: func(result fantasy.ToolResultContent) error {
375 if call.ShowToolCalls {
376 content := ""
377 switch result.Result.GetType() {
378 case fantasy.ToolResultContentTypeText:
379 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
380 content = r.Text
381 }
382 case fantasy.ToolResultContentTypeError:
383 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
384 content = "Error: " + r.Error.Error()
385 }
386 case fantasy.ToolResultContentTypeMedia:
387 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
388 content = r.Text
389 if content == "" {
390 content = fmt.Sprintf("[%s content]", r.MediaType)
391 }
392 }
393 }
394 if len(content) > 200 {
395 content = content[:200] + "..."
396 }
397 slog.Default().WithGroup("TOOL").Info("result", "name", result.ToolName, "output", content)
398 }
399 toolResult := a.convertToToolResult(result)
400 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
401 Role: message.Tool,
402 Parts: []message.ContentPart{
403 toolResult,
404 },
405 })
406 return createMsgErr
407 },
408 OnStepFinish: func(stepResult fantasy.StepResult) error {
409 finishReason := message.FinishReasonUnknown
410 switch stepResult.FinishReason {
411 case fantasy.FinishReasonLength:
412 finishReason = message.FinishReasonMaxTokens
413 case fantasy.FinishReasonStop:
414 finishReason = message.FinishReasonEndTurn
415 case fantasy.FinishReasonToolCalls:
416 finishReason = message.FinishReasonToolUse
417 }
418 currentAssistant.AddFinish(finishReason, "", "")
419 sessionLock.Lock()
420 defer sessionLock.Unlock()
421
422 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
423 if getSessionErr != nil {
424 return getSessionErr
425 }
426 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
427 _, sessionErr := a.sessions.Save(ctx, updatedSession)
428 if sessionErr != nil {
429 return sessionErr
430 }
431 currentSession = updatedSession
432 return a.messages.Update(genCtx, *currentAssistant)
433 },
434 StopWhen: []fantasy.StopCondition{
435 func(_ []fantasy.StepResult) bool {
436 cw := int64(largeModel.CatwalkCfg.ContextWindow)
437 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
438 remaining := cw - tokens
439 var threshold int64
440 if cw > largeContextWindowThreshold {
441 threshold = largeContextWindowBuffer
442 } else {
443 threshold = int64(float64(cw) * smallContextWindowRatio)
444 }
445 if (remaining <= threshold) && !a.disableAutoSummarize {
446 shouldSummarize = true
447 return true
448 }
449 return false
450 },
451 func(steps []fantasy.StepResult) bool {
452 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
453 },
454 },
455 })
456
457 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
458
459 if err != nil {
460 isCancelErr := errors.Is(err, context.Canceled)
461 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
462 if currentAssistant == nil {
463 return result, err
464 }
465 // Ensure we finish thinking on error to close the reasoning state.
466 currentAssistant.FinishThinking()
467 toolCalls := currentAssistant.ToolCalls()
468 // INFO: we use the parent context here because the genCtx has been cancelled.
469 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
470 if createErr != nil {
471 return nil, createErr
472 }
473 for _, tc := range toolCalls {
474 if !tc.Finished {
475 tc.Finished = true
476 tc.Input = "{}"
477 currentAssistant.AddToolCall(tc)
478 updateErr := a.messages.Update(ctx, *currentAssistant)
479 if updateErr != nil {
480 return nil, updateErr
481 }
482 }
483
484 found := false
485 for _, msg := range msgs {
486 if msg.Role == message.Tool {
487 for _, tr := range msg.ToolResults() {
488 if tr.ToolCallID == tc.ID {
489 found = true
490 break
491 }
492 }
493 }
494 if found {
495 break
496 }
497 }
498 if found {
499 continue
500 }
501 content := "There was an error while executing the tool"
502 if isCancelErr {
503 content = "Tool execution canceled by user"
504 } else if isPermissionErr {
505 content = "User denied permission"
506 }
507 toolResult := message.ToolResult{
508 ToolCallID: tc.ID,
509 Name: tc.Name,
510 Content: content,
511 IsError: true,
512 }
513 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
514 Role: message.Tool,
515 Parts: []message.ContentPart{
516 toolResult,
517 },
518 })
519 if createErr != nil {
520 return nil, createErr
521 }
522 }
523 var fantasyErr *fantasy.Error
524 var providerErr *fantasy.ProviderError
525 const defaultTitle = "Provider Error"
526 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
527 if isCancelErr {
528 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
529 } else if isPermissionErr {
530 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
531 } else if errors.Is(err, hyper.ErrNoCredits) {
532 url := hyper.BaseURL()
533 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
534 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
535 } else if errors.As(err, &providerErr) {
536 if providerErr.Message == "The requested model is not supported." {
537 url := "https://github.com/settings/copilot/features"
538 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
539 currentAssistant.AddFinish(
540 message.FinishReasonError,
541 "Copilot model not enabled",
542 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
543 )
544 } else {
545 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
546 }
547 } else if errors.As(err, &fantasyErr) {
548 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
549 } else {
550 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
551 }
552 // Note: we use the parent context here because the genCtx has been
553 // cancelled.
554 updateErr := a.messages.Update(ctx, *currentAssistant)
555 if updateErr != nil {
556 return nil, updateErr
557 }
558 return nil, err
559 }
560
561 if shouldSummarize {
562 a.activeRequests.Del(call.SessionID)
563 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
564 return nil, summarizeErr
565 }
566 // If the agent wasn't done...
567 if len(currentAssistant.ToolCalls()) > 0 {
568 existing, ok := a.messageQueue.Get(call.SessionID)
569 if !ok {
570 existing = []SessionAgentCall{}
571 }
572 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
573 existing = append(existing, call)
574 a.messageQueue.Set(call.SessionID, existing)
575 }
576 }
577
578 // Release active request before processing queued messages.
579 a.activeRequests.Del(call.SessionID)
580 cancel()
581
582 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
583 if !ok || len(queuedMessages) == 0 {
584 return result, err
585 }
586 // There are queued messages restart the loop.
587 firstQueuedMessage := queuedMessages[0]
588 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
589 return a.Run(ctx, firstQueuedMessage)
590}
591
592func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
593 if a.IsSessionBusy(sessionID) {
594 return ErrSessionBusy
595 }
596
597 // Copy mutable fields under lock to avoid races with SetModels.
598 largeModel := a.largeModel.Get()
599 systemPromptPrefix := a.systemPromptPrefix.Get()
600
601 currentSession, err := a.sessions.Get(ctx, sessionID)
602 if err != nil {
603 return fmt.Errorf("failed to get session: %w", err)
604 }
605 msgs, err := a.getSessionMessages(ctx, currentSession)
606 if err != nil {
607 return err
608 }
609 if len(msgs) == 0 {
610 // Nothing to summarize.
611 return nil
612 }
613
614 aiMsgs, _ := a.preparePrompt(msgs)
615
616 genCtx, cancel := context.WithCancel(ctx)
617 a.activeRequests.Set(sessionID, cancel)
618 defer a.activeRequests.Del(sessionID)
619 defer cancel()
620
621 agent := fantasy.NewAgent(largeModel.Model,
622 fantasy.WithSystemPrompt(string(summaryPrompt)),
623 )
624 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
625 Role: message.Assistant,
626 Model: largeModel.Model.Model(),
627 Provider: largeModel.Model.Provider(),
628 IsSummaryMessage: true,
629 })
630 if err != nil {
631 return err
632 }
633
634 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
635
636 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
637 Prompt: summaryPromptText,
638 Messages: aiMsgs,
639 ProviderOptions: opts,
640 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
641 prepared.Messages = options.Messages
642 if systemPromptPrefix != "" {
643 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
644 }
645 return callContext, prepared, nil
646 },
647 OnReasoningDelta: func(id string, text string) error {
648 summaryMessage.AppendReasoningContent(text)
649 return a.messages.Update(genCtx, summaryMessage)
650 },
651 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
652 // Handle anthropic signature.
653 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
654 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
655 summaryMessage.AppendReasoningSignature(signature.Signature)
656 }
657 }
658 summaryMessage.FinishThinking()
659 return a.messages.Update(genCtx, summaryMessage)
660 },
661 OnTextDelta: func(id, text string) error {
662 summaryMessage.AppendContent(text)
663 return a.messages.Update(genCtx, summaryMessage)
664 },
665 })
666 if err != nil {
667 isCancelErr := errors.Is(err, context.Canceled)
668 if isCancelErr {
669 // User cancelled summarize we need to remove the summary message.
670 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
671 return deleteErr
672 }
673 return err
674 }
675
676 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
677 err = a.messages.Update(genCtx, summaryMessage)
678 if err != nil {
679 return err
680 }
681
682 var openrouterCost *float64
683 for _, step := range resp.Steps {
684 stepCost := a.openrouterCost(step.ProviderMetadata)
685 if stepCost != nil {
686 newCost := *stepCost
687 if openrouterCost != nil {
688 newCost += *openrouterCost
689 }
690 openrouterCost = &newCost
691 }
692 }
693
694 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
695
696 // Just in case, get just the last usage info.
697 usage := resp.Response.Usage
698 currentSession.SummaryMessageID = summaryMessage.ID
699 currentSession.CompletionTokens = usage.OutputTokens
700 currentSession.PromptTokens = 0
701 _, err = a.sessions.Save(genCtx, currentSession)
702 return err
703}
704
705func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
706 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
707 return fantasy.ProviderOptions{}
708 }
709 return fantasy.ProviderOptions{
710 anthropic.Name: &anthropic.ProviderCacheControlOptions{
711 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
712 },
713 bedrock.Name: &anthropic.ProviderCacheControlOptions{
714 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
715 },
716 vercel.Name: &anthropic.ProviderCacheControlOptions{
717 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
718 },
719 }
720}
721
722func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
723 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
724 var attachmentParts []message.ContentPart
725 for _, attachment := range call.Attachments {
726 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
727 }
728 parts = append(parts, attachmentParts...)
729 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
730 Role: message.User,
731 Parts: parts,
732 })
733 if err != nil {
734 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
735 }
736 return msg, nil
737}
738
739func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
740 var history []fantasy.Message
741 if !a.isSubAgent {
742 history = append(history, fantasy.NewUserMessage(
743 fmt.Sprintf("<system_reminder>%s</system_reminder>",
744 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
745If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
746If not, please feel free to ignore. Again do not mention this message to the user.`,
747 ),
748 ))
749 }
750 for _, m := range msgs {
751 if len(m.Parts) == 0 {
752 continue
753 }
754 // Assistant message without content or tool calls (cancelled before it
755 // returned anything).
756 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
757 continue
758 }
759 history = append(history, m.ToAIMessage()...)
760 }
761
762 var files []fantasy.FilePart
763 for _, attachment := range attachments {
764 if attachment.IsText() {
765 continue
766 }
767 files = append(files, fantasy.FilePart{
768 Filename: attachment.FileName,
769 Data: attachment.Content,
770 MediaType: attachment.MimeType,
771 })
772 }
773
774 return history, files
775}
776
777func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
778 msgs, err := a.messages.List(ctx, session.ID)
779 if err != nil {
780 return nil, fmt.Errorf("failed to list messages: %w", err)
781 }
782
783 if session.SummaryMessageID != "" {
784 summaryMsgIndex := -1
785 for i, msg := range msgs {
786 if msg.ID == session.SummaryMessageID {
787 summaryMsgIndex = i
788 break
789 }
790 }
791 if summaryMsgIndex != -1 {
792 msgs = msgs[summaryMsgIndex:]
793 msgs[0].Role = message.User
794 }
795 }
796 return msgs, nil
797}
798
799// generateTitle generates a session titled based on the initial prompt.
800func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
801 if userPrompt == "" {
802 return
803 }
804
805 smallModel := a.smallModel.Get()
806 largeModel := a.largeModel.Get()
807 systemPromptPrefix := a.systemPromptPrefix.Get()
808
809 var maxOutputTokens int64 = 40
810 if smallModel.CatwalkCfg.CanReason {
811 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
812 }
813
814 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
815 return fantasy.NewAgent(m,
816 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
817 fantasy.WithMaxOutputTokens(tok),
818 )
819 }
820
821 streamCall := fantasy.AgentStreamCall{
822 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
823 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
824 prepared.Messages = opts.Messages
825 if systemPromptPrefix != "" {
826 prepared.Messages = append([]fantasy.Message{
827 fantasy.NewSystemMessage(systemPromptPrefix),
828 }, prepared.Messages...)
829 }
830 return callCtx, prepared, nil
831 },
832 }
833
834 // Use the small model to generate the title.
835 model := smallModel
836 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
837 resp, err := agent.Stream(ctx, streamCall)
838 if err == nil {
839 // We successfully generated a title with the small model.
840 slog.Debug("Generated title with small model")
841 } else {
842 // It didn't work. Let's try with the big model.
843 slog.Error("Error generating title with small model; trying big model", "err", err)
844 model = largeModel
845 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
846 resp, err = agent.Stream(ctx, streamCall)
847 if err == nil {
848 slog.Debug("Generated title with large model")
849 } else {
850 // Welp, the large model didn't work either. Use the default
851 // session name and return.
852 slog.Error("Error generating title with large model", "err", err)
853 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
854 if saveErr != nil {
855 slog.Error("Failed to save session title and usage", "error", saveErr)
856 }
857 return
858 }
859 }
860
861 if resp == nil {
862 // Actually, we didn't get a response so we can't. Use the default
863 // session name and return.
864 slog.Error("Response is nil; can't generate title")
865 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
866 if saveErr != nil {
867 slog.Error("Failed to save session title and usage", "error", saveErr)
868 }
869 return
870 }
871
872 // Clean up title.
873 var title string
874 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
875
876 // Remove thinking tags if present.
877 title = thinkTagRegex.ReplaceAllString(title, "")
878
879 title = strings.TrimSpace(title)
880 title = cmp.Or(title, defaultSessionName)
881
882 // Calculate usage and cost.
883 var openrouterCost *float64
884 for _, step := range resp.Steps {
885 stepCost := a.openrouterCost(step.ProviderMetadata)
886 if stepCost != nil {
887 newCost := *stepCost
888 if openrouterCost != nil {
889 newCost += *openrouterCost
890 }
891 openrouterCost = &newCost
892 }
893 }
894
895 modelConfig := model.CatwalkCfg
896 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
897 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
898 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
899 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
900
901 // Use override cost if available (e.g., from OpenRouter).
902 if openrouterCost != nil {
903 cost = *openrouterCost
904 }
905
906 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
907 completionTokens := resp.TotalUsage.OutputTokens
908
909 // Atomically update only title and usage fields to avoid overriding other
910 // concurrent session updates.
911 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
912 if saveErr != nil {
913 slog.Error("Failed to save session title and usage", "error", saveErr)
914 return
915 }
916}
917
918func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
919 openrouterMetadata, ok := metadata[openrouter.Name]
920 if !ok {
921 return nil
922 }
923
924 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
925 if !ok {
926 return nil
927 }
928 return &opts.Usage.Cost
929}
930
931func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
932 modelConfig := model.CatwalkCfg
933 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
934 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
935 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
936 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
937
938 a.eventTokensUsed(session.ID, model, usage, cost)
939
940 if overrideCost != nil {
941 session.Cost += *overrideCost
942 } else {
943 session.Cost += cost
944 }
945
946 session.CompletionTokens = usage.OutputTokens
947 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
948}
949
950func (a *sessionAgent) Cancel(sessionID string) {
951 // Cancel regular requests. Don't use Take() here - we need the entry to
952 // remain in activeRequests so IsBusy() returns true until the goroutine
953 // fully completes (including error handling that may access the DB).
954 // The defer in processRequest will clean up the entry.
955 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
956 slog.Debug("Request cancellation initiated", "session_id", sessionID)
957 cancel()
958 }
959
960 // Also check for summarize requests.
961 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
962 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
963 cancel()
964 }
965
966 if a.QueuedPrompts(sessionID) > 0 {
967 slog.Debug("Clearing queued prompts", "session_id", sessionID)
968 a.messageQueue.Del(sessionID)
969 }
970}
971
972func (a *sessionAgent) ClearQueue(sessionID string) {
973 if a.QueuedPrompts(sessionID) > 0 {
974 slog.Debug("Clearing queued prompts", "session_id", sessionID)
975 a.messageQueue.Del(sessionID)
976 }
977}
978
979func (a *sessionAgent) CancelAll() {
980 if !a.IsBusy() {
981 return
982 }
983 for key := range a.activeRequests.Seq2() {
984 a.Cancel(key) // key is sessionID
985 }
986
987 timeout := time.After(5 * time.Second)
988 for a.IsBusy() {
989 select {
990 case <-timeout:
991 return
992 default:
993 time.Sleep(200 * time.Millisecond)
994 }
995 }
996}
997
998func (a *sessionAgent) IsBusy() bool {
999 var busy bool
1000 for cancelFunc := range a.activeRequests.Seq() {
1001 if cancelFunc != nil {
1002 busy = true
1003 break
1004 }
1005 }
1006 return busy
1007}
1008
1009func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1010 _, busy := a.activeRequests.Get(sessionID)
1011 return busy
1012}
1013
1014func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1015 l, ok := a.messageQueue.Get(sessionID)
1016 if !ok {
1017 return 0
1018 }
1019 return len(l)
1020}
1021
1022func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1023 l, ok := a.messageQueue.Get(sessionID)
1024 if !ok {
1025 return nil
1026 }
1027 prompts := make([]string, len(l))
1028 for i, call := range l {
1029 prompts[i] = call.Prompt
1030 }
1031 return prompts
1032}
1033
1034func (a *sessionAgent) SetModels(large Model, small Model) {
1035 a.largeModel.Set(large)
1036 a.smallModel.Set(small)
1037}
1038
1039func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1040 a.tools.SetSlice(tools)
1041}
1042
1043func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1044 a.systemPrompt.Set(systemPrompt)
1045}
1046
1047func (a *sessionAgent) Model() Model {
1048 return a.largeModel.Get()
1049}
1050
1051// convertToToolResult converts a fantasy tool result to a message tool result.
1052func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1053 baseResult := message.ToolResult{
1054 ToolCallID: result.ToolCallID,
1055 Name: result.ToolName,
1056 Metadata: result.ClientMetadata,
1057 }
1058
1059 switch result.Result.GetType() {
1060 case fantasy.ToolResultContentTypeText:
1061 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1062 baseResult.Content = r.Text
1063 }
1064 case fantasy.ToolResultContentTypeError:
1065 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1066 baseResult.Content = r.Error.Error()
1067 baseResult.IsError = true
1068 }
1069 case fantasy.ToolResultContentTypeMedia:
1070 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1071 content := r.Text
1072 if content == "" {
1073 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1074 }
1075 baseResult.Content = content
1076 baseResult.Data = r.Data
1077 baseResult.MIMEType = r.MediaType
1078 }
1079 }
1080
1081 return baseResult
1082}
1083
1084// workaroundProviderMediaLimitations converts media content in tool results to
1085// user messages for providers that don't natively support images in tool results.
1086//
1087// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1088// don't support sending images/media in tool result messages - they only accept
1089// text in tool results. However, they DO support images in user messages.
1090//
1091// If we send media in tool results to these providers, the API returns an error.
1092//
1093// Solution: For these providers, we:
1094// 1. Replace the media in the tool result with a text placeholder
1095// 2. Inject a user message immediately after with the image as a file attachment
1096// 3. This maintains the tool execution flow while working around API limitations
1097//
1098// Anthropic and Bedrock support images natively in tool results, so we skip
1099// this workaround for them.
1100//
1101// Example transformation:
1102//
1103// BEFORE: [tool result: image data]
1104// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1105func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1106 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1107 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1108
1109 if providerSupportsMedia {
1110 return messages
1111 }
1112
1113 convertedMessages := make([]fantasy.Message, 0, len(messages))
1114
1115 for _, msg := range messages {
1116 if msg.Role != fantasy.MessageRoleTool {
1117 convertedMessages = append(convertedMessages, msg)
1118 continue
1119 }
1120
1121 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1122 var mediaFiles []fantasy.FilePart
1123
1124 for _, part := range msg.Content {
1125 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1126 if !ok {
1127 textParts = append(textParts, part)
1128 continue
1129 }
1130
1131 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1132 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1133 if err != nil {
1134 slog.Warn("Failed to decode media data", "error", err)
1135 textParts = append(textParts, part)
1136 continue
1137 }
1138
1139 mediaFiles = append(mediaFiles, fantasy.FilePart{
1140 Data: decoded,
1141 MediaType: media.MediaType,
1142 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1143 })
1144
1145 textParts = append(textParts, fantasy.ToolResultPart{
1146 ToolCallID: toolResult.ToolCallID,
1147 Output: fantasy.ToolResultOutputContentText{
1148 Text: "[Image/media content loaded - see attached file]",
1149 },
1150 ProviderOptions: toolResult.ProviderOptions,
1151 })
1152 } else {
1153 textParts = append(textParts, part)
1154 }
1155 }
1156
1157 convertedMessages = append(convertedMessages, fantasy.Message{
1158 Role: fantasy.MessageRoleTool,
1159 Content: textParts,
1160 })
1161
1162 if len(mediaFiles) > 0 {
1163 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1164 "Here is the media content from the tool result:",
1165 mediaFiles...,
1166 ))
1167 }
1168 }
1169
1170 return convertedMessages
1171}
1172
1173// buildSummaryPrompt constructs the prompt text for session summarization.
1174func buildSummaryPrompt(todos []session.Todo) string {
1175 var sb strings.Builder
1176 sb.WriteString("Provide a detailed summary of our conversation above.")
1177 if len(todos) > 0 {
1178 sb.WriteString("\n\n## Current Todo List\n\n")
1179 for _, t := range todos {
1180 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1181 }
1182 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1183 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1184 }
1185 return sb.String()
1186}