1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
68
69type SessionAgentCall struct {
70 SessionID string
71 Prompt string
72 ProviderOptions fantasy.ProviderOptions
73 Attachments []message.Attachment
74 MaxOutputTokens int64
75 Temperature *float64
76 TopP *float64
77 TopK *int64
78 FrequencyPenalty *float64
79 PresencePenalty *float64
80 NonInteractive bool
81}
82
83type SessionAgent interface {
84 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
85 SetModels(large Model, small Model)
86 SetTools(tools []fantasy.AgentTool)
87 SetSystemPrompt(systemPrompt string)
88 Cancel(sessionID string)
89 CancelAll()
90 IsSessionBusy(sessionID string) bool
91 IsBusy() bool
92 QueuedPrompts(sessionID string) int
93 QueuedPromptsList(sessionID string) []string
94 ClearQueue(sessionID string)
95 Summarize(context.Context, string, fantasy.ProviderOptions) error
96 Model() Model
97}
98
99type Model struct {
100 Model fantasy.LanguageModel
101 CatwalkCfg catwalk.Model
102 ModelCfg config.SelectedModel
103}
104
105type sessionAgent struct {
106 largeModel *csync.Value[Model]
107 smallModel *csync.Value[Model]
108 systemPromptPrefix *csync.Value[string]
109 systemPrompt *csync.Value[string]
110 tools *csync.Slice[fantasy.AgentTool]
111
112 isSubAgent bool
113 sessions session.Service
114 messages message.Service
115 disableAutoSummarize bool
116 isYolo bool
117 notify pubsub.Publisher[notify.Notification]
118
119 messageQueue *csync.Map[string, []SessionAgentCall]
120 activeRequests *csync.Map[string, context.CancelFunc]
121}
122
123type SessionAgentOptions struct {
124 LargeModel Model
125 SmallModel Model
126 SystemPromptPrefix string
127 SystemPrompt string
128 IsSubAgent bool
129 DisableAutoSummarize bool
130 IsYolo bool
131 Sessions session.Service
132 Messages message.Service
133 Tools []fantasy.AgentTool
134 Notify pubsub.Publisher[notify.Notification]
135}
136
137func NewSessionAgent(
138 opts SessionAgentOptions,
139) SessionAgent {
140 return &sessionAgent{
141 largeModel: csync.NewValue(opts.LargeModel),
142 smallModel: csync.NewValue(opts.SmallModel),
143 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
144 systemPrompt: csync.NewValue(opts.SystemPrompt),
145 isSubAgent: opts.IsSubAgent,
146 sessions: opts.Sessions,
147 messages: opts.Messages,
148 disableAutoSummarize: opts.DisableAutoSummarize,
149 tools: csync.NewSliceFrom(opts.Tools),
150 isYolo: opts.IsYolo,
151 notify: opts.Notify,
152 messageQueue: csync.NewMap[string, []SessionAgentCall](),
153 activeRequests: csync.NewMap[string, context.CancelFunc](),
154 }
155}
156
157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
158 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
159 return nil, ErrEmptyPrompt
160 }
161 if call.SessionID == "" {
162 return nil, ErrSessionMissing
163 }
164
165 // Queue the message if busy
166 if a.IsSessionBusy(call.SessionID) {
167 existing, ok := a.messageQueue.Get(call.SessionID)
168 if !ok {
169 existing = []SessionAgentCall{}
170 }
171 existing = append(existing, call)
172 a.messageQueue.Set(call.SessionID, existing)
173 return nil, nil
174 }
175
176 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
177 agentTools := a.tools.Copy()
178 largeModel := a.largeModel.Get()
179 systemPrompt := a.systemPrompt.Get()
180 promptPrefix := a.systemPromptPrefix.Get()
181 var instructions strings.Builder
182
183 for _, server := range mcp.GetStates() {
184 if server.State != mcp.StateConnected {
185 continue
186 }
187 if s := server.Client.InitializeResult().Instructions; s != "" {
188 instructions.WriteString(s)
189 instructions.WriteString("\n\n")
190 }
191 }
192
193 if s := instructions.String(); s != "" {
194 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
195 }
196
197 if len(agentTools) > 0 {
198 // Add Anthropic caching to the last tool.
199 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
200 }
201
202 agent := fantasy.NewAgent(
203 largeModel.Model,
204 fantasy.WithSystemPrompt(systemPrompt),
205 fantasy.WithTools(agentTools...),
206 fantasy.WithUserAgent(userAgent),
207 )
208
209 sessionLock := sync.Mutex{}
210 currentSession, err := a.sessions.Get(ctx, call.SessionID)
211 if err != nil {
212 return nil, fmt.Errorf("failed to get session: %w", err)
213 }
214
215 msgs, err := a.getSessionMessages(ctx, currentSession)
216 if err != nil {
217 return nil, fmt.Errorf("failed to get session messages: %w", err)
218 }
219
220 var wg sync.WaitGroup
221 // Generate title if first message.
222 if len(msgs) == 0 {
223 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
224 wg.Go(func() {
225 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
226 })
227 }
228 defer wg.Wait()
229
230 // Add the user message to the session.
231 _, err = a.createUserMessage(ctx, call)
232 if err != nil {
233 return nil, err
234 }
235
236 // Add the session to the context.
237 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
238
239 genCtx, cancel := context.WithCancel(ctx)
240 a.activeRequests.Set(call.SessionID, cancel)
241
242 defer cancel()
243 defer a.activeRequests.Del(call.SessionID)
244
245 history, files := a.preparePrompt(msgs, call.Attachments...)
246
247 startTime := time.Now()
248 a.eventPromptSent(call.SessionID)
249
250 var currentAssistant *message.Message
251 var shouldSummarize bool
252 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
253 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
254 Files: files,
255 Messages: history,
256 ProviderOptions: call.ProviderOptions,
257 MaxOutputTokens: &call.MaxOutputTokens,
258 TopP: call.TopP,
259 Temperature: call.Temperature,
260 PresencePenalty: call.PresencePenalty,
261 TopK: call.TopK,
262 FrequencyPenalty: call.FrequencyPenalty,
263 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
264 prepared.Messages = options.Messages
265 for i := range prepared.Messages {
266 prepared.Messages[i].ProviderOptions = nil
267 }
268
269 // Use latest tools (updated by SetTools when MCP tools change).
270 prepared.Tools = a.tools.Copy()
271
272 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
273 a.messageQueue.Del(call.SessionID)
274 for _, queued := range queuedCalls {
275 userMessage, createErr := a.createUserMessage(callContext, queued)
276 if createErr != nil {
277 return callContext, prepared, createErr
278 }
279 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
280 }
281
282 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
283
284 lastSystemRoleInx := 0
285 systemMessageUpdated := false
286 for i, msg := range prepared.Messages {
287 // Only add cache control to the last message.
288 if msg.Role == fantasy.MessageRoleSystem {
289 lastSystemRoleInx = i
290 } else if !systemMessageUpdated {
291 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
292 systemMessageUpdated = true
293 }
294 // Than add cache control to the last 2 messages.
295 if i > len(prepared.Messages)-3 {
296 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
297 }
298 }
299
300 if promptPrefix != "" {
301 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
302 }
303
304 var assistantMsg message.Message
305 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
306 Role: message.Assistant,
307 Parts: []message.ContentPart{},
308 Model: largeModel.ModelCfg.Model,
309 Provider: largeModel.ModelCfg.Provider,
310 })
311 if err != nil {
312 return callContext, prepared, err
313 }
314 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
315 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
316 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
317 currentAssistant = &assistantMsg
318 return callContext, prepared, err
319 },
320 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
321 currentAssistant.AppendReasoningContent(reasoning.Text)
322 return a.messages.Update(genCtx, *currentAssistant)
323 },
324 OnReasoningDelta: func(id string, text string) error {
325 currentAssistant.AppendReasoningContent(text)
326 return a.messages.Update(genCtx, *currentAssistant)
327 },
328 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
329 // handle anthropic signature
330 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
331 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
332 currentAssistant.AppendReasoningSignature(reasoning.Signature)
333 }
334 }
335 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
336 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
337 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
338 }
339 }
340 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
341 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
342 currentAssistant.SetReasoningResponsesData(reasoning)
343 }
344 }
345 currentAssistant.FinishThinking()
346 return a.messages.Update(genCtx, *currentAssistant)
347 },
348 OnTextDelta: func(id string, text string) error {
349 // Strip leading newline from initial text content. This is is
350 // particularly important in non-interactive mode where leading
351 // newlines are very visible.
352 if len(currentAssistant.Parts) == 0 {
353 text = strings.TrimPrefix(text, "\n")
354 }
355
356 currentAssistant.AppendContent(text)
357 return a.messages.Update(genCtx, *currentAssistant)
358 },
359 OnToolInputStart: func(id string, toolName string) error {
360 toolCall := message.ToolCall{
361 ID: id,
362 Name: toolName,
363 ProviderExecuted: false,
364 Finished: false,
365 }
366 currentAssistant.AddToolCall(toolCall)
367 return a.messages.Update(genCtx, *currentAssistant)
368 },
369 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
370 // TODO: implement
371 },
372 OnToolCall: func(tc fantasy.ToolCallContent) error {
373 toolCall := message.ToolCall{
374 ID: tc.ToolCallID,
375 Name: tc.ToolName,
376 Input: tc.Input,
377 ProviderExecuted: false,
378 Finished: true,
379 }
380 currentAssistant.AddToolCall(toolCall)
381 return a.messages.Update(genCtx, *currentAssistant)
382 },
383 OnToolResult: func(result fantasy.ToolResultContent) error {
384 toolResult := a.convertToToolResult(result)
385 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
386 Role: message.Tool,
387 Parts: []message.ContentPart{
388 toolResult,
389 },
390 })
391 return createMsgErr
392 },
393 OnStepFinish: func(stepResult fantasy.StepResult) error {
394 finishReason := message.FinishReasonUnknown
395 switch stepResult.FinishReason {
396 case fantasy.FinishReasonLength:
397 finishReason = message.FinishReasonMaxTokens
398 case fantasy.FinishReasonStop:
399 finishReason = message.FinishReasonEndTurn
400 case fantasy.FinishReasonToolCalls:
401 finishReason = message.FinishReasonToolUse
402 }
403 currentAssistant.AddFinish(finishReason, "", "")
404 sessionLock.Lock()
405 defer sessionLock.Unlock()
406
407 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
408 if getSessionErr != nil {
409 return getSessionErr
410 }
411 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
412 _, sessionErr := a.sessions.Save(ctx, updatedSession)
413 if sessionErr != nil {
414 return sessionErr
415 }
416 currentSession = updatedSession
417 return a.messages.Update(genCtx, *currentAssistant)
418 },
419 StopWhen: []fantasy.StopCondition{
420 func(_ []fantasy.StepResult) bool {
421 cw := int64(largeModel.CatwalkCfg.ContextWindow)
422 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
423 remaining := cw - tokens
424 var threshold int64
425 if cw > largeContextWindowThreshold {
426 threshold = largeContextWindowBuffer
427 } else {
428 threshold = int64(float64(cw) * smallContextWindowRatio)
429 }
430 if (remaining <= threshold) && !a.disableAutoSummarize {
431 shouldSummarize = true
432 return true
433 }
434 return false
435 },
436 func(steps []fantasy.StepResult) bool {
437 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
438 },
439 },
440 })
441
442 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
443
444 if err != nil {
445 isCancelErr := errors.Is(err, context.Canceled)
446 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
447 if currentAssistant == nil {
448 return result, err
449 }
450 // Ensure we finish thinking on error to close the reasoning state.
451 currentAssistant.FinishThinking()
452 toolCalls := currentAssistant.ToolCalls()
453 // INFO: we use the parent context here because the genCtx has been cancelled.
454 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
455 if createErr != nil {
456 return nil, createErr
457 }
458 for _, tc := range toolCalls {
459 if !tc.Finished {
460 tc.Finished = true
461 tc.Input = "{}"
462 currentAssistant.AddToolCall(tc)
463 updateErr := a.messages.Update(ctx, *currentAssistant)
464 if updateErr != nil {
465 return nil, updateErr
466 }
467 }
468
469 found := false
470 for _, msg := range msgs {
471 if msg.Role == message.Tool {
472 for _, tr := range msg.ToolResults() {
473 if tr.ToolCallID == tc.ID {
474 found = true
475 break
476 }
477 }
478 }
479 if found {
480 break
481 }
482 }
483 if found {
484 continue
485 }
486 content := "There was an error while executing the tool"
487 if isCancelErr {
488 content = "Tool execution canceled by user"
489 } else if isPermissionErr {
490 content = "User denied permission"
491 }
492 toolResult := message.ToolResult{
493 ToolCallID: tc.ID,
494 Name: tc.Name,
495 Content: content,
496 IsError: true,
497 }
498 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
499 Role: message.Tool,
500 Parts: []message.ContentPart{
501 toolResult,
502 },
503 })
504 if createErr != nil {
505 return nil, createErr
506 }
507 }
508 var fantasyErr *fantasy.Error
509 var providerErr *fantasy.ProviderError
510 const defaultTitle = "Provider Error"
511 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
512 if isCancelErr {
513 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
514 } else if isPermissionErr {
515 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
516 } else if errors.Is(err, hyper.ErrNoCredits) {
517 url := hyper.BaseURL()
518 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
519 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
520 } else if errors.As(err, &providerErr) {
521 if providerErr.Message == "The requested model is not supported." {
522 url := "https://github.com/settings/copilot/features"
523 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
524 currentAssistant.AddFinish(
525 message.FinishReasonError,
526 "Copilot model not enabled",
527 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
528 )
529 } else {
530 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
531 }
532 } else if errors.As(err, &fantasyErr) {
533 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
534 } else {
535 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
536 }
537 // Note: we use the parent context here because the genCtx has been
538 // cancelled.
539 updateErr := a.messages.Update(ctx, *currentAssistant)
540 if updateErr != nil {
541 return nil, updateErr
542 }
543 return nil, err
544 }
545
546 // Send notification that agent has finished its turn (skip for
547 // nested/non-interactive sessions).
548 if !call.NonInteractive && a.notify != nil {
549 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
550 SessionID: call.SessionID,
551 SessionTitle: currentSession.Title,
552 Type: notify.TypeAgentFinished,
553 })
554 }
555
556 if shouldSummarize {
557 a.activeRequests.Del(call.SessionID)
558 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
559 return nil, summarizeErr
560 }
561 // If the agent wasn't done...
562 if len(currentAssistant.ToolCalls()) > 0 {
563 existing, ok := a.messageQueue.Get(call.SessionID)
564 if !ok {
565 existing = []SessionAgentCall{}
566 }
567 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
568 existing = append(existing, call)
569 a.messageQueue.Set(call.SessionID, existing)
570 }
571 }
572
573 // Release active request before processing queued messages.
574 a.activeRequests.Del(call.SessionID)
575 cancel()
576
577 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
578 if !ok || len(queuedMessages) == 0 {
579 return result, err
580 }
581 // There are queued messages restart the loop.
582 firstQueuedMessage := queuedMessages[0]
583 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
584 return a.Run(ctx, firstQueuedMessage)
585}
586
587func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
588 if a.IsSessionBusy(sessionID) {
589 return ErrSessionBusy
590 }
591
592 // Copy mutable fields under lock to avoid races with SetModels.
593 largeModel := a.largeModel.Get()
594 systemPromptPrefix := a.systemPromptPrefix.Get()
595
596 currentSession, err := a.sessions.Get(ctx, sessionID)
597 if err != nil {
598 return fmt.Errorf("failed to get session: %w", err)
599 }
600 msgs, err := a.getSessionMessages(ctx, currentSession)
601 if err != nil {
602 return err
603 }
604 if len(msgs) == 0 {
605 // Nothing to summarize.
606 return nil
607 }
608
609 aiMsgs, _ := a.preparePrompt(msgs)
610
611 genCtx, cancel := context.WithCancel(ctx)
612 a.activeRequests.Set(sessionID, cancel)
613 defer a.activeRequests.Del(sessionID)
614 defer cancel()
615
616 agent := fantasy.NewAgent(largeModel.Model,
617 fantasy.WithSystemPrompt(string(summaryPrompt)),
618 fantasy.WithUserAgent(userAgent),
619 )
620 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
621 Role: message.Assistant,
622 Model: largeModel.Model.Model(),
623 Provider: largeModel.Model.Provider(),
624 IsSummaryMessage: true,
625 })
626 if err != nil {
627 return err
628 }
629
630 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
631
632 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
633 Prompt: summaryPromptText,
634 Messages: aiMsgs,
635 ProviderOptions: opts,
636 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
637 prepared.Messages = options.Messages
638 if systemPromptPrefix != "" {
639 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
640 }
641 return callContext, prepared, nil
642 },
643 OnReasoningDelta: func(id string, text string) error {
644 summaryMessage.AppendReasoningContent(text)
645 return a.messages.Update(genCtx, summaryMessage)
646 },
647 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
648 // Handle anthropic signature.
649 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
650 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
651 summaryMessage.AppendReasoningSignature(signature.Signature)
652 }
653 }
654 summaryMessage.FinishThinking()
655 return a.messages.Update(genCtx, summaryMessage)
656 },
657 OnTextDelta: func(id, text string) error {
658 summaryMessage.AppendContent(text)
659 return a.messages.Update(genCtx, summaryMessage)
660 },
661 })
662 if err != nil {
663 isCancelErr := errors.Is(err, context.Canceled)
664 if isCancelErr {
665 // User cancelled summarize we need to remove the summary message.
666 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
667 return deleteErr
668 }
669 return err
670 }
671
672 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
673 err = a.messages.Update(genCtx, summaryMessage)
674 if err != nil {
675 return err
676 }
677
678 var openrouterCost *float64
679 for _, step := range resp.Steps {
680 stepCost := a.openrouterCost(step.ProviderMetadata)
681 if stepCost != nil {
682 newCost := *stepCost
683 if openrouterCost != nil {
684 newCost += *openrouterCost
685 }
686 openrouterCost = &newCost
687 }
688 }
689
690 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
691
692 // Just in case, get just the last usage info.
693 usage := resp.Response.Usage
694 currentSession.SummaryMessageID = summaryMessage.ID
695 currentSession.CompletionTokens = usage.OutputTokens
696 currentSession.PromptTokens = 0
697 _, err = a.sessions.Save(genCtx, currentSession)
698 return err
699}
700
701func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
702 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
703 return fantasy.ProviderOptions{}
704 }
705 return fantasy.ProviderOptions{
706 anthropic.Name: &anthropic.ProviderCacheControlOptions{
707 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
708 },
709 bedrock.Name: &anthropic.ProviderCacheControlOptions{
710 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
711 },
712 vercel.Name: &anthropic.ProviderCacheControlOptions{
713 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
714 },
715 }
716}
717
718func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
719 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
720 var attachmentParts []message.ContentPart
721 for _, attachment := range call.Attachments {
722 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
723 }
724 parts = append(parts, attachmentParts...)
725 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
726 Role: message.User,
727 Parts: parts,
728 })
729 if err != nil {
730 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
731 }
732 return msg, nil
733}
734
735func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
736 var history []fantasy.Message
737 if !a.isSubAgent {
738 history = append(history, fantasy.NewUserMessage(
739 fmt.Sprintf("<system_reminder>%s</system_reminder>",
740 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
741If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
742If not, please feel free to ignore. Again do not mention this message to the user.`,
743 ),
744 ))
745 }
746 for _, m := range msgs {
747 if len(m.Parts) == 0 {
748 continue
749 }
750 // Assistant message without content or tool calls (cancelled before it
751 // returned anything).
752 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
753 continue
754 }
755 history = append(history, m.ToAIMessage()...)
756 }
757
758 var files []fantasy.FilePart
759 for _, attachment := range attachments {
760 if attachment.IsText() {
761 continue
762 }
763 files = append(files, fantasy.FilePart{
764 Filename: attachment.FileName,
765 Data: attachment.Content,
766 MediaType: attachment.MimeType,
767 })
768 }
769
770 return history, files
771}
772
773func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
774 msgs, err := a.messages.List(ctx, session.ID)
775 if err != nil {
776 return nil, fmt.Errorf("failed to list messages: %w", err)
777 }
778
779 if session.SummaryMessageID != "" {
780 summaryMsgIndex := -1
781 for i, msg := range msgs {
782 if msg.ID == session.SummaryMessageID {
783 summaryMsgIndex = i
784 break
785 }
786 }
787 if summaryMsgIndex != -1 {
788 msgs = msgs[summaryMsgIndex:]
789 msgs[0].Role = message.User
790 }
791 }
792 return msgs, nil
793}
794
795// generateTitle generates a session titled based on the initial prompt.
796func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
797 if userPrompt == "" {
798 return
799 }
800
801 smallModel := a.smallModel.Get()
802 largeModel := a.largeModel.Get()
803 systemPromptPrefix := a.systemPromptPrefix.Get()
804
805 var maxOutputTokens int64 = 40
806 if smallModel.CatwalkCfg.CanReason {
807 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
808 }
809
810 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
811 return fantasy.NewAgent(m,
812 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
813 fantasy.WithMaxOutputTokens(tok),
814 fantasy.WithUserAgent(userAgent),
815 )
816 }
817
818 streamCall := fantasy.AgentStreamCall{
819 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
820 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
821 prepared.Messages = opts.Messages
822 if systemPromptPrefix != "" {
823 prepared.Messages = append([]fantasy.Message{
824 fantasy.NewSystemMessage(systemPromptPrefix),
825 }, prepared.Messages...)
826 }
827 return callCtx, prepared, nil
828 },
829 }
830
831 // Use the small model to generate the title.
832 model := smallModel
833 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
834 resp, err := agent.Stream(ctx, streamCall)
835 if err == nil {
836 // We successfully generated a title with the small model.
837 slog.Debug("Generated title with small model")
838 } else {
839 // It didn't work. Let's try with the big model.
840 slog.Error("Error generating title with small model; trying big model", "err", err)
841 model = largeModel
842 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
843 resp, err = agent.Stream(ctx, streamCall)
844 if err == nil {
845 slog.Debug("Generated title with large model")
846 } else {
847 // Welp, the large model didn't work either. Use the default
848 // session name and return.
849 slog.Error("Error generating title with large model", "err", err)
850 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
851 if saveErr != nil {
852 slog.Error("Failed to save session title", "error", saveErr)
853 }
854 return
855 }
856 }
857
858 if resp == nil {
859 // Actually, we didn't get a response so we can't. Use the default
860 // session name and return.
861 slog.Error("Response is nil; can't generate title")
862 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
863 if saveErr != nil {
864 slog.Error("Failed to save session title", "error", saveErr)
865 }
866 return
867 }
868
869 // Clean up title.
870 var title string
871 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
872
873 // Remove thinking tags if present.
874 title = thinkTagRegex.ReplaceAllString(title, "")
875
876 title = strings.TrimSpace(title)
877 title = cmp.Or(title, DefaultSessionName)
878
879 // Calculate usage and cost.
880 var openrouterCost *float64
881 for _, step := range resp.Steps {
882 stepCost := a.openrouterCost(step.ProviderMetadata)
883 if stepCost != nil {
884 newCost := *stepCost
885 if openrouterCost != nil {
886 newCost += *openrouterCost
887 }
888 openrouterCost = &newCost
889 }
890 }
891
892 modelConfig := model.CatwalkCfg
893 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
894 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
895 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
896 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
897
898 // Use override cost if available (e.g., from OpenRouter).
899 if openrouterCost != nil {
900 cost = *openrouterCost
901 }
902
903 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
904 completionTokens := resp.TotalUsage.OutputTokens
905
906 // Atomically update only title and usage fields to avoid overriding other
907 // concurrent session updates.
908 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
909 if saveErr != nil {
910 slog.Error("Failed to save session title and usage", "error", saveErr)
911 return
912 }
913}
914
915func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
916 openrouterMetadata, ok := metadata[openrouter.Name]
917 if !ok {
918 return nil
919 }
920
921 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
922 if !ok {
923 return nil
924 }
925 return &opts.Usage.Cost
926}
927
928func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
929 modelConfig := model.CatwalkCfg
930 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
931 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
932 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
933 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
934
935 a.eventTokensUsed(session.ID, model, usage, cost)
936
937 if overrideCost != nil {
938 session.Cost += *overrideCost
939 } else {
940 session.Cost += cost
941 }
942
943 session.CompletionTokens = usage.OutputTokens
944 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
945}
946
947func (a *sessionAgent) Cancel(sessionID string) {
948 // Cancel regular requests. Don't use Take() here - we need the entry to
949 // remain in activeRequests so IsBusy() returns true until the goroutine
950 // fully completes (including error handling that may access the DB).
951 // The defer in processRequest will clean up the entry.
952 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
953 slog.Debug("Request cancellation initiated", "session_id", sessionID)
954 cancel()
955 }
956
957 // Also check for summarize requests.
958 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
959 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
960 cancel()
961 }
962
963 if a.QueuedPrompts(sessionID) > 0 {
964 slog.Debug("Clearing queued prompts", "session_id", sessionID)
965 a.messageQueue.Del(sessionID)
966 }
967}
968
969func (a *sessionAgent) ClearQueue(sessionID string) {
970 if a.QueuedPrompts(sessionID) > 0 {
971 slog.Debug("Clearing queued prompts", "session_id", sessionID)
972 a.messageQueue.Del(sessionID)
973 }
974}
975
976func (a *sessionAgent) CancelAll() {
977 if !a.IsBusy() {
978 return
979 }
980 for key := range a.activeRequests.Seq2() {
981 a.Cancel(key) // key is sessionID
982 }
983
984 timeout := time.After(5 * time.Second)
985 for a.IsBusy() {
986 select {
987 case <-timeout:
988 return
989 default:
990 time.Sleep(200 * time.Millisecond)
991 }
992 }
993}
994
995func (a *sessionAgent) IsBusy() bool {
996 var busy bool
997 for cancelFunc := range a.activeRequests.Seq() {
998 if cancelFunc != nil {
999 busy = true
1000 break
1001 }
1002 }
1003 return busy
1004}
1005
1006func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1007 _, busy := a.activeRequests.Get(sessionID)
1008 return busy
1009}
1010
1011func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1012 l, ok := a.messageQueue.Get(sessionID)
1013 if !ok {
1014 return 0
1015 }
1016 return len(l)
1017}
1018
1019func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1020 l, ok := a.messageQueue.Get(sessionID)
1021 if !ok {
1022 return nil
1023 }
1024 prompts := make([]string, len(l))
1025 for i, call := range l {
1026 prompts[i] = call.Prompt
1027 }
1028 return prompts
1029}
1030
1031func (a *sessionAgent) SetModels(large Model, small Model) {
1032 a.largeModel.Set(large)
1033 a.smallModel.Set(small)
1034}
1035
1036func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1037 a.tools.SetSlice(tools)
1038}
1039
1040func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1041 a.systemPrompt.Set(systemPrompt)
1042}
1043
1044func (a *sessionAgent) Model() Model {
1045 return a.largeModel.Get()
1046}
1047
1048// convertToToolResult converts a fantasy tool result to a message tool result.
1049func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1050 baseResult := message.ToolResult{
1051 ToolCallID: result.ToolCallID,
1052 Name: result.ToolName,
1053 Metadata: result.ClientMetadata,
1054 }
1055
1056 switch result.Result.GetType() {
1057 case fantasy.ToolResultContentTypeText:
1058 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1059 baseResult.Content = r.Text
1060 }
1061 case fantasy.ToolResultContentTypeError:
1062 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1063 baseResult.Content = r.Error.Error()
1064 baseResult.IsError = true
1065 }
1066 case fantasy.ToolResultContentTypeMedia:
1067 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1068 content := r.Text
1069 if content == "" {
1070 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1071 }
1072 baseResult.Content = content
1073 baseResult.Data = r.Data
1074 baseResult.MIMEType = r.MediaType
1075 }
1076 }
1077
1078 return baseResult
1079}
1080
1081// workaroundProviderMediaLimitations converts media content in tool results to
1082// user messages for providers that don't natively support images in tool results.
1083//
1084// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1085// don't support sending images/media in tool result messages - they only accept
1086// text in tool results. However, they DO support images in user messages.
1087//
1088// If we send media in tool results to these providers, the API returns an error.
1089//
1090// Solution: For these providers, we:
1091// 1. Replace the media in the tool result with a text placeholder
1092// 2. Inject a user message immediately after with the image as a file attachment
1093// 3. This maintains the tool execution flow while working around API limitations
1094//
1095// Anthropic and Bedrock support images natively in tool results, so we skip
1096// this workaround for them.
1097//
1098// Example transformation:
1099//
1100// BEFORE: [tool result: image data]
1101// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1102func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1103 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1104 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1105
1106 if providerSupportsMedia {
1107 return messages
1108 }
1109
1110 convertedMessages := make([]fantasy.Message, 0, len(messages))
1111
1112 for _, msg := range messages {
1113 if msg.Role != fantasy.MessageRoleTool {
1114 convertedMessages = append(convertedMessages, msg)
1115 continue
1116 }
1117
1118 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1119 var mediaFiles []fantasy.FilePart
1120
1121 for _, part := range msg.Content {
1122 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1123 if !ok {
1124 textParts = append(textParts, part)
1125 continue
1126 }
1127
1128 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1129 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1130 if err != nil {
1131 slog.Warn("Failed to decode media data", "error", err)
1132 textParts = append(textParts, part)
1133 continue
1134 }
1135
1136 mediaFiles = append(mediaFiles, fantasy.FilePart{
1137 Data: decoded,
1138 MediaType: media.MediaType,
1139 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1140 })
1141
1142 textParts = append(textParts, fantasy.ToolResultPart{
1143 ToolCallID: toolResult.ToolCallID,
1144 Output: fantasy.ToolResultOutputContentText{
1145 Text: "[Image/media content loaded - see attached file]",
1146 },
1147 ProviderOptions: toolResult.ProviderOptions,
1148 })
1149 } else {
1150 textParts = append(textParts, part)
1151 }
1152 }
1153
1154 convertedMessages = append(convertedMessages, fantasy.Message{
1155 Role: fantasy.MessageRoleTool,
1156 Content: textParts,
1157 })
1158
1159 if len(mediaFiles) > 0 {
1160 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1161 "Here is the media content from the tool result:",
1162 mediaFiles...,
1163 ))
1164 }
1165 }
1166
1167 return convertedMessages
1168}
1169
1170// buildSummaryPrompt constructs the prompt text for session summarization.
1171func buildSummaryPrompt(todos []session.Todo) string {
1172 var sb strings.Builder
1173 sb.WriteString("Provide a detailed summary of our conversation above.")
1174 if len(todos) > 0 {
1175 sb.WriteString("\n\n## Current Todo List\n\n")
1176 for _, t := range todos {
1177 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1178 }
1179 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1180 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1181 }
1182 return sb.String()
1183}