1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
68
69type SessionAgentCall struct {
70 SessionID string
71 Prompt string
72 ProviderOptions fantasy.ProviderOptions
73 Attachments []message.Attachment
74 MaxOutputTokens int64
75 Temperature *float64
76 TopP *float64
77 TopK *int64
78 FrequencyPenalty *float64
79 PresencePenalty *float64
80 NonInteractive bool
81}
82
83type SessionAgent interface {
84 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
85 SetModels(large Model, small Model)
86 SetTools(tools []fantasy.AgentTool)
87 SetSystemPrompt(systemPrompt string)
88 Cancel(sessionID string)
89 CancelAll()
90 IsSessionBusy(sessionID string) bool
91 IsBusy() bool
92 QueuedPrompts(sessionID string) int
93 QueuedPromptsList(sessionID string) []string
94 ClearQueue(sessionID string)
95 Summarize(context.Context, string, fantasy.ProviderOptions) error
96 Model() Model
97}
98
99type Model struct {
100 Model fantasy.LanguageModel
101 CatwalkCfg catwalk.Model
102 ModelCfg config.SelectedModel
103}
104
105type sessionAgent struct {
106 largeModel *csync.Value[Model]
107 smallModel *csync.Value[Model]
108 systemPromptPrefix *csync.Value[string]
109 systemPrompt *csync.Value[string]
110 tools *csync.Slice[fantasy.AgentTool]
111
112 isSubAgent bool
113 sessions session.Service
114 messages message.Service
115 disableAutoSummarize bool
116 isYolo bool
117 notify pubsub.Publisher[notify.Notification]
118
119 messageQueue *csync.Map[string, []SessionAgentCall]
120 activeRequests *csync.Map[string, context.CancelFunc]
121}
122
123type SessionAgentOptions struct {
124 LargeModel Model
125 SmallModel Model
126 SystemPromptPrefix string
127 SystemPrompt string
128 IsSubAgent bool
129 DisableAutoSummarize bool
130 IsYolo bool
131 Sessions session.Service
132 Messages message.Service
133 Tools []fantasy.AgentTool
134 Notify pubsub.Publisher[notify.Notification]
135}
136
137func NewSessionAgent(
138 opts SessionAgentOptions,
139) SessionAgent {
140 return &sessionAgent{
141 largeModel: csync.NewValue(opts.LargeModel),
142 smallModel: csync.NewValue(opts.SmallModel),
143 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
144 systemPrompt: csync.NewValue(opts.SystemPrompt),
145 isSubAgent: opts.IsSubAgent,
146 sessions: opts.Sessions,
147 messages: opts.Messages,
148 disableAutoSummarize: opts.DisableAutoSummarize,
149 tools: csync.NewSliceFrom(opts.Tools),
150 isYolo: opts.IsYolo,
151 notify: opts.Notify,
152 messageQueue: csync.NewMap[string, []SessionAgentCall](),
153 activeRequests: csync.NewMap[string, context.CancelFunc](),
154 }
155}
156
157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
158 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
159 return nil, ErrEmptyPrompt
160 }
161 if call.SessionID == "" {
162 return nil, ErrSessionMissing
163 }
164
165 // Queue the message if busy
166 if a.IsSessionBusy(call.SessionID) {
167 existing, ok := a.messageQueue.Get(call.SessionID)
168 if !ok {
169 existing = []SessionAgentCall{}
170 }
171 existing = append(existing, call)
172 a.messageQueue.Set(call.SessionID, existing)
173 return nil, nil
174 }
175
176 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
177 agentTools := a.tools.Copy()
178 largeModel := a.largeModel.Get()
179 systemPrompt := a.systemPrompt.Get()
180 promptPrefix := a.systemPromptPrefix.Get()
181 var instructions strings.Builder
182
183 for _, server := range mcp.GetStates() {
184 if server.State != mcp.StateConnected {
185 continue
186 }
187 if s := server.Client.InitializeResult().Instructions; s != "" {
188 instructions.WriteString(s)
189 instructions.WriteString("\n\n")
190 }
191 }
192
193 if s := instructions.String(); s != "" {
194 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
195 }
196
197 if len(agentTools) > 0 {
198 // Add Anthropic caching to the last tool.
199 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
200 }
201
202 agent := fantasy.NewAgent(
203 largeModel.Model,
204 fantasy.WithSystemPrompt(systemPrompt),
205 fantasy.WithTools(agentTools...),
206 fantasy.WithUserAgent(userAgent),
207 )
208
209 sessionLock := sync.Mutex{}
210 currentSession, err := a.sessions.Get(ctx, call.SessionID)
211 if err != nil {
212 return nil, fmt.Errorf("failed to get session: %w", err)
213 }
214
215 msgs, err := a.getSessionMessages(ctx, currentSession)
216 if err != nil {
217 return nil, fmt.Errorf("failed to get session messages: %w", err)
218 }
219
220 var wg sync.WaitGroup
221 // Generate title if first message.
222 if len(msgs) == 0 {
223 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
224 wg.Go(func() {
225 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
226 })
227 }
228 defer wg.Wait()
229
230 // Add the user message to the session.
231 _, err = a.createUserMessage(ctx, call)
232 if err != nil {
233 return nil, err
234 }
235
236 // Add the session to the context.
237 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
238
239 genCtx, cancel := context.WithCancel(ctx)
240 a.activeRequests.Set(call.SessionID, cancel)
241
242 defer cancel()
243 defer a.activeRequests.Del(call.SessionID)
244
245 history, files := a.preparePrompt(msgs, call.Attachments...)
246
247 startTime := time.Now()
248 a.eventPromptSent(call.SessionID)
249
250 var currentAssistant *message.Message
251 var shouldSummarize bool
252 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
253 var maxOutputTokens *int64
254 if call.MaxOutputTokens > 0 {
255 maxOutputTokens = &call.MaxOutputTokens
256 }
257 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
258 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
259 Files: files,
260 Messages: history,
261 ProviderOptions: call.ProviderOptions,
262 MaxOutputTokens: maxOutputTokens,
263 TopP: call.TopP,
264 Temperature: call.Temperature,
265 PresencePenalty: call.PresencePenalty,
266 TopK: call.TopK,
267 FrequencyPenalty: call.FrequencyPenalty,
268 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
269 prepared.Messages = options.Messages
270 for i := range prepared.Messages {
271 prepared.Messages[i].ProviderOptions = nil
272 }
273
274 // Use latest tools (updated by SetTools when MCP tools change).
275 prepared.Tools = a.tools.Copy()
276
277 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
278 a.messageQueue.Del(call.SessionID)
279 for _, queued := range queuedCalls {
280 userMessage, createErr := a.createUserMessage(callContext, queued)
281 if createErr != nil {
282 return callContext, prepared, createErr
283 }
284 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
285 }
286
287 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
288
289 lastSystemRoleInx := 0
290 systemMessageUpdated := false
291 for i, msg := range prepared.Messages {
292 // Only add cache control to the last message.
293 if msg.Role == fantasy.MessageRoleSystem {
294 lastSystemRoleInx = i
295 } else if !systemMessageUpdated {
296 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
297 systemMessageUpdated = true
298 }
299 // Than add cache control to the last 2 messages.
300 if i > len(prepared.Messages)-3 {
301 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
302 }
303 }
304
305 if promptPrefix != "" {
306 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
307 }
308
309 var assistantMsg message.Message
310 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
311 Role: message.Assistant,
312 Parts: []message.ContentPart{},
313 Model: largeModel.ModelCfg.Model,
314 Provider: largeModel.ModelCfg.Provider,
315 })
316 if err != nil {
317 return callContext, prepared, err
318 }
319 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
320 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
321 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
322 currentAssistant = &assistantMsg
323 return callContext, prepared, err
324 },
325 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
326 currentAssistant.AppendReasoningContent(reasoning.Text)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnReasoningDelta: func(id string, text string) error {
330 currentAssistant.AppendReasoningContent(text)
331 return a.messages.Update(genCtx, *currentAssistant)
332 },
333 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
334 // handle anthropic signature
335 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
336 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
337 currentAssistant.AppendReasoningSignature(reasoning.Signature)
338 }
339 }
340 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
341 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
342 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
343 }
344 }
345 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
346 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
347 currentAssistant.SetReasoningResponsesData(reasoning)
348 }
349 }
350 currentAssistant.FinishThinking()
351 return a.messages.Update(genCtx, *currentAssistant)
352 },
353 OnTextDelta: func(id string, text string) error {
354 // Strip leading newline from initial text content. This is is
355 // particularly important in non-interactive mode where leading
356 // newlines are very visible.
357 if len(currentAssistant.Parts) == 0 {
358 text = strings.TrimPrefix(text, "\n")
359 }
360
361 currentAssistant.AppendContent(text)
362 return a.messages.Update(genCtx, *currentAssistant)
363 },
364 OnToolInputStart: func(id string, toolName string) error {
365 toolCall := message.ToolCall{
366 ID: id,
367 Name: toolName,
368 ProviderExecuted: false,
369 Finished: false,
370 }
371 currentAssistant.AddToolCall(toolCall)
372 // Use parent ctx instead of genCtx to ensure the update succeeds
373 // even if the request is canceled mid-stream
374 return a.messages.Update(ctx, *currentAssistant)
375 },
376 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
377 // TODO: implement
378 },
379 OnToolCall: func(tc fantasy.ToolCallContent) error {
380 toolCall := message.ToolCall{
381 ID: tc.ToolCallID,
382 Name: tc.ToolName,
383 Input: tc.Input,
384 ProviderExecuted: false,
385 Finished: true,
386 }
387 currentAssistant.AddToolCall(toolCall)
388 // Use parent ctx instead of genCtx to ensure the update succeeds
389 // even if the request is canceled mid-stream
390 return a.messages.Update(ctx, *currentAssistant)
391 },
392 OnToolResult: func(result fantasy.ToolResultContent) error {
393 toolResult := a.convertToToolResult(result)
394 // Use parent ctx instead of genCtx to ensure the message is created
395 // even if the request is canceled mid-stream
396 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
397 Role: message.Tool,
398 Parts: []message.ContentPart{
399 toolResult,
400 },
401 })
402 return createMsgErr
403 },
404 OnStepFinish: func(stepResult fantasy.StepResult) error {
405 finishReason := message.FinishReasonUnknown
406 switch stepResult.FinishReason {
407 case fantasy.FinishReasonLength:
408 finishReason = message.FinishReasonMaxTokens
409 case fantasy.FinishReasonStop:
410 finishReason = message.FinishReasonEndTurn
411 case fantasy.FinishReasonToolCalls:
412 finishReason = message.FinishReasonToolUse
413 }
414 currentAssistant.AddFinish(finishReason, "", "")
415 sessionLock.Lock()
416 defer sessionLock.Unlock()
417
418 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
419 if getSessionErr != nil {
420 return getSessionErr
421 }
422 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
423 _, sessionErr := a.sessions.Save(ctx, updatedSession)
424 if sessionErr != nil {
425 return sessionErr
426 }
427 currentSession = updatedSession
428 return a.messages.Update(genCtx, *currentAssistant)
429 },
430 StopWhen: []fantasy.StopCondition{
431 func(_ []fantasy.StepResult) bool {
432 cw := int64(largeModel.CatwalkCfg.ContextWindow)
433 // If context window is unknown (0), skip auto-summarize
434 // to avoid immediately truncating custom/local models.
435 if cw == 0 {
436 return false
437 }
438 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
439 remaining := cw - tokens
440 var threshold int64
441 if cw > largeContextWindowThreshold {
442 threshold = largeContextWindowBuffer
443 } else {
444 threshold = int64(float64(cw) * smallContextWindowRatio)
445 }
446 if (remaining <= threshold) && !a.disableAutoSummarize {
447 shouldSummarize = true
448 return true
449 }
450 return false
451 },
452 func(steps []fantasy.StepResult) bool {
453 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
454 },
455 },
456 })
457
458 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
459
460 if err != nil {
461 isCancelErr := errors.Is(err, context.Canceled)
462 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
463 if currentAssistant == nil {
464 return result, err
465 }
466 // Ensure we finish thinking on error to close the reasoning state.
467 currentAssistant.FinishThinking()
468 toolCalls := currentAssistant.ToolCalls()
469 // INFO: we use the parent context here because the genCtx has been cancelled.
470 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
471 if createErr != nil {
472 return nil, createErr
473 }
474 for _, tc := range toolCalls {
475 if !tc.Finished {
476 tc.Finished = true
477 tc.Input = "{}"
478 currentAssistant.AddToolCall(tc)
479 updateErr := a.messages.Update(ctx, *currentAssistant)
480 if updateErr != nil {
481 return nil, updateErr
482 }
483 }
484
485 found := false
486 for _, msg := range msgs {
487 if msg.Role == message.Tool {
488 for _, tr := range msg.ToolResults() {
489 if tr.ToolCallID == tc.ID {
490 found = true
491 break
492 }
493 }
494 }
495 if found {
496 break
497 }
498 }
499 if found {
500 continue
501 }
502 content := "There was an error while executing the tool"
503 if isCancelErr {
504 content = "Error: user cancelled assistant tool calling"
505 } else if isPermissionErr {
506 content = "User denied permission"
507 }
508 toolResult := message.ToolResult{
509 ToolCallID: tc.ID,
510 Name: tc.Name,
511 Content: content,
512 IsError: true,
513 }
514 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
515 Role: message.Tool,
516 Parts: []message.ContentPart{
517 toolResult,
518 },
519 })
520 if createErr != nil {
521 return nil, createErr
522 }
523 }
524 var fantasyErr *fantasy.Error
525 var providerErr *fantasy.ProviderError
526 const defaultTitle = "Provider Error"
527 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
528 if isCancelErr {
529 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
530 } else if isPermissionErr {
531 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
532 } else if errors.Is(err, hyper.ErrUnauthorized) {
533 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Authentication with Hyper failed. Please run "crush auth" to re-authenticate.`)
534 } else if errors.Is(err, hyper.ErrNoCredits) {
535 url := hyper.BaseURL()
536 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
537 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
538 } else if errors.As(err, &providerErr) {
539 if providerErr.Message == "The requested model is not supported." {
540 url := "https://github.com/settings/copilot/features"
541 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
542 currentAssistant.AddFinish(
543 message.FinishReasonError,
544 "Copilot model not enabled",
545 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
546 )
547 } else {
548 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
549 }
550 } else if errors.As(err, &fantasyErr) {
551 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
552 } else {
553 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
554 }
555 // Note: we use the parent context here because the genCtx has been
556 // cancelled.
557 updateErr := a.messages.Update(ctx, *currentAssistant)
558 if updateErr != nil {
559 return nil, updateErr
560 }
561 return nil, err
562 }
563
564 // Send notification that agent has finished its turn (skip for
565 // nested/non-interactive sessions).
566 if !call.NonInteractive && a.notify != nil {
567 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
568 SessionID: call.SessionID,
569 SessionTitle: currentSession.Title,
570 Type: notify.TypeAgentFinished,
571 })
572 }
573
574 if shouldSummarize {
575 a.activeRequests.Del(call.SessionID)
576 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
577 return nil, summarizeErr
578 }
579 // If the agent wasn't done...
580 if len(currentAssistant.ToolCalls()) > 0 {
581 existing, ok := a.messageQueue.Get(call.SessionID)
582 if !ok {
583 existing = []SessionAgentCall{}
584 }
585 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
586 existing = append(existing, call)
587 a.messageQueue.Set(call.SessionID, existing)
588 }
589 }
590
591 // Release active request before processing queued messages.
592 a.activeRequests.Del(call.SessionID)
593 cancel()
594
595 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
596 if !ok || len(queuedMessages) == 0 {
597 return result, err
598 }
599 // There are queued messages restart the loop.
600 firstQueuedMessage := queuedMessages[0]
601 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
602 return a.Run(ctx, firstQueuedMessage)
603}
604
605func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
606 if a.IsSessionBusy(sessionID) {
607 return ErrSessionBusy
608 }
609
610 // Copy mutable fields under lock to avoid races with SetModels.
611 largeModel := a.largeModel.Get()
612 systemPromptPrefix := a.systemPromptPrefix.Get()
613
614 currentSession, err := a.sessions.Get(ctx, sessionID)
615 if err != nil {
616 return fmt.Errorf("failed to get session: %w", err)
617 }
618 msgs, err := a.getSessionMessages(ctx, currentSession)
619 if err != nil {
620 return err
621 }
622 if len(msgs) == 0 {
623 // Nothing to summarize.
624 return nil
625 }
626
627 aiMsgs, _ := a.preparePrompt(msgs)
628
629 genCtx, cancel := context.WithCancel(ctx)
630 a.activeRequests.Set(sessionID, cancel)
631 defer a.activeRequests.Del(sessionID)
632 defer cancel()
633
634 agent := fantasy.NewAgent(largeModel.Model,
635 fantasy.WithSystemPrompt(string(summaryPrompt)),
636 fantasy.WithUserAgent(userAgent),
637 )
638 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
639 Role: message.Assistant,
640 Model: largeModel.Model.Model(),
641 Provider: largeModel.Model.Provider(),
642 IsSummaryMessage: true,
643 })
644 if err != nil {
645 return err
646 }
647
648 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
649
650 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
651 Prompt: summaryPromptText,
652 Messages: aiMsgs,
653 ProviderOptions: opts,
654 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
655 prepared.Messages = options.Messages
656 if systemPromptPrefix != "" {
657 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
658 }
659 return callContext, prepared, nil
660 },
661 OnReasoningDelta: func(id string, text string) error {
662 summaryMessage.AppendReasoningContent(text)
663 return a.messages.Update(genCtx, summaryMessage)
664 },
665 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
666 // Handle anthropic signature.
667 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
668 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
669 summaryMessage.AppendReasoningSignature(signature.Signature)
670 }
671 }
672 summaryMessage.FinishThinking()
673 return a.messages.Update(genCtx, summaryMessage)
674 },
675 OnTextDelta: func(id, text string) error {
676 summaryMessage.AppendContent(text)
677 return a.messages.Update(genCtx, summaryMessage)
678 },
679 })
680 if err != nil {
681 isCancelErr := errors.Is(err, context.Canceled)
682 if isCancelErr {
683 // User cancelled summarize we need to remove the summary message.
684 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
685 return deleteErr
686 }
687 return err
688 }
689
690 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
691 err = a.messages.Update(genCtx, summaryMessage)
692 if err != nil {
693 return err
694 }
695
696 var openrouterCost *float64
697 for _, step := range resp.Steps {
698 stepCost := a.openrouterCost(step.ProviderMetadata)
699 if stepCost != nil {
700 newCost := *stepCost
701 if openrouterCost != nil {
702 newCost += *openrouterCost
703 }
704 openrouterCost = &newCost
705 }
706 }
707
708 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
709
710 // Just in case, get just the last usage info.
711 usage := resp.Response.Usage
712 currentSession.SummaryMessageID = summaryMessage.ID
713 currentSession.CompletionTokens = usage.OutputTokens
714 currentSession.PromptTokens = 0
715 _, err = a.sessions.Save(genCtx, currentSession)
716 return err
717}
718
719func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
720 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
721 return fantasy.ProviderOptions{}
722 }
723 return fantasy.ProviderOptions{
724 anthropic.Name: &anthropic.ProviderCacheControlOptions{
725 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
726 },
727 bedrock.Name: &anthropic.ProviderCacheControlOptions{
728 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
729 },
730 vercel.Name: &anthropic.ProviderCacheControlOptions{
731 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
732 },
733 }
734}
735
736func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
737 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
738 var attachmentParts []message.ContentPart
739 for _, attachment := range call.Attachments {
740 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
741 }
742 parts = append(parts, attachmentParts...)
743 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
744 Role: message.User,
745 Parts: parts,
746 })
747 if err != nil {
748 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
749 }
750 return msg, nil
751}
752
753func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
754 var history []fantasy.Message
755 if !a.isSubAgent {
756 history = append(history, fantasy.NewUserMessage(
757 fmt.Sprintf("<system_reminder>%s</system_reminder>",
758 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
759If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
760If not, please feel free to ignore. Again do not mention this message to the user.`,
761 ),
762 ))
763 }
764 for _, m := range msgs {
765 if len(m.Parts) == 0 {
766 continue
767 }
768 // Assistant message without content or tool calls (cancelled before it
769 // returned anything).
770 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
771 continue
772 }
773 history = append(history, m.ToAIMessage()...)
774 }
775
776 var files []fantasy.FilePart
777 for _, attachment := range attachments {
778 if attachment.IsText() {
779 continue
780 }
781 files = append(files, fantasy.FilePart{
782 Filename: attachment.FileName,
783 Data: attachment.Content,
784 MediaType: attachment.MimeType,
785 })
786 }
787
788 return history, files
789}
790
791func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
792 msgs, err := a.messages.List(ctx, session.ID)
793 if err != nil {
794 return nil, fmt.Errorf("failed to list messages: %w", err)
795 }
796
797 if session.SummaryMessageID != "" {
798 summaryMsgIndex := -1
799 for i, msg := range msgs {
800 if msg.ID == session.SummaryMessageID {
801 summaryMsgIndex = i
802 break
803 }
804 }
805 if summaryMsgIndex != -1 {
806 msgs = msgs[summaryMsgIndex:]
807 msgs[0].Role = message.User
808 }
809 }
810 return msgs, nil
811}
812
813// generateTitle generates a session titled based on the initial prompt.
814func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
815 if userPrompt == "" {
816 return
817 }
818
819 smallModel := a.smallModel.Get()
820 largeModel := a.largeModel.Get()
821 systemPromptPrefix := a.systemPromptPrefix.Get()
822
823 var maxOutputTokens int64 = 40
824 if smallModel.CatwalkCfg.CanReason {
825 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
826 }
827
828 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
829 return fantasy.NewAgent(m,
830 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
831 fantasy.WithMaxOutputTokens(tok),
832 fantasy.WithUserAgent(userAgent),
833 )
834 }
835
836 streamCall := fantasy.AgentStreamCall{
837 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
838 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
839 prepared.Messages = opts.Messages
840 if systemPromptPrefix != "" {
841 prepared.Messages = append([]fantasy.Message{
842 fantasy.NewSystemMessage(systemPromptPrefix),
843 }, prepared.Messages...)
844 }
845 return callCtx, prepared, nil
846 },
847 }
848
849 // Use the small model to generate the title.
850 model := smallModel
851 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
852 resp, err := agent.Stream(ctx, streamCall)
853 if err == nil {
854 // We successfully generated a title with the small model.
855 slog.Debug("Generated title with small model")
856 } else {
857 // It didn't work. Let's try with the big model.
858 slog.Error("Error generating title with small model; trying big model", "err", err)
859 model = largeModel
860 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
861 resp, err = agent.Stream(ctx, streamCall)
862 if err == nil {
863 slog.Debug("Generated title with large model")
864 } else {
865 // Welp, the large model didn't work either. Use the default
866 // session name and return.
867 slog.Error("Error generating title with large model", "err", err)
868 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
869 if saveErr != nil {
870 slog.Error("Failed to save session title", "error", saveErr)
871 }
872 return
873 }
874 }
875
876 if resp == nil {
877 // Actually, we didn't get a response so we can't. Use the default
878 // session name and return.
879 slog.Error("Response is nil; can't generate title")
880 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
881 if saveErr != nil {
882 slog.Error("Failed to save session title", "error", saveErr)
883 }
884 return
885 }
886
887 // Clean up title.
888 var title string
889 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
890
891 // Remove thinking tags if present.
892 title = thinkTagRegex.ReplaceAllString(title, "")
893
894 title = strings.TrimSpace(title)
895 title = cmp.Or(title, DefaultSessionName)
896
897 // Calculate usage and cost.
898 var openrouterCost *float64
899 for _, step := range resp.Steps {
900 stepCost := a.openrouterCost(step.ProviderMetadata)
901 if stepCost != nil {
902 newCost := *stepCost
903 if openrouterCost != nil {
904 newCost += *openrouterCost
905 }
906 openrouterCost = &newCost
907 }
908 }
909
910 modelConfig := model.CatwalkCfg
911 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
912 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
913 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
914 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
915
916 // Use override cost if available (e.g., from OpenRouter).
917 if openrouterCost != nil {
918 cost = *openrouterCost
919 }
920
921 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
922 completionTokens := resp.TotalUsage.OutputTokens
923
924 // Atomically update only title and usage fields to avoid overriding other
925 // concurrent session updates.
926 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
927 if saveErr != nil {
928 slog.Error("Failed to save session title and usage", "error", saveErr)
929 return
930 }
931}
932
933func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
934 openrouterMetadata, ok := metadata[openrouter.Name]
935 if !ok {
936 return nil
937 }
938
939 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
940 if !ok {
941 return nil
942 }
943 return &opts.Usage.Cost
944}
945
946func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
947 modelConfig := model.CatwalkCfg
948 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
949 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
950 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
951 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
952
953 a.eventTokensUsed(session.ID, model, usage, cost)
954
955 if overrideCost != nil {
956 session.Cost += *overrideCost
957 } else {
958 session.Cost += cost
959 }
960
961 session.CompletionTokens = usage.OutputTokens
962 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
963}
964
965func (a *sessionAgent) Cancel(sessionID string) {
966 // Cancel regular requests. Don't use Take() here - we need the entry to
967 // remain in activeRequests so IsBusy() returns true until the goroutine
968 // fully completes (including error handling that may access the DB).
969 // The defer in processRequest will clean up the entry.
970 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
971 slog.Debug("Request cancellation initiated", "session_id", sessionID)
972 cancel()
973 }
974
975 // Also check for summarize requests.
976 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
977 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
978 cancel()
979 }
980
981 if a.QueuedPrompts(sessionID) > 0 {
982 slog.Debug("Clearing queued prompts", "session_id", sessionID)
983 a.messageQueue.Del(sessionID)
984 }
985}
986
987func (a *sessionAgent) ClearQueue(sessionID string) {
988 if a.QueuedPrompts(sessionID) > 0 {
989 slog.Debug("Clearing queued prompts", "session_id", sessionID)
990 a.messageQueue.Del(sessionID)
991 }
992}
993
994func (a *sessionAgent) CancelAll() {
995 if !a.IsBusy() {
996 return
997 }
998 for key := range a.activeRequests.Seq2() {
999 a.Cancel(key) // key is sessionID
1000 }
1001
1002 timeout := time.After(5 * time.Second)
1003 for a.IsBusy() {
1004 select {
1005 case <-timeout:
1006 return
1007 default:
1008 time.Sleep(200 * time.Millisecond)
1009 }
1010 }
1011}
1012
1013func (a *sessionAgent) IsBusy() bool {
1014 var busy bool
1015 for cancelFunc := range a.activeRequests.Seq() {
1016 if cancelFunc != nil {
1017 busy = true
1018 break
1019 }
1020 }
1021 return busy
1022}
1023
1024func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1025 _, busy := a.activeRequests.Get(sessionID)
1026 return busy
1027}
1028
1029func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1030 l, ok := a.messageQueue.Get(sessionID)
1031 if !ok {
1032 return 0
1033 }
1034 return len(l)
1035}
1036
1037func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1038 l, ok := a.messageQueue.Get(sessionID)
1039 if !ok {
1040 return nil
1041 }
1042 prompts := make([]string, len(l))
1043 for i, call := range l {
1044 prompts[i] = call.Prompt
1045 }
1046 return prompts
1047}
1048
1049func (a *sessionAgent) SetModels(large Model, small Model) {
1050 a.largeModel.Set(large)
1051 a.smallModel.Set(small)
1052}
1053
1054func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1055 a.tools.SetSlice(tools)
1056}
1057
1058func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1059 a.systemPrompt.Set(systemPrompt)
1060}
1061
1062func (a *sessionAgent) Model() Model {
1063 return a.largeModel.Get()
1064}
1065
1066// convertToToolResult converts a fantasy tool result to a message tool result.
1067func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1068 baseResult := message.ToolResult{
1069 ToolCallID: result.ToolCallID,
1070 Name: result.ToolName,
1071 Metadata: result.ClientMetadata,
1072 }
1073
1074 switch result.Result.GetType() {
1075 case fantasy.ToolResultContentTypeText:
1076 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1077 baseResult.Content = r.Text
1078 }
1079 case fantasy.ToolResultContentTypeError:
1080 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1081 baseResult.Content = r.Error.Error()
1082 baseResult.IsError = true
1083 }
1084 case fantasy.ToolResultContentTypeMedia:
1085 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1086 content := r.Text
1087 if content == "" {
1088 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1089 }
1090 baseResult.Content = content
1091 baseResult.Data = r.Data
1092 baseResult.MIMEType = r.MediaType
1093 }
1094 }
1095
1096 return baseResult
1097}
1098
1099// workaroundProviderMediaLimitations converts media content in tool results to
1100// user messages for providers that don't natively support images in tool results.
1101//
1102// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1103// don't support sending images/media in tool result messages - they only accept
1104// text in tool results. However, they DO support images in user messages.
1105//
1106// If we send media in tool results to these providers, the API returns an error.
1107//
1108// Solution: For these providers, we:
1109// 1. Replace the media in the tool result with a text placeholder
1110// 2. Inject a user message immediately after with the image as a file attachment
1111// 3. This maintains the tool execution flow while working around API limitations
1112//
1113// Anthropic and Bedrock support images natively in tool results, so we skip
1114// this workaround for them.
1115//
1116// Example transformation:
1117//
1118// BEFORE: [tool result: image data]
1119// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1120func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1121 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1122 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1123
1124 if providerSupportsMedia {
1125 return messages
1126 }
1127
1128 convertedMessages := make([]fantasy.Message, 0, len(messages))
1129
1130 for _, msg := range messages {
1131 if msg.Role != fantasy.MessageRoleTool {
1132 convertedMessages = append(convertedMessages, msg)
1133 continue
1134 }
1135
1136 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1137 var mediaFiles []fantasy.FilePart
1138
1139 for _, part := range msg.Content {
1140 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1141 if !ok {
1142 textParts = append(textParts, part)
1143 continue
1144 }
1145
1146 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1147 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1148 if err != nil {
1149 slog.Warn("Failed to decode media data", "error", err)
1150 textParts = append(textParts, part)
1151 continue
1152 }
1153
1154 mediaFiles = append(mediaFiles, fantasy.FilePart{
1155 Data: decoded,
1156 MediaType: media.MediaType,
1157 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1158 })
1159
1160 textParts = append(textParts, fantasy.ToolResultPart{
1161 ToolCallID: toolResult.ToolCallID,
1162 Output: fantasy.ToolResultOutputContentText{
1163 Text: "[Image/media content loaded - see attached file]",
1164 },
1165 ProviderOptions: toolResult.ProviderOptions,
1166 })
1167 } else {
1168 textParts = append(textParts, part)
1169 }
1170 }
1171
1172 convertedMessages = append(convertedMessages, fantasy.Message{
1173 Role: fantasy.MessageRoleTool,
1174 Content: textParts,
1175 })
1176
1177 if len(mediaFiles) > 0 {
1178 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1179 "Here is the media content from the tool result:",
1180 mediaFiles...,
1181 ))
1182 }
1183 }
1184
1185 return convertedMessages
1186}
1187
1188// buildSummaryPrompt constructs the prompt text for session summarization.
1189func buildSummaryPrompt(todos []session.Todo) string {
1190 var sb strings.Builder
1191 sb.WriteString("Provide a detailed summary of our conversation above.")
1192 if len(todos) > 0 {
1193 sb.WriteString("\n\n## Current Todo List\n\n")
1194 for _, t := range todos {
1195 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1196 }
1197 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1198 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1199 }
1200 return sb.String()
1201}