1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
68
69type SessionAgentCall struct {
70 SessionID string
71 Prompt string
72 ProviderOptions fantasy.ProviderOptions
73 Attachments []message.Attachment
74 MaxOutputTokens int64
75 Temperature *float64
76 TopP *float64
77 TopK *int64
78 FrequencyPenalty *float64
79 PresencePenalty *float64
80 NonInteractive bool
81}
82
83type SessionAgent interface {
84 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
85 SetModels(large Model, small Model)
86 SetTools(tools []fantasy.AgentTool)
87 SetSystemPrompt(systemPrompt string)
88 Cancel(sessionID string)
89 CancelAll()
90 IsSessionBusy(sessionID string) bool
91 IsBusy() bool
92 QueuedPrompts(sessionID string) int
93 QueuedPromptsList(sessionID string) []string
94 ClearQueue(sessionID string)
95 Summarize(context.Context, string, fantasy.ProviderOptions) error
96 Model() Model
97}
98
99type Model struct {
100 Model fantasy.LanguageModel
101 CatwalkCfg catwalk.Model
102 ModelCfg config.SelectedModel
103}
104
105type sessionAgent struct {
106 largeModel *csync.Value[Model]
107 smallModel *csync.Value[Model]
108 systemPromptPrefix *csync.Value[string]
109 systemPrompt *csync.Value[string]
110 tools *csync.Slice[fantasy.AgentTool]
111
112 isSubAgent bool
113 sessions session.Service
114 messages message.Service
115 disableAutoSummarize bool
116 isYolo bool
117 notify pubsub.Publisher[notify.Notification]
118
119 messageQueue *csync.Map[string, []SessionAgentCall]
120 activeRequests *csync.Map[string, context.CancelFunc]
121}
122
123type SessionAgentOptions struct {
124 LargeModel Model
125 SmallModel Model
126 SystemPromptPrefix string
127 SystemPrompt string
128 IsSubAgent bool
129 DisableAutoSummarize bool
130 IsYolo bool
131 Sessions session.Service
132 Messages message.Service
133 Tools []fantasy.AgentTool
134 Notify pubsub.Publisher[notify.Notification]
135}
136
137func NewSessionAgent(
138 opts SessionAgentOptions,
139) SessionAgent {
140 return &sessionAgent{
141 largeModel: csync.NewValue(opts.LargeModel),
142 smallModel: csync.NewValue(opts.SmallModel),
143 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
144 systemPrompt: csync.NewValue(opts.SystemPrompt),
145 isSubAgent: opts.IsSubAgent,
146 sessions: opts.Sessions,
147 messages: opts.Messages,
148 disableAutoSummarize: opts.DisableAutoSummarize,
149 tools: csync.NewSliceFrom(opts.Tools),
150 isYolo: opts.IsYolo,
151 notify: opts.Notify,
152 messageQueue: csync.NewMap[string, []SessionAgentCall](),
153 activeRequests: csync.NewMap[string, context.CancelFunc](),
154 }
155}
156
157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
158 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
159 return nil, ErrEmptyPrompt
160 }
161 if call.SessionID == "" {
162 return nil, ErrSessionMissing
163 }
164
165 // Queue the message if busy
166 if a.IsSessionBusy(call.SessionID) {
167 existing, ok := a.messageQueue.Get(call.SessionID)
168 if !ok {
169 existing = []SessionAgentCall{}
170 }
171 existing = append(existing, call)
172 a.messageQueue.Set(call.SessionID, existing)
173 return nil, nil
174 }
175
176 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
177 agentTools := a.tools.Copy()
178 largeModel := a.largeModel.Get()
179 systemPrompt := a.systemPrompt.Get()
180 promptPrefix := a.systemPromptPrefix.Get()
181 var instructions strings.Builder
182
183 for _, server := range mcp.GetStates() {
184 if server.State != mcp.StateConnected {
185 continue
186 }
187 if s := server.Client.InitializeResult().Instructions; s != "" {
188 instructions.WriteString(s)
189 instructions.WriteString("\n\n")
190 }
191 }
192
193 if s := instructions.String(); s != "" {
194 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
195 }
196
197 if len(agentTools) > 0 {
198 // Add Anthropic caching to the last tool.
199 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
200 }
201
202 agent := fantasy.NewAgent(
203 largeModel.Model,
204 fantasy.WithSystemPrompt(systemPrompt),
205 fantasy.WithTools(agentTools...),
206 fantasy.WithUserAgent(userAgent),
207 )
208
209 sessionLock := sync.Mutex{}
210 currentSession, err := a.sessions.Get(ctx, call.SessionID)
211 if err != nil {
212 return nil, fmt.Errorf("failed to get session: %w", err)
213 }
214
215 msgs, err := a.getSessionMessages(ctx, currentSession)
216 if err != nil {
217 return nil, fmt.Errorf("failed to get session messages: %w", err)
218 }
219
220 var wg sync.WaitGroup
221 // Generate title if first message.
222 if len(msgs) == 0 {
223 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
224 wg.Go(func() {
225 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
226 })
227 }
228 defer wg.Wait()
229
230 // Add the user message to the session.
231 _, err = a.createUserMessage(ctx, call)
232 if err != nil {
233 return nil, err
234 }
235
236 // Add the session to the context.
237 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
238
239 genCtx, cancel := context.WithCancel(ctx)
240 a.activeRequests.Set(call.SessionID, cancel)
241
242 defer cancel()
243 defer a.activeRequests.Del(call.SessionID)
244
245 history, files := a.preparePrompt(msgs, call.Attachments...)
246
247 startTime := time.Now()
248 a.eventPromptSent(call.SessionID)
249
250 var currentAssistant *message.Message
251 var shouldSummarize bool
252 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
253 var maxOutputTokens *int64
254 if call.MaxOutputTokens > 0 {
255 maxOutputTokens = &call.MaxOutputTokens
256 }
257 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
258 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
259 Files: files,
260 Messages: history,
261 ProviderOptions: call.ProviderOptions,
262 MaxOutputTokens: maxOutputTokens,
263 TopP: call.TopP,
264 Temperature: call.Temperature,
265 PresencePenalty: call.PresencePenalty,
266 TopK: call.TopK,
267 FrequencyPenalty: call.FrequencyPenalty,
268 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
269 prepared.Messages = options.Messages
270 for i := range prepared.Messages {
271 prepared.Messages[i].ProviderOptions = nil
272 }
273
274 // Use latest tools (updated by SetTools when MCP tools change).
275 prepared.Tools = a.tools.Copy()
276
277 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
278 a.messageQueue.Del(call.SessionID)
279 for _, queued := range queuedCalls {
280 userMessage, createErr := a.createUserMessage(callContext, queued)
281 if createErr != nil {
282 return callContext, prepared, createErr
283 }
284 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
285 }
286
287 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
288
289 lastSystemRoleInx := 0
290 systemMessageUpdated := false
291 for i, msg := range prepared.Messages {
292 // Only add cache control to the last message.
293 if msg.Role == fantasy.MessageRoleSystem {
294 lastSystemRoleInx = i
295 } else if !systemMessageUpdated {
296 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
297 systemMessageUpdated = true
298 }
299 // Than add cache control to the last 2 messages.
300 if i > len(prepared.Messages)-3 {
301 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
302 }
303 }
304
305 if promptPrefix != "" {
306 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
307 }
308
309 var assistantMsg message.Message
310 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
311 Role: message.Assistant,
312 Parts: []message.ContentPart{},
313 Model: largeModel.ModelCfg.Model,
314 Provider: largeModel.ModelCfg.Provider,
315 })
316 if err != nil {
317 return callContext, prepared, err
318 }
319 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
320 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
321 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
322 currentAssistant = &assistantMsg
323 return callContext, prepared, err
324 },
325 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
326 currentAssistant.AppendReasoningContent(reasoning.Text)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnReasoningDelta: func(id string, text string) error {
330 currentAssistant.AppendReasoningContent(text)
331 return a.messages.Update(genCtx, *currentAssistant)
332 },
333 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
334 // handle anthropic signature
335 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
336 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
337 currentAssistant.AppendReasoningSignature(reasoning.Signature)
338 }
339 }
340 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
341 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
342 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
343 }
344 }
345 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
346 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
347 currentAssistant.SetReasoningResponsesData(reasoning)
348 }
349 }
350 currentAssistant.FinishThinking()
351 return a.messages.Update(genCtx, *currentAssistant)
352 },
353 OnTextDelta: func(id string, text string) error {
354 // Strip leading newline from initial text content. This is is
355 // particularly important in non-interactive mode where leading
356 // newlines are very visible.
357 if len(currentAssistant.Parts) == 0 {
358 text = strings.TrimPrefix(text, "\n")
359 }
360
361 currentAssistant.AppendContent(text)
362 return a.messages.Update(genCtx, *currentAssistant)
363 },
364 OnToolInputStart: func(id string, toolName string) error {
365 toolCall := message.ToolCall{
366 ID: id,
367 Name: toolName,
368 ProviderExecuted: false,
369 Finished: false,
370 }
371 currentAssistant.AddToolCall(toolCall)
372 // Use parent ctx instead of genCtx to ensure the update succeeds
373 // even if the request is canceled mid-stream
374 return a.messages.Update(ctx, *currentAssistant)
375 },
376 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
377 // TODO: implement
378 },
379 OnToolCall: func(tc fantasy.ToolCallContent) error {
380 toolCall := message.ToolCall{
381 ID: tc.ToolCallID,
382 Name: tc.ToolName,
383 Input: tc.Input,
384 ProviderExecuted: false,
385 Finished: true,
386 }
387 currentAssistant.AddToolCall(toolCall)
388 // Use parent ctx instead of genCtx to ensure the update succeeds
389 // even if the request is canceled mid-stream
390 return a.messages.Update(ctx, *currentAssistant)
391 },
392 OnToolResult: func(result fantasy.ToolResultContent) error {
393 toolResult := a.convertToToolResult(result)
394 // Use parent ctx instead of genCtx to ensure the message is created
395 // even if the request is canceled mid-stream
396 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
397 Role: message.Tool,
398 Parts: []message.ContentPart{
399 toolResult,
400 },
401 })
402 return createMsgErr
403 },
404 OnStepFinish: func(stepResult fantasy.StepResult) error {
405 finishReason := message.FinishReasonUnknown
406 switch stepResult.FinishReason {
407 case fantasy.FinishReasonLength:
408 finishReason = message.FinishReasonMaxTokens
409 case fantasy.FinishReasonStop:
410 finishReason = message.FinishReasonEndTurn
411 case fantasy.FinishReasonToolCalls:
412 finishReason = message.FinishReasonToolUse
413 }
414 currentAssistant.AddFinish(finishReason, "", "")
415 sessionLock.Lock()
416 defer sessionLock.Unlock()
417
418 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
419 if getSessionErr != nil {
420 return getSessionErr
421 }
422 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
423 _, sessionErr := a.sessions.Save(ctx, updatedSession)
424 if sessionErr != nil {
425 return sessionErr
426 }
427 currentSession = updatedSession
428 return a.messages.Update(genCtx, *currentAssistant)
429 },
430 StopWhen: []fantasy.StopCondition{
431 func(_ []fantasy.StepResult) bool {
432 cw := int64(largeModel.CatwalkCfg.ContextWindow)
433 // If context window is unknown (0), skip auto-summarize
434 // to avoid immediately truncating custom/local models.
435 if cw == 0 {
436 return false
437 }
438 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
439 remaining := cw - tokens
440 var threshold int64
441 if cw > largeContextWindowThreshold {
442 threshold = largeContextWindowBuffer
443 } else {
444 threshold = int64(float64(cw) * smallContextWindowRatio)
445 }
446 if (remaining <= threshold) && !a.disableAutoSummarize {
447 shouldSummarize = true
448 return true
449 }
450 return false
451 },
452 func(steps []fantasy.StepResult) bool {
453 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
454 },
455 },
456 })
457
458 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
459
460 if err != nil {
461 isCancelErr := errors.Is(err, context.Canceled)
462 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
463 if currentAssistant == nil {
464 return result, err
465 }
466 // Ensure we finish thinking on error to close the reasoning state.
467 currentAssistant.FinishThinking()
468 toolCalls := currentAssistant.ToolCalls()
469 // INFO: we use the parent context here because the genCtx has been cancelled.
470 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
471 if createErr != nil {
472 return nil, createErr
473 }
474 for _, tc := range toolCalls {
475 if !tc.Finished {
476 tc.Finished = true
477 tc.Input = "{}"
478 currentAssistant.AddToolCall(tc)
479 updateErr := a.messages.Update(ctx, *currentAssistant)
480 if updateErr != nil {
481 return nil, updateErr
482 }
483 }
484
485 found := false
486 for _, msg := range msgs {
487 if msg.Role == message.Tool {
488 for _, tr := range msg.ToolResults() {
489 if tr.ToolCallID == tc.ID {
490 found = true
491 break
492 }
493 }
494 }
495 if found {
496 break
497 }
498 }
499 if found {
500 continue
501 }
502 content := "There was an error while executing the tool"
503 if isCancelErr {
504 content = "Error: user cancelled assistant tool calling"
505 } else if isPermissionErr {
506 content = "User denied permission"
507 }
508 toolResult := message.ToolResult{
509 ToolCallID: tc.ID,
510 Name: tc.Name,
511 Content: content,
512 IsError: true,
513 }
514 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
515 Role: message.Tool,
516 Parts: []message.ContentPart{
517 toolResult,
518 },
519 })
520 if createErr != nil {
521 return nil, createErr
522 }
523 }
524 var fantasyErr *fantasy.Error
525 var providerErr *fantasy.ProviderError
526 const defaultTitle = "Provider Error"
527 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
528 if isCancelErr {
529 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
530 } else if isPermissionErr {
531 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
532 } else if errors.Is(err, hyper.ErrNoCredits) {
533 url := hyper.BaseURL()
534 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
535 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
536 } else if errors.As(err, &providerErr) {
537 if providerErr.Message == "The requested model is not supported." {
538 url := "https://github.com/settings/copilot/features"
539 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
540 currentAssistant.AddFinish(
541 message.FinishReasonError,
542 "Copilot model not enabled",
543 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
544 )
545 } else {
546 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
547 }
548 } else if errors.As(err, &fantasyErr) {
549 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
550 } else {
551 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
552 }
553 // Note: we use the parent context here because the genCtx has been
554 // cancelled.
555 updateErr := a.messages.Update(ctx, *currentAssistant)
556 if updateErr != nil {
557 return nil, updateErr
558 }
559 return nil, err
560 }
561
562 // Send notification that agent has finished its turn (skip for
563 // nested/non-interactive sessions).
564 if !call.NonInteractive && a.notify != nil {
565 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
566 SessionID: call.SessionID,
567 SessionTitle: currentSession.Title,
568 Type: notify.TypeAgentFinished,
569 })
570 }
571
572 if shouldSummarize {
573 a.activeRequests.Del(call.SessionID)
574 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
575 return nil, summarizeErr
576 }
577 // If the agent wasn't done...
578 if len(currentAssistant.ToolCalls()) > 0 {
579 existing, ok := a.messageQueue.Get(call.SessionID)
580 if !ok {
581 existing = []SessionAgentCall{}
582 }
583 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
584 existing = append(existing, call)
585 a.messageQueue.Set(call.SessionID, existing)
586 }
587 }
588
589 // Release active request before processing queued messages.
590 a.activeRequests.Del(call.SessionID)
591 cancel()
592
593 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
594 if !ok || len(queuedMessages) == 0 {
595 return result, err
596 }
597 // There are queued messages restart the loop.
598 firstQueuedMessage := queuedMessages[0]
599 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
600 return a.Run(ctx, firstQueuedMessage)
601}
602
603func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
604 if a.IsSessionBusy(sessionID) {
605 return ErrSessionBusy
606 }
607
608 // Copy mutable fields under lock to avoid races with SetModels.
609 largeModel := a.largeModel.Get()
610 systemPromptPrefix := a.systemPromptPrefix.Get()
611
612 currentSession, err := a.sessions.Get(ctx, sessionID)
613 if err != nil {
614 return fmt.Errorf("failed to get session: %w", err)
615 }
616 msgs, err := a.getSessionMessages(ctx, currentSession)
617 if err != nil {
618 return err
619 }
620 if len(msgs) == 0 {
621 // Nothing to summarize.
622 return nil
623 }
624
625 aiMsgs, _ := a.preparePrompt(msgs)
626
627 genCtx, cancel := context.WithCancel(ctx)
628 a.activeRequests.Set(sessionID, cancel)
629 defer a.activeRequests.Del(sessionID)
630 defer cancel()
631
632 agent := fantasy.NewAgent(largeModel.Model,
633 fantasy.WithSystemPrompt(string(summaryPrompt)),
634 fantasy.WithUserAgent(userAgent),
635 )
636 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
637 Role: message.Assistant,
638 Model: largeModel.Model.Model(),
639 Provider: largeModel.Model.Provider(),
640 IsSummaryMessage: true,
641 })
642 if err != nil {
643 return err
644 }
645
646 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
647
648 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
649 Prompt: summaryPromptText,
650 Messages: aiMsgs,
651 ProviderOptions: opts,
652 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
653 prepared.Messages = options.Messages
654 if systemPromptPrefix != "" {
655 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
656 }
657 return callContext, prepared, nil
658 },
659 OnReasoningDelta: func(id string, text string) error {
660 summaryMessage.AppendReasoningContent(text)
661 return a.messages.Update(genCtx, summaryMessage)
662 },
663 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
664 // Handle anthropic signature.
665 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
666 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
667 summaryMessage.AppendReasoningSignature(signature.Signature)
668 }
669 }
670 summaryMessage.FinishThinking()
671 return a.messages.Update(genCtx, summaryMessage)
672 },
673 OnTextDelta: func(id, text string) error {
674 summaryMessage.AppendContent(text)
675 return a.messages.Update(genCtx, summaryMessage)
676 },
677 })
678 if err != nil {
679 isCancelErr := errors.Is(err, context.Canceled)
680 if isCancelErr {
681 // User cancelled summarize we need to remove the summary message.
682 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
683 return deleteErr
684 }
685 return err
686 }
687
688 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
689 err = a.messages.Update(genCtx, summaryMessage)
690 if err != nil {
691 return err
692 }
693
694 var openrouterCost *float64
695 for _, step := range resp.Steps {
696 stepCost := a.openrouterCost(step.ProviderMetadata)
697 if stepCost != nil {
698 newCost := *stepCost
699 if openrouterCost != nil {
700 newCost += *openrouterCost
701 }
702 openrouterCost = &newCost
703 }
704 }
705
706 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
707
708 // Just in case, get just the last usage info.
709 usage := resp.Response.Usage
710 currentSession.SummaryMessageID = summaryMessage.ID
711 currentSession.CompletionTokens = usage.OutputTokens
712 currentSession.PromptTokens = 0
713 _, err = a.sessions.Save(genCtx, currentSession)
714 return err
715}
716
717func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
718 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
719 return fantasy.ProviderOptions{}
720 }
721 return fantasy.ProviderOptions{
722 anthropic.Name: &anthropic.ProviderCacheControlOptions{
723 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
724 },
725 bedrock.Name: &anthropic.ProviderCacheControlOptions{
726 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
727 },
728 vercel.Name: &anthropic.ProviderCacheControlOptions{
729 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
730 },
731 }
732}
733
734func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
735 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
736 var attachmentParts []message.ContentPart
737 for _, attachment := range call.Attachments {
738 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
739 }
740 parts = append(parts, attachmentParts...)
741 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
742 Role: message.User,
743 Parts: parts,
744 })
745 if err != nil {
746 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
747 }
748 return msg, nil
749}
750
751func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
752 var history []fantasy.Message
753 if !a.isSubAgent {
754 history = append(history, fantasy.NewUserMessage(
755 fmt.Sprintf("<system_reminder>%s</system_reminder>",
756 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
757If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
758If not, please feel free to ignore. Again do not mention this message to the user.`,
759 ),
760 ))
761 }
762 for _, m := range msgs {
763 if len(m.Parts) == 0 {
764 continue
765 }
766 // Assistant message without content or tool calls (cancelled before it
767 // returned anything).
768 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
769 continue
770 }
771 history = append(history, m.ToAIMessage()...)
772 }
773
774 var files []fantasy.FilePart
775 for _, attachment := range attachments {
776 if attachment.IsText() {
777 continue
778 }
779 files = append(files, fantasy.FilePart{
780 Filename: attachment.FileName,
781 Data: attachment.Content,
782 MediaType: attachment.MimeType,
783 })
784 }
785
786 return history, files
787}
788
789func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
790 msgs, err := a.messages.List(ctx, session.ID)
791 if err != nil {
792 return nil, fmt.Errorf("failed to list messages: %w", err)
793 }
794
795 if session.SummaryMessageID != "" {
796 summaryMsgIndex := -1
797 for i, msg := range msgs {
798 if msg.ID == session.SummaryMessageID {
799 summaryMsgIndex = i
800 break
801 }
802 }
803 if summaryMsgIndex != -1 {
804 msgs = msgs[summaryMsgIndex:]
805 msgs[0].Role = message.User
806 }
807 }
808 return msgs, nil
809}
810
811// generateTitle generates a session titled based on the initial prompt.
812func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
813 if userPrompt == "" {
814 return
815 }
816
817 smallModel := a.smallModel.Get()
818 largeModel := a.largeModel.Get()
819 systemPromptPrefix := a.systemPromptPrefix.Get()
820
821 var maxOutputTokens int64 = 40
822 if smallModel.CatwalkCfg.CanReason {
823 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
824 }
825
826 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
827 return fantasy.NewAgent(m,
828 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
829 fantasy.WithMaxOutputTokens(tok),
830 fantasy.WithUserAgent(userAgent),
831 )
832 }
833
834 streamCall := fantasy.AgentStreamCall{
835 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
836 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
837 prepared.Messages = opts.Messages
838 if systemPromptPrefix != "" {
839 prepared.Messages = append([]fantasy.Message{
840 fantasy.NewSystemMessage(systemPromptPrefix),
841 }, prepared.Messages...)
842 }
843 return callCtx, prepared, nil
844 },
845 }
846
847 // Use the small model to generate the title.
848 model := smallModel
849 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
850 resp, err := agent.Stream(ctx, streamCall)
851 if err == nil {
852 // We successfully generated a title with the small model.
853 slog.Debug("Generated title with small model")
854 } else {
855 // It didn't work. Let's try with the big model.
856 slog.Error("Error generating title with small model; trying big model", "err", err)
857 model = largeModel
858 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
859 resp, err = agent.Stream(ctx, streamCall)
860 if err == nil {
861 slog.Debug("Generated title with large model")
862 } else {
863 // Welp, the large model didn't work either. Use the default
864 // session name and return.
865 slog.Error("Error generating title with large model", "err", err)
866 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
867 if saveErr != nil {
868 slog.Error("Failed to save session title", "error", saveErr)
869 }
870 return
871 }
872 }
873
874 if resp == nil {
875 // Actually, we didn't get a response so we can't. Use the default
876 // session name and return.
877 slog.Error("Response is nil; can't generate title")
878 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
879 if saveErr != nil {
880 slog.Error("Failed to save session title", "error", saveErr)
881 }
882 return
883 }
884
885 // Clean up title.
886 var title string
887 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
888
889 // Remove thinking tags if present.
890 title = thinkTagRegex.ReplaceAllString(title, "")
891
892 title = strings.TrimSpace(title)
893 title = cmp.Or(title, DefaultSessionName)
894
895 // Calculate usage and cost.
896 var openrouterCost *float64
897 for _, step := range resp.Steps {
898 stepCost := a.openrouterCost(step.ProviderMetadata)
899 if stepCost != nil {
900 newCost := *stepCost
901 if openrouterCost != nil {
902 newCost += *openrouterCost
903 }
904 openrouterCost = &newCost
905 }
906 }
907
908 modelConfig := model.CatwalkCfg
909 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
910 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
911 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
912 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
913
914 // Use override cost if available (e.g., from OpenRouter).
915 if openrouterCost != nil {
916 cost = *openrouterCost
917 }
918
919 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
920 completionTokens := resp.TotalUsage.OutputTokens
921
922 // Atomically update only title and usage fields to avoid overriding other
923 // concurrent session updates.
924 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
925 if saveErr != nil {
926 slog.Error("Failed to save session title and usage", "error", saveErr)
927 return
928 }
929}
930
931func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
932 openrouterMetadata, ok := metadata[openrouter.Name]
933 if !ok {
934 return nil
935 }
936
937 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
938 if !ok {
939 return nil
940 }
941 return &opts.Usage.Cost
942}
943
944func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
945 modelConfig := model.CatwalkCfg
946 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
947 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
948 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
949 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
950
951 a.eventTokensUsed(session.ID, model, usage, cost)
952
953 if overrideCost != nil {
954 session.Cost += *overrideCost
955 } else {
956 session.Cost += cost
957 }
958
959 session.CompletionTokens = usage.OutputTokens
960 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
961}
962
963func (a *sessionAgent) Cancel(sessionID string) {
964 // Cancel regular requests. Don't use Take() here - we need the entry to
965 // remain in activeRequests so IsBusy() returns true until the goroutine
966 // fully completes (including error handling that may access the DB).
967 // The defer in processRequest will clean up the entry.
968 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
969 slog.Debug("Request cancellation initiated", "session_id", sessionID)
970 cancel()
971 }
972
973 // Also check for summarize requests.
974 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
975 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
976 cancel()
977 }
978
979 if a.QueuedPrompts(sessionID) > 0 {
980 slog.Debug("Clearing queued prompts", "session_id", sessionID)
981 a.messageQueue.Del(sessionID)
982 }
983}
984
985func (a *sessionAgent) ClearQueue(sessionID string) {
986 if a.QueuedPrompts(sessionID) > 0 {
987 slog.Debug("Clearing queued prompts", "session_id", sessionID)
988 a.messageQueue.Del(sessionID)
989 }
990}
991
992func (a *sessionAgent) CancelAll() {
993 if !a.IsBusy() {
994 return
995 }
996 for key := range a.activeRequests.Seq2() {
997 a.Cancel(key) // key is sessionID
998 }
999
1000 timeout := time.After(5 * time.Second)
1001 for a.IsBusy() {
1002 select {
1003 case <-timeout:
1004 return
1005 default:
1006 time.Sleep(200 * time.Millisecond)
1007 }
1008 }
1009}
1010
1011func (a *sessionAgent) IsBusy() bool {
1012 var busy bool
1013 for cancelFunc := range a.activeRequests.Seq() {
1014 if cancelFunc != nil {
1015 busy = true
1016 break
1017 }
1018 }
1019 return busy
1020}
1021
1022func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1023 _, busy := a.activeRequests.Get(sessionID)
1024 return busy
1025}
1026
1027func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1028 l, ok := a.messageQueue.Get(sessionID)
1029 if !ok {
1030 return 0
1031 }
1032 return len(l)
1033}
1034
1035func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1036 l, ok := a.messageQueue.Get(sessionID)
1037 if !ok {
1038 return nil
1039 }
1040 prompts := make([]string, len(l))
1041 for i, call := range l {
1042 prompts[i] = call.Prompt
1043 }
1044 return prompts
1045}
1046
1047func (a *sessionAgent) SetModels(large Model, small Model) {
1048 a.largeModel.Set(large)
1049 a.smallModel.Set(small)
1050}
1051
1052func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1053 a.tools.SetSlice(tools)
1054}
1055
1056func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1057 a.systemPrompt.Set(systemPrompt)
1058}
1059
1060func (a *sessionAgent) Model() Model {
1061 return a.largeModel.Get()
1062}
1063
1064// convertToToolResult converts a fantasy tool result to a message tool result.
1065func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1066 baseResult := message.ToolResult{
1067 ToolCallID: result.ToolCallID,
1068 Name: result.ToolName,
1069 Metadata: result.ClientMetadata,
1070 }
1071
1072 switch result.Result.GetType() {
1073 case fantasy.ToolResultContentTypeText:
1074 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1075 baseResult.Content = r.Text
1076 }
1077 case fantasy.ToolResultContentTypeError:
1078 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1079 baseResult.Content = r.Error.Error()
1080 baseResult.IsError = true
1081 }
1082 case fantasy.ToolResultContentTypeMedia:
1083 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1084 content := r.Text
1085 if content == "" {
1086 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1087 }
1088 baseResult.Content = content
1089 baseResult.Data = r.Data
1090 baseResult.MIMEType = r.MediaType
1091 }
1092 }
1093
1094 return baseResult
1095}
1096
1097// workaroundProviderMediaLimitations converts media content in tool results to
1098// user messages for providers that don't natively support images in tool results.
1099//
1100// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1101// don't support sending images/media in tool result messages - they only accept
1102// text in tool results. However, they DO support images in user messages.
1103//
1104// If we send media in tool results to these providers, the API returns an error.
1105//
1106// Solution: For these providers, we:
1107// 1. Replace the media in the tool result with a text placeholder
1108// 2. Inject a user message immediately after with the image as a file attachment
1109// 3. This maintains the tool execution flow while working around API limitations
1110//
1111// Anthropic and Bedrock support images natively in tool results, so we skip
1112// this workaround for them.
1113//
1114// Example transformation:
1115//
1116// BEFORE: [tool result: image data]
1117// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1118func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1119 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1120 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1121
1122 if providerSupportsMedia {
1123 return messages
1124 }
1125
1126 convertedMessages := make([]fantasy.Message, 0, len(messages))
1127
1128 for _, msg := range messages {
1129 if msg.Role != fantasy.MessageRoleTool {
1130 convertedMessages = append(convertedMessages, msg)
1131 continue
1132 }
1133
1134 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1135 var mediaFiles []fantasy.FilePart
1136
1137 for _, part := range msg.Content {
1138 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1139 if !ok {
1140 textParts = append(textParts, part)
1141 continue
1142 }
1143
1144 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1145 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1146 if err != nil {
1147 slog.Warn("Failed to decode media data", "error", err)
1148 textParts = append(textParts, part)
1149 continue
1150 }
1151
1152 mediaFiles = append(mediaFiles, fantasy.FilePart{
1153 Data: decoded,
1154 MediaType: media.MediaType,
1155 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1156 })
1157
1158 textParts = append(textParts, fantasy.ToolResultPart{
1159 ToolCallID: toolResult.ToolCallID,
1160 Output: fantasy.ToolResultOutputContentText{
1161 Text: "[Image/media content loaded - see attached file]",
1162 },
1163 ProviderOptions: toolResult.ProviderOptions,
1164 })
1165 } else {
1166 textParts = append(textParts, part)
1167 }
1168 }
1169
1170 convertedMessages = append(convertedMessages, fantasy.Message{
1171 Role: fantasy.MessageRoleTool,
1172 Content: textParts,
1173 })
1174
1175 if len(mediaFiles) > 0 {
1176 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1177 "Here is the media content from the tool result:",
1178 mediaFiles...,
1179 ))
1180 }
1181 }
1182
1183 return convertedMessages
1184}
1185
1186// buildSummaryPrompt constructs the prompt text for session summarization.
1187func buildSummaryPrompt(todos []session.Todo) string {
1188 var sb strings.Builder
1189 sb.WriteString("Provide a detailed summary of our conversation above.")
1190 if len(todos) > 0 {
1191 sb.WriteString("\n\n## Current Todo List\n\n")
1192 for _, t := range todos {
1193 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1194 }
1195 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1196 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1197 }
1198 return sb.String()
1199}