1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
68
69type SessionAgentCall struct {
70 SessionID string
71 Prompt string
72 ProviderOptions fantasy.ProviderOptions
73 Attachments []message.Attachment
74 MaxOutputTokens int64
75 Temperature *float64
76 TopP *float64
77 TopK *int64
78 FrequencyPenalty *float64
79 PresencePenalty *float64
80 NonInteractive bool
81}
82
83type SessionAgent interface {
84 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
85 SetModels(large Model, small Model)
86 SetTools(tools []fantasy.AgentTool)
87 SetSystemPrompt(systemPrompt string)
88 Cancel(sessionID string)
89 CancelAll()
90 IsSessionBusy(sessionID string) bool
91 IsBusy() bool
92 QueuedPrompts(sessionID string) int
93 QueuedPromptsList(sessionID string) []string
94 ClearQueue(sessionID string)
95 Summarize(context.Context, string, fantasy.ProviderOptions) error
96 Model() Model
97}
98
99type Model struct {
100 Model fantasy.LanguageModel
101 CatwalkCfg catwalk.Model
102 ModelCfg config.SelectedModel
103}
104
105type sessionAgent struct {
106 largeModel *csync.Value[Model]
107 smallModel *csync.Value[Model]
108 systemPromptPrefix *csync.Value[string]
109 systemPrompt *csync.Value[string]
110 tools *csync.Slice[fantasy.AgentTool]
111
112 isSubAgent bool
113 sessions session.Service
114 messages message.Service
115 disableAutoSummarize bool
116 isYolo bool
117 notify pubsub.Publisher[notify.Notification]
118
119 messageQueue *csync.Map[string, []SessionAgentCall]
120 activeRequests *csync.Map[string, context.CancelFunc]
121}
122
123type SessionAgentOptions struct {
124 LargeModel Model
125 SmallModel Model
126 SystemPromptPrefix string
127 SystemPrompt string
128 IsSubAgent bool
129 DisableAutoSummarize bool
130 IsYolo bool
131 Sessions session.Service
132 Messages message.Service
133 Tools []fantasy.AgentTool
134 Notify pubsub.Publisher[notify.Notification]
135}
136
137func NewSessionAgent(
138 opts SessionAgentOptions,
139) SessionAgent {
140 return &sessionAgent{
141 largeModel: csync.NewValue(opts.LargeModel),
142 smallModel: csync.NewValue(opts.SmallModel),
143 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
144 systemPrompt: csync.NewValue(opts.SystemPrompt),
145 isSubAgent: opts.IsSubAgent,
146 sessions: opts.Sessions,
147 messages: opts.Messages,
148 disableAutoSummarize: opts.DisableAutoSummarize,
149 tools: csync.NewSliceFrom(opts.Tools),
150 isYolo: opts.IsYolo,
151 notify: opts.Notify,
152 messageQueue: csync.NewMap[string, []SessionAgentCall](),
153 activeRequests: csync.NewMap[string, context.CancelFunc](),
154 }
155}
156
157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
158 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
159 return nil, ErrEmptyPrompt
160 }
161 if call.SessionID == "" {
162 return nil, ErrSessionMissing
163 }
164
165 // Queue the message if busy
166 if a.IsSessionBusy(call.SessionID) {
167 existing, ok := a.messageQueue.Get(call.SessionID)
168 if !ok {
169 existing = []SessionAgentCall{}
170 }
171 existing = append(existing, call)
172 a.messageQueue.Set(call.SessionID, existing)
173 return nil, nil
174 }
175
176 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
177 agentTools := a.tools.Copy()
178 largeModel := a.largeModel.Get()
179 systemPrompt := a.systemPrompt.Get()
180 promptPrefix := a.systemPromptPrefix.Get()
181 var instructions strings.Builder
182
183 for _, server := range mcp.GetStates() {
184 if server.State != mcp.StateConnected {
185 continue
186 }
187 if s := server.Client.InitializeResult().Instructions; s != "" {
188 instructions.WriteString(s)
189 instructions.WriteString("\n\n")
190 }
191 }
192
193 if s := instructions.String(); s != "" {
194 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
195 }
196
197 if len(agentTools) > 0 {
198 // Add Anthropic caching to the last tool.
199 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
200 }
201
202 agent := fantasy.NewAgent(
203 largeModel.Model,
204 fantasy.WithSystemPrompt(systemPrompt),
205 fantasy.WithTools(agentTools...),
206 fantasy.WithUserAgent(userAgent),
207 )
208
209 sessionLock := sync.Mutex{}
210 currentSession, err := a.sessions.Get(ctx, call.SessionID)
211 if err != nil {
212 return nil, fmt.Errorf("failed to get session: %w", err)
213 }
214
215 msgs, err := a.getSessionMessages(ctx, currentSession)
216 if err != nil {
217 return nil, fmt.Errorf("failed to get session messages: %w", err)
218 }
219
220 var wg sync.WaitGroup
221 // Generate title if first message.
222 if len(msgs) == 0 {
223 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
224 wg.Go(func() {
225 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
226 })
227 }
228 defer wg.Wait()
229
230 // Add the user message to the session.
231 _, err = a.createUserMessage(ctx, call)
232 if err != nil {
233 return nil, err
234 }
235
236 // Add the session to the context.
237 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
238
239 genCtx, cancel := context.WithCancel(ctx)
240 a.activeRequests.Set(call.SessionID, cancel)
241
242 defer cancel()
243 defer a.activeRequests.Del(call.SessionID)
244
245 history, files := a.preparePrompt(msgs, call.Attachments...)
246
247 startTime := time.Now()
248 a.eventPromptSent(call.SessionID)
249
250 var currentAssistant *message.Message
251 var shouldSummarize bool
252 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
253 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
254 Files: files,
255 Messages: history,
256 ProviderOptions: call.ProviderOptions,
257 MaxOutputTokens: &call.MaxOutputTokens,
258 TopP: call.TopP,
259 Temperature: call.Temperature,
260 PresencePenalty: call.PresencePenalty,
261 TopK: call.TopK,
262 FrequencyPenalty: call.FrequencyPenalty,
263 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
264 prepared.Messages = options.Messages
265 for i := range prepared.Messages {
266 prepared.Messages[i].ProviderOptions = nil
267 }
268
269 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
270 a.messageQueue.Del(call.SessionID)
271 for _, queued := range queuedCalls {
272 userMessage, createErr := a.createUserMessage(callContext, queued)
273 if createErr != nil {
274 return callContext, prepared, createErr
275 }
276 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
277 }
278
279 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
280
281 lastSystemRoleInx := 0
282 systemMessageUpdated := false
283 for i, msg := range prepared.Messages {
284 // Only add cache control to the last message.
285 if msg.Role == fantasy.MessageRoleSystem {
286 lastSystemRoleInx = i
287 } else if !systemMessageUpdated {
288 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
289 systemMessageUpdated = true
290 }
291 // Than add cache control to the last 2 messages.
292 if i > len(prepared.Messages)-3 {
293 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
294 }
295 }
296
297 if promptPrefix != "" {
298 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
299 }
300
301 var assistantMsg message.Message
302 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
303 Role: message.Assistant,
304 Parts: []message.ContentPart{},
305 Model: largeModel.ModelCfg.Model,
306 Provider: largeModel.ModelCfg.Provider,
307 })
308 if err != nil {
309 return callContext, prepared, err
310 }
311 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
312 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
313 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
314 currentAssistant = &assistantMsg
315 return callContext, prepared, err
316 },
317 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
318 currentAssistant.AppendReasoningContent(reasoning.Text)
319 return a.messages.Update(genCtx, *currentAssistant)
320 },
321 OnReasoningDelta: func(id string, text string) error {
322 currentAssistant.AppendReasoningContent(text)
323 return a.messages.Update(genCtx, *currentAssistant)
324 },
325 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
326 // handle anthropic signature
327 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
328 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
329 currentAssistant.AppendReasoningSignature(reasoning.Signature)
330 }
331 }
332 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
333 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
334 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
335 }
336 }
337 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
338 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
339 currentAssistant.SetReasoningResponsesData(reasoning)
340 }
341 }
342 currentAssistant.FinishThinking()
343 return a.messages.Update(genCtx, *currentAssistant)
344 },
345 OnTextDelta: func(id string, text string) error {
346 // Strip leading newline from initial text content. This is is
347 // particularly important in non-interactive mode where leading
348 // newlines are very visible.
349 if len(currentAssistant.Parts) == 0 {
350 text = strings.TrimPrefix(text, "\n")
351 }
352
353 currentAssistant.AppendContent(text)
354 return a.messages.Update(genCtx, *currentAssistant)
355 },
356 OnToolInputStart: func(id string, toolName string) error {
357 toolCall := message.ToolCall{
358 ID: id,
359 Name: toolName,
360 ProviderExecuted: false,
361 Finished: false,
362 }
363 currentAssistant.AddToolCall(toolCall)
364 return a.messages.Update(genCtx, *currentAssistant)
365 },
366 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
367 // TODO: implement
368 },
369 OnToolCall: func(tc fantasy.ToolCallContent) error {
370 toolCall := message.ToolCall{
371 ID: tc.ToolCallID,
372 Name: tc.ToolName,
373 Input: tc.Input,
374 ProviderExecuted: false,
375 Finished: true,
376 }
377 currentAssistant.AddToolCall(toolCall)
378 return a.messages.Update(genCtx, *currentAssistant)
379 },
380 OnToolResult: func(result fantasy.ToolResultContent) error {
381 toolResult := a.convertToToolResult(result)
382 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
383 Role: message.Tool,
384 Parts: []message.ContentPart{
385 toolResult,
386 },
387 })
388 return createMsgErr
389 },
390 OnStepFinish: func(stepResult fantasy.StepResult) error {
391 finishReason := message.FinishReasonUnknown
392 switch stepResult.FinishReason {
393 case fantasy.FinishReasonLength:
394 finishReason = message.FinishReasonMaxTokens
395 case fantasy.FinishReasonStop:
396 finishReason = message.FinishReasonEndTurn
397 case fantasy.FinishReasonToolCalls:
398 finishReason = message.FinishReasonToolUse
399 }
400 currentAssistant.AddFinish(finishReason, "", "")
401 sessionLock.Lock()
402 defer sessionLock.Unlock()
403
404 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
405 if getSessionErr != nil {
406 return getSessionErr
407 }
408 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
409 _, sessionErr := a.sessions.Save(ctx, updatedSession)
410 if sessionErr != nil {
411 return sessionErr
412 }
413 currentSession = updatedSession
414 return a.messages.Update(genCtx, *currentAssistant)
415 },
416 StopWhen: []fantasy.StopCondition{
417 func(_ []fantasy.StepResult) bool {
418 cw := int64(largeModel.CatwalkCfg.ContextWindow)
419 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
420 remaining := cw - tokens
421 var threshold int64
422 if cw > largeContextWindowThreshold {
423 threshold = largeContextWindowBuffer
424 } else {
425 threshold = int64(float64(cw) * smallContextWindowRatio)
426 }
427 if (remaining <= threshold) && !a.disableAutoSummarize {
428 shouldSummarize = true
429 return true
430 }
431 return false
432 },
433 func(steps []fantasy.StepResult) bool {
434 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
435 },
436 },
437 })
438
439 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
440
441 if err != nil {
442 isCancelErr := errors.Is(err, context.Canceled)
443 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
444 if currentAssistant == nil {
445 return result, err
446 }
447 // Ensure we finish thinking on error to close the reasoning state.
448 currentAssistant.FinishThinking()
449 toolCalls := currentAssistant.ToolCalls()
450 // INFO: we use the parent context here because the genCtx has been cancelled.
451 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
452 if createErr != nil {
453 return nil, createErr
454 }
455 for _, tc := range toolCalls {
456 if !tc.Finished {
457 tc.Finished = true
458 tc.Input = "{}"
459 currentAssistant.AddToolCall(tc)
460 updateErr := a.messages.Update(ctx, *currentAssistant)
461 if updateErr != nil {
462 return nil, updateErr
463 }
464 }
465
466 found := false
467 for _, msg := range msgs {
468 if msg.Role == message.Tool {
469 for _, tr := range msg.ToolResults() {
470 if tr.ToolCallID == tc.ID {
471 found = true
472 break
473 }
474 }
475 }
476 if found {
477 break
478 }
479 }
480 if found {
481 continue
482 }
483 content := "There was an error while executing the tool"
484 if isCancelErr {
485 content = "Tool execution canceled by user"
486 } else if isPermissionErr {
487 content = "User denied permission"
488 }
489 toolResult := message.ToolResult{
490 ToolCallID: tc.ID,
491 Name: tc.Name,
492 Content: content,
493 IsError: true,
494 }
495 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
496 Role: message.Tool,
497 Parts: []message.ContentPart{
498 toolResult,
499 },
500 })
501 if createErr != nil {
502 return nil, createErr
503 }
504 }
505 var fantasyErr *fantasy.Error
506 var providerErr *fantasy.ProviderError
507 const defaultTitle = "Provider Error"
508 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
509 if isCancelErr {
510 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
511 } else if isPermissionErr {
512 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
513 } else if errors.Is(err, hyper.ErrNoCredits) {
514 url := hyper.BaseURL()
515 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
516 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
517 } else if errors.As(err, &providerErr) {
518 if providerErr.Message == "The requested model is not supported." {
519 url := "https://github.com/settings/copilot/features"
520 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
521 currentAssistant.AddFinish(
522 message.FinishReasonError,
523 "Copilot model not enabled",
524 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
525 )
526 } else {
527 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
528 }
529 } else if errors.As(err, &fantasyErr) {
530 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
531 } else {
532 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
533 }
534 // Note: we use the parent context here because the genCtx has been
535 // cancelled.
536 updateErr := a.messages.Update(ctx, *currentAssistant)
537 if updateErr != nil {
538 return nil, updateErr
539 }
540 return nil, err
541 }
542
543 // Send notification that agent has finished its turn (skip for
544 // nested/non-interactive sessions).
545 if !call.NonInteractive && a.notify != nil {
546 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
547 SessionID: call.SessionID,
548 SessionTitle: currentSession.Title,
549 Type: notify.TypeAgentFinished,
550 })
551 }
552
553 if shouldSummarize {
554 a.activeRequests.Del(call.SessionID)
555 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
556 return nil, summarizeErr
557 }
558 // If the agent wasn't done...
559 if len(currentAssistant.ToolCalls()) > 0 {
560 existing, ok := a.messageQueue.Get(call.SessionID)
561 if !ok {
562 existing = []SessionAgentCall{}
563 }
564 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
565 existing = append(existing, call)
566 a.messageQueue.Set(call.SessionID, existing)
567 }
568 }
569
570 // Release active request before processing queued messages.
571 a.activeRequests.Del(call.SessionID)
572 cancel()
573
574 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
575 if !ok || len(queuedMessages) == 0 {
576 return result, err
577 }
578 // There are queued messages restart the loop.
579 firstQueuedMessage := queuedMessages[0]
580 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
581 return a.Run(ctx, firstQueuedMessage)
582}
583
584func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
585 if a.IsSessionBusy(sessionID) {
586 return ErrSessionBusy
587 }
588
589 // Copy mutable fields under lock to avoid races with SetModels.
590 largeModel := a.largeModel.Get()
591 systemPromptPrefix := a.systemPromptPrefix.Get()
592
593 currentSession, err := a.sessions.Get(ctx, sessionID)
594 if err != nil {
595 return fmt.Errorf("failed to get session: %w", err)
596 }
597 msgs, err := a.getSessionMessages(ctx, currentSession)
598 if err != nil {
599 return err
600 }
601 if len(msgs) == 0 {
602 // Nothing to summarize.
603 return nil
604 }
605
606 aiMsgs, _ := a.preparePrompt(msgs)
607
608 genCtx, cancel := context.WithCancel(ctx)
609 a.activeRequests.Set(sessionID, cancel)
610 defer a.activeRequests.Del(sessionID)
611 defer cancel()
612
613 agent := fantasy.NewAgent(largeModel.Model,
614 fantasy.WithSystemPrompt(string(summaryPrompt)),
615 fantasy.WithUserAgent(userAgent),
616 )
617 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
618 Role: message.Assistant,
619 Model: largeModel.Model.Model(),
620 Provider: largeModel.Model.Provider(),
621 IsSummaryMessage: true,
622 })
623 if err != nil {
624 return err
625 }
626
627 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
628
629 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
630 Prompt: summaryPromptText,
631 Messages: aiMsgs,
632 ProviderOptions: opts,
633 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
634 prepared.Messages = options.Messages
635 if systemPromptPrefix != "" {
636 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
637 }
638 return callContext, prepared, nil
639 },
640 OnReasoningDelta: func(id string, text string) error {
641 summaryMessage.AppendReasoningContent(text)
642 return a.messages.Update(genCtx, summaryMessage)
643 },
644 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
645 // Handle anthropic signature.
646 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
647 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
648 summaryMessage.AppendReasoningSignature(signature.Signature)
649 }
650 }
651 summaryMessage.FinishThinking()
652 return a.messages.Update(genCtx, summaryMessage)
653 },
654 OnTextDelta: func(id, text string) error {
655 summaryMessage.AppendContent(text)
656 return a.messages.Update(genCtx, summaryMessage)
657 },
658 })
659 if err != nil {
660 isCancelErr := errors.Is(err, context.Canceled)
661 if isCancelErr {
662 // User cancelled summarize we need to remove the summary message.
663 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
664 return deleteErr
665 }
666 return err
667 }
668
669 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
670 err = a.messages.Update(genCtx, summaryMessage)
671 if err != nil {
672 return err
673 }
674
675 var openrouterCost *float64
676 for _, step := range resp.Steps {
677 stepCost := a.openrouterCost(step.ProviderMetadata)
678 if stepCost != nil {
679 newCost := *stepCost
680 if openrouterCost != nil {
681 newCost += *openrouterCost
682 }
683 openrouterCost = &newCost
684 }
685 }
686
687 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
688
689 // Just in case, get just the last usage info.
690 usage := resp.Response.Usage
691 currentSession.SummaryMessageID = summaryMessage.ID
692 currentSession.CompletionTokens = usage.OutputTokens
693 currentSession.PromptTokens = 0
694 _, err = a.sessions.Save(genCtx, currentSession)
695 return err
696}
697
698func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
699 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
700 return fantasy.ProviderOptions{}
701 }
702 return fantasy.ProviderOptions{
703 anthropic.Name: &anthropic.ProviderCacheControlOptions{
704 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
705 },
706 bedrock.Name: &anthropic.ProviderCacheControlOptions{
707 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
708 },
709 vercel.Name: &anthropic.ProviderCacheControlOptions{
710 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
711 },
712 }
713}
714
715func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
716 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
717 var attachmentParts []message.ContentPart
718 for _, attachment := range call.Attachments {
719 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
720 }
721 parts = append(parts, attachmentParts...)
722 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
723 Role: message.User,
724 Parts: parts,
725 })
726 if err != nil {
727 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
728 }
729 return msg, nil
730}
731
732func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
733 var history []fantasy.Message
734 if !a.isSubAgent {
735 history = append(history, fantasy.NewUserMessage(
736 fmt.Sprintf("<system_reminder>%s</system_reminder>",
737 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
738If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
739If not, please feel free to ignore. Again do not mention this message to the user.`,
740 ),
741 ))
742 }
743 for _, m := range msgs {
744 if len(m.Parts) == 0 {
745 continue
746 }
747 // Assistant message without content or tool calls (cancelled before it
748 // returned anything).
749 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
750 continue
751 }
752 history = append(history, m.ToAIMessage()...)
753 }
754
755 var files []fantasy.FilePart
756 for _, attachment := range attachments {
757 if attachment.IsText() {
758 continue
759 }
760 files = append(files, fantasy.FilePart{
761 Filename: attachment.FileName,
762 Data: attachment.Content,
763 MediaType: attachment.MimeType,
764 })
765 }
766
767 return history, files
768}
769
770func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
771 msgs, err := a.messages.List(ctx, session.ID)
772 if err != nil {
773 return nil, fmt.Errorf("failed to list messages: %w", err)
774 }
775
776 if session.SummaryMessageID != "" {
777 summaryMsgIndex := -1
778 for i, msg := range msgs {
779 if msg.ID == session.SummaryMessageID {
780 summaryMsgIndex = i
781 break
782 }
783 }
784 if summaryMsgIndex != -1 {
785 msgs = msgs[summaryMsgIndex:]
786 msgs[0].Role = message.User
787 }
788 }
789 return msgs, nil
790}
791
792// generateTitle generates a session titled based on the initial prompt.
793func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
794 if userPrompt == "" {
795 return
796 }
797
798 smallModel := a.smallModel.Get()
799 largeModel := a.largeModel.Get()
800 systemPromptPrefix := a.systemPromptPrefix.Get()
801
802 var maxOutputTokens int64 = 40
803 if smallModel.CatwalkCfg.CanReason {
804 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
805 }
806
807 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
808 return fantasy.NewAgent(m,
809 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
810 fantasy.WithMaxOutputTokens(tok),
811 fantasy.WithUserAgent(userAgent),
812 )
813 }
814
815 streamCall := fantasy.AgentStreamCall{
816 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
817 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
818 prepared.Messages = opts.Messages
819 if systemPromptPrefix != "" {
820 prepared.Messages = append([]fantasy.Message{
821 fantasy.NewSystemMessage(systemPromptPrefix),
822 }, prepared.Messages...)
823 }
824 return callCtx, prepared, nil
825 },
826 }
827
828 // Use the small model to generate the title.
829 model := smallModel
830 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
831 resp, err := agent.Stream(ctx, streamCall)
832 if err == nil {
833 // We successfully generated a title with the small model.
834 slog.Debug("Generated title with small model")
835 } else {
836 // It didn't work. Let's try with the big model.
837 slog.Error("Error generating title with small model; trying big model", "err", err)
838 model = largeModel
839 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
840 resp, err = agent.Stream(ctx, streamCall)
841 if err == nil {
842 slog.Debug("Generated title with large model")
843 } else {
844 // Welp, the large model didn't work either. Use the default
845 // session name and return.
846 slog.Error("Error generating title with large model", "err", err)
847 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
848 if saveErr != nil {
849 slog.Error("Failed to save session title", "error", saveErr)
850 }
851 return
852 }
853 }
854
855 if resp == nil {
856 // Actually, we didn't get a response so we can't. Use the default
857 // session name and return.
858 slog.Error("Response is nil; can't generate title")
859 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
860 if saveErr != nil {
861 slog.Error("Failed to save session title", "error", saveErr)
862 }
863 return
864 }
865
866 // Clean up title.
867 var title string
868 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
869
870 // Remove thinking tags if present.
871 title = thinkTagRegex.ReplaceAllString(title, "")
872
873 title = strings.TrimSpace(title)
874 title = cmp.Or(title, DefaultSessionName)
875
876 // Calculate usage and cost.
877 var openrouterCost *float64
878 for _, step := range resp.Steps {
879 stepCost := a.openrouterCost(step.ProviderMetadata)
880 if stepCost != nil {
881 newCost := *stepCost
882 if openrouterCost != nil {
883 newCost += *openrouterCost
884 }
885 openrouterCost = &newCost
886 }
887 }
888
889 modelConfig := model.CatwalkCfg
890 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
891 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
892 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
893 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
894
895 // Use override cost if available (e.g., from OpenRouter).
896 if openrouterCost != nil {
897 cost = *openrouterCost
898 }
899
900 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
901 completionTokens := resp.TotalUsage.OutputTokens
902
903 // Atomically update only title and usage fields to avoid overriding other
904 // concurrent session updates.
905 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
906 if saveErr != nil {
907 slog.Error("Failed to save session title and usage", "error", saveErr)
908 return
909 }
910}
911
912func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
913 openrouterMetadata, ok := metadata[openrouter.Name]
914 if !ok {
915 return nil
916 }
917
918 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
919 if !ok {
920 return nil
921 }
922 return &opts.Usage.Cost
923}
924
925func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
926 modelConfig := model.CatwalkCfg
927 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
928 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
929 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
930 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
931
932 a.eventTokensUsed(session.ID, model, usage, cost)
933
934 if overrideCost != nil {
935 session.Cost += *overrideCost
936 } else {
937 session.Cost += cost
938 }
939
940 session.CompletionTokens = usage.OutputTokens
941 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
942}
943
944func (a *sessionAgent) Cancel(sessionID string) {
945 // Cancel regular requests. Don't use Take() here - we need the entry to
946 // remain in activeRequests so IsBusy() returns true until the goroutine
947 // fully completes (including error handling that may access the DB).
948 // The defer in processRequest will clean up the entry.
949 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
950 slog.Debug("Request cancellation initiated", "session_id", sessionID)
951 cancel()
952 }
953
954 // Also check for summarize requests.
955 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
956 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
957 cancel()
958 }
959
960 if a.QueuedPrompts(sessionID) > 0 {
961 slog.Debug("Clearing queued prompts", "session_id", sessionID)
962 a.messageQueue.Del(sessionID)
963 }
964}
965
966func (a *sessionAgent) ClearQueue(sessionID string) {
967 if a.QueuedPrompts(sessionID) > 0 {
968 slog.Debug("Clearing queued prompts", "session_id", sessionID)
969 a.messageQueue.Del(sessionID)
970 }
971}
972
973func (a *sessionAgent) CancelAll() {
974 if !a.IsBusy() {
975 return
976 }
977 for key := range a.activeRequests.Seq2() {
978 a.Cancel(key) // key is sessionID
979 }
980
981 timeout := time.After(5 * time.Second)
982 for a.IsBusy() {
983 select {
984 case <-timeout:
985 return
986 default:
987 time.Sleep(200 * time.Millisecond)
988 }
989 }
990}
991
992func (a *sessionAgent) IsBusy() bool {
993 var busy bool
994 for cancelFunc := range a.activeRequests.Seq() {
995 if cancelFunc != nil {
996 busy = true
997 break
998 }
999 }
1000 return busy
1001}
1002
1003func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1004 _, busy := a.activeRequests.Get(sessionID)
1005 return busy
1006}
1007
1008func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1009 l, ok := a.messageQueue.Get(sessionID)
1010 if !ok {
1011 return 0
1012 }
1013 return len(l)
1014}
1015
1016func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1017 l, ok := a.messageQueue.Get(sessionID)
1018 if !ok {
1019 return nil
1020 }
1021 prompts := make([]string, len(l))
1022 for i, call := range l {
1023 prompts[i] = call.Prompt
1024 }
1025 return prompts
1026}
1027
1028func (a *sessionAgent) SetModels(large Model, small Model) {
1029 a.largeModel.Set(large)
1030 a.smallModel.Set(small)
1031}
1032
1033func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1034 a.tools.SetSlice(tools)
1035}
1036
1037func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1038 a.systemPrompt.Set(systemPrompt)
1039}
1040
1041func (a *sessionAgent) Model() Model {
1042 return a.largeModel.Get()
1043}
1044
1045// convertToToolResult converts a fantasy tool result to a message tool result.
1046func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1047 baseResult := message.ToolResult{
1048 ToolCallID: result.ToolCallID,
1049 Name: result.ToolName,
1050 Metadata: result.ClientMetadata,
1051 }
1052
1053 switch result.Result.GetType() {
1054 case fantasy.ToolResultContentTypeText:
1055 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1056 baseResult.Content = r.Text
1057 }
1058 case fantasy.ToolResultContentTypeError:
1059 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1060 baseResult.Content = r.Error.Error()
1061 baseResult.IsError = true
1062 }
1063 case fantasy.ToolResultContentTypeMedia:
1064 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1065 content := r.Text
1066 if content == "" {
1067 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1068 }
1069 baseResult.Content = content
1070 baseResult.Data = r.Data
1071 baseResult.MIMEType = r.MediaType
1072 }
1073 }
1074
1075 return baseResult
1076}
1077
1078// workaroundProviderMediaLimitations converts media content in tool results to
1079// user messages for providers that don't natively support images in tool results.
1080//
1081// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1082// don't support sending images/media in tool result messages - they only accept
1083// text in tool results. However, they DO support images in user messages.
1084//
1085// If we send media in tool results to these providers, the API returns an error.
1086//
1087// Solution: For these providers, we:
1088// 1. Replace the media in the tool result with a text placeholder
1089// 2. Inject a user message immediately after with the image as a file attachment
1090// 3. This maintains the tool execution flow while working around API limitations
1091//
1092// Anthropic and Bedrock support images natively in tool results, so we skip
1093// this workaround for them.
1094//
1095// Example transformation:
1096//
1097// BEFORE: [tool result: image data]
1098// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1099func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1100 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1101 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1102
1103 if providerSupportsMedia {
1104 return messages
1105 }
1106
1107 convertedMessages := make([]fantasy.Message, 0, len(messages))
1108
1109 for _, msg := range messages {
1110 if msg.Role != fantasy.MessageRoleTool {
1111 convertedMessages = append(convertedMessages, msg)
1112 continue
1113 }
1114
1115 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1116 var mediaFiles []fantasy.FilePart
1117
1118 for _, part := range msg.Content {
1119 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1120 if !ok {
1121 textParts = append(textParts, part)
1122 continue
1123 }
1124
1125 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1126 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1127 if err != nil {
1128 slog.Warn("Failed to decode media data", "error", err)
1129 textParts = append(textParts, part)
1130 continue
1131 }
1132
1133 mediaFiles = append(mediaFiles, fantasy.FilePart{
1134 Data: decoded,
1135 MediaType: media.MediaType,
1136 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1137 })
1138
1139 textParts = append(textParts, fantasy.ToolResultPart{
1140 ToolCallID: toolResult.ToolCallID,
1141 Output: fantasy.ToolResultOutputContentText{
1142 Text: "[Image/media content loaded - see attached file]",
1143 },
1144 ProviderOptions: toolResult.ProviderOptions,
1145 })
1146 } else {
1147 textParts = append(textParts, part)
1148 }
1149 }
1150
1151 convertedMessages = append(convertedMessages, fantasy.Message{
1152 Role: fantasy.MessageRoleTool,
1153 Content: textParts,
1154 })
1155
1156 if len(mediaFiles) > 0 {
1157 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1158 "Here is the media content from the tool result:",
1159 mediaFiles...,
1160 ))
1161 }
1162 }
1163
1164 return convertedMessages
1165}
1166
1167// buildSummaryPrompt constructs the prompt text for session summarization.
1168func buildSummaryPrompt(todos []session.Todo) string {
1169 var sb strings.Builder
1170 sb.WriteString("Provide a detailed summary of our conversation above.")
1171 if len(todos) > 0 {
1172 sb.WriteString("\n\n## Current Todo List\n\n")
1173 for _, t := range todos {
1174 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1175 }
1176 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1177 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1178 }
1179 return sb.String()
1180}