1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
68
69type SessionAgentCall struct {
70 SessionID string
71 Prompt string
72 ProviderOptions fantasy.ProviderOptions
73 Attachments []message.Attachment
74 MaxOutputTokens int64
75 Temperature *float64
76 TopP *float64
77 TopK *int64
78 FrequencyPenalty *float64
79 PresencePenalty *float64
80 NonInteractive bool
81}
82
83type SessionAgent interface {
84 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
85 SetModels(large Model, small Model)
86 SetTools(tools []fantasy.AgentTool)
87 SetSystemPrompt(systemPrompt string)
88 Cancel(sessionID string)
89 CancelAll()
90 IsSessionBusy(sessionID string) bool
91 IsBusy() bool
92 QueuedPrompts(sessionID string) int
93 QueuedPromptsList(sessionID string) []string
94 ClearQueue(sessionID string)
95 Summarize(context.Context, string, fantasy.ProviderOptions) error
96 Model() Model
97}
98
99type Model struct {
100 Model fantasy.LanguageModel
101 CatwalkCfg catwalk.Model
102 ModelCfg config.SelectedModel
103}
104
105type sessionAgent struct {
106 largeModel *csync.Value[Model]
107 smallModel *csync.Value[Model]
108 systemPromptPrefix *csync.Value[string]
109 systemPrompt *csync.Value[string]
110 tools *csync.Slice[fantasy.AgentTool]
111
112 isSubAgent bool
113 sessions session.Service
114 messages message.Service
115 disableAutoSummarize bool
116 isYolo bool
117 notify pubsub.Publisher[notify.Notification]
118
119 messageQueue *csync.Map[string, []SessionAgentCall]
120 activeRequests *csync.Map[string, context.CancelFunc]
121}
122
123type SessionAgentOptions struct {
124 LargeModel Model
125 SmallModel Model
126 SystemPromptPrefix string
127 SystemPrompt string
128 IsSubAgent bool
129 DisableAutoSummarize bool
130 IsYolo bool
131 Sessions session.Service
132 Messages message.Service
133 Tools []fantasy.AgentTool
134 Notify pubsub.Publisher[notify.Notification]
135}
136
137func NewSessionAgent(
138 opts SessionAgentOptions,
139) SessionAgent {
140 return &sessionAgent{
141 largeModel: csync.NewValue(opts.LargeModel),
142 smallModel: csync.NewValue(opts.SmallModel),
143 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
144 systemPrompt: csync.NewValue(opts.SystemPrompt),
145 isSubAgent: opts.IsSubAgent,
146 sessions: opts.Sessions,
147 messages: opts.Messages,
148 disableAutoSummarize: opts.DisableAutoSummarize,
149 tools: csync.NewSliceFrom(opts.Tools),
150 isYolo: opts.IsYolo,
151 notify: opts.Notify,
152 messageQueue: csync.NewMap[string, []SessionAgentCall](),
153 activeRequests: csync.NewMap[string, context.CancelFunc](),
154 }
155}
156
157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
158 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
159 return nil, ErrEmptyPrompt
160 }
161 if call.SessionID == "" {
162 return nil, ErrSessionMissing
163 }
164
165 // Queue the message if busy
166 if a.IsSessionBusy(call.SessionID) {
167 existing, ok := a.messageQueue.Get(call.SessionID)
168 if !ok {
169 existing = []SessionAgentCall{}
170 }
171 existing = append(existing, call)
172 a.messageQueue.Set(call.SessionID, existing)
173 return nil, nil
174 }
175
176 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
177 agentTools := a.tools.Copy()
178 largeModel := a.largeModel.Get()
179 systemPrompt := a.systemPrompt.Get()
180 promptPrefix := a.systemPromptPrefix.Get()
181 var instructions strings.Builder
182
183 for _, server := range mcp.GetStates() {
184 if server.State != mcp.StateConnected {
185 continue
186 }
187 if s := server.Client.InitializeResult().Instructions; s != "" {
188 instructions.WriteString(s)
189 instructions.WriteString("\n\n")
190 }
191 }
192
193 if s := instructions.String(); s != "" {
194 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
195 }
196
197 if len(agentTools) > 0 {
198 // Add Anthropic caching to the last tool.
199 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
200 }
201
202 agent := fantasy.NewAgent(
203 largeModel.Model,
204 fantasy.WithSystemPrompt(systemPrompt),
205 fantasy.WithTools(agentTools...),
206 fantasy.WithUserAgent(userAgent),
207 )
208
209 sessionLock := sync.Mutex{}
210 currentSession, err := a.sessions.Get(ctx, call.SessionID)
211 if err != nil {
212 return nil, fmt.Errorf("failed to get session: %w", err)
213 }
214
215 msgs, err := a.getSessionMessages(ctx, currentSession)
216 if err != nil {
217 return nil, fmt.Errorf("failed to get session messages: %w", err)
218 }
219
220 var wg sync.WaitGroup
221 // Generate title if first message.
222 if len(msgs) == 0 {
223 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
224 wg.Go(func() {
225 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
226 })
227 }
228 defer wg.Wait()
229
230 // Add the user message to the session.
231 _, err = a.createUserMessage(ctx, call)
232 if err != nil {
233 return nil, err
234 }
235
236 // Add the session to the context.
237 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
238
239 genCtx, cancel := context.WithCancel(ctx)
240 a.activeRequests.Set(call.SessionID, cancel)
241
242 defer cancel()
243 defer a.activeRequests.Del(call.SessionID)
244
245 history, files := a.preparePrompt(msgs, call.Attachments...)
246
247 startTime := time.Now()
248 a.eventPromptSent(call.SessionID)
249
250 var currentAssistant *message.Message
251 var shouldSummarize bool
252 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
253 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
254 Files: files,
255 Messages: history,
256 ProviderOptions: call.ProviderOptions,
257 MaxOutputTokens: &call.MaxOutputTokens,
258 TopP: call.TopP,
259 Temperature: call.Temperature,
260 PresencePenalty: call.PresencePenalty,
261 TopK: call.TopK,
262 FrequencyPenalty: call.FrequencyPenalty,
263 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
264 prepared.Messages = options.Messages
265 for i := range prepared.Messages {
266 prepared.Messages[i].ProviderOptions = nil
267 }
268
269 // Use latest tools (updated by SetTools when MCP tools change).
270 prepared.Tools = a.tools.Copy()
271
272 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
273 a.messageQueue.Del(call.SessionID)
274 for _, queued := range queuedCalls {
275 userMessage, createErr := a.createUserMessage(callContext, queued)
276 if createErr != nil {
277 return callContext, prepared, createErr
278 }
279 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
280 }
281
282 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
283
284 lastSystemRoleInx := 0
285 systemMessageUpdated := false
286 for i, msg := range prepared.Messages {
287 // Only add cache control to the last message.
288 if msg.Role == fantasy.MessageRoleSystem {
289 lastSystemRoleInx = i
290 } else if !systemMessageUpdated {
291 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
292 systemMessageUpdated = true
293 }
294 // Than add cache control to the last 2 messages.
295 if i > len(prepared.Messages)-3 {
296 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
297 }
298 }
299
300 if promptPrefix != "" {
301 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
302 }
303
304 var assistantMsg message.Message
305 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
306 Role: message.Assistant,
307 Parts: []message.ContentPart{},
308 Model: largeModel.ModelCfg.Model,
309 Provider: largeModel.ModelCfg.Provider,
310 })
311 if err != nil {
312 return callContext, prepared, err
313 }
314 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
315 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
316 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
317 currentAssistant = &assistantMsg
318 return callContext, prepared, err
319 },
320 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
321 currentAssistant.AppendReasoningContent(reasoning.Text)
322 return a.messages.Update(genCtx, *currentAssistant)
323 },
324 OnReasoningDelta: func(id string, text string) error {
325 currentAssistant.AppendReasoningContent(text)
326 return a.messages.Update(genCtx, *currentAssistant)
327 },
328 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
329 // handle anthropic signature
330 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
331 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
332 currentAssistant.AppendReasoningSignature(reasoning.Signature)
333 }
334 }
335 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
336 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
337 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
338 }
339 }
340 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
341 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
342 currentAssistant.SetReasoningResponsesData(reasoning)
343 }
344 }
345 currentAssistant.FinishThinking()
346 return a.messages.Update(genCtx, *currentAssistant)
347 },
348 OnTextDelta: func(id string, text string) error {
349 // Strip leading newline from initial text content. This is is
350 // particularly important in non-interactive mode where leading
351 // newlines are very visible.
352 if len(currentAssistant.Parts) == 0 {
353 text = strings.TrimPrefix(text, "\n")
354 }
355
356 currentAssistant.AppendContent(text)
357 return a.messages.Update(genCtx, *currentAssistant)
358 },
359 OnToolInputStart: func(id string, toolName string) error {
360 toolCall := message.ToolCall{
361 ID: id,
362 Name: toolName,
363 ProviderExecuted: false,
364 Finished: false,
365 }
366 currentAssistant.AddToolCall(toolCall)
367 // Use parent ctx instead of genCtx to ensure the update succeeds
368 // even if the request is canceled mid-stream
369 return a.messages.Update(ctx, *currentAssistant)
370 },
371 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
372 // TODO: implement
373 },
374 OnToolCall: func(tc fantasy.ToolCallContent) error {
375 toolCall := message.ToolCall{
376 ID: tc.ToolCallID,
377 Name: tc.ToolName,
378 Input: tc.Input,
379 ProviderExecuted: false,
380 Finished: true,
381 }
382 currentAssistant.AddToolCall(toolCall)
383 // Use parent ctx instead of genCtx to ensure the update succeeds
384 // even if the request is canceled mid-stream
385 return a.messages.Update(ctx, *currentAssistant)
386 },
387 OnToolResult: func(result fantasy.ToolResultContent) error {
388 toolResult := a.convertToToolResult(result)
389 // Use parent ctx instead of genCtx to ensure the message is created
390 // even if the request is canceled mid-stream
391 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
392 Role: message.Tool,
393 Parts: []message.ContentPart{
394 toolResult,
395 },
396 })
397 return createMsgErr
398 },
399 OnStepFinish: func(stepResult fantasy.StepResult) error {
400 finishReason := message.FinishReasonUnknown
401 switch stepResult.FinishReason {
402 case fantasy.FinishReasonLength:
403 finishReason = message.FinishReasonMaxTokens
404 case fantasy.FinishReasonStop:
405 finishReason = message.FinishReasonEndTurn
406 case fantasy.FinishReasonToolCalls:
407 finishReason = message.FinishReasonToolUse
408 }
409 currentAssistant.AddFinish(finishReason, "", "")
410 sessionLock.Lock()
411 defer sessionLock.Unlock()
412
413 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
414 if getSessionErr != nil {
415 return getSessionErr
416 }
417 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
418 _, sessionErr := a.sessions.Save(ctx, updatedSession)
419 if sessionErr != nil {
420 return sessionErr
421 }
422 currentSession = updatedSession
423 return a.messages.Update(genCtx, *currentAssistant)
424 },
425 StopWhen: []fantasy.StopCondition{
426 func(_ []fantasy.StepResult) bool {
427 cw := int64(largeModel.CatwalkCfg.ContextWindow)
428 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
429 remaining := cw - tokens
430 var threshold int64
431 if cw > largeContextWindowThreshold {
432 threshold = largeContextWindowBuffer
433 } else {
434 threshold = int64(float64(cw) * smallContextWindowRatio)
435 }
436 if (remaining <= threshold) && !a.disableAutoSummarize {
437 shouldSummarize = true
438 return true
439 }
440 return false
441 },
442 func(steps []fantasy.StepResult) bool {
443 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
444 },
445 },
446 })
447
448 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
449
450 if err != nil {
451 isCancelErr := errors.Is(err, context.Canceled)
452 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
453 if currentAssistant == nil {
454 return result, err
455 }
456 // Ensure we finish thinking on error to close the reasoning state.
457 currentAssistant.FinishThinking()
458 toolCalls := currentAssistant.ToolCalls()
459 // INFO: we use the parent context here because the genCtx has been cancelled.
460 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
461 if createErr != nil {
462 return nil, createErr
463 }
464 for _, tc := range toolCalls {
465 if !tc.Finished {
466 tc.Finished = true
467 tc.Input = "{}"
468 currentAssistant.AddToolCall(tc)
469 updateErr := a.messages.Update(ctx, *currentAssistant)
470 if updateErr != nil {
471 return nil, updateErr
472 }
473 }
474
475 found := false
476 for _, msg := range msgs {
477 if msg.Role == message.Tool {
478 for _, tr := range msg.ToolResults() {
479 if tr.ToolCallID == tc.ID {
480 found = true
481 break
482 }
483 }
484 }
485 if found {
486 break
487 }
488 }
489 if found {
490 continue
491 }
492 content := "There was an error while executing the tool"
493 if isCancelErr {
494 content = "Error: user cancelled assistant tool calling"
495 } else if isPermissionErr {
496 content = "User denied permission"
497 }
498 toolResult := message.ToolResult{
499 ToolCallID: tc.ID,
500 Name: tc.Name,
501 Content: content,
502 IsError: true,
503 }
504 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
505 Role: message.Tool,
506 Parts: []message.ContentPart{
507 toolResult,
508 },
509 })
510 if createErr != nil {
511 return nil, createErr
512 }
513 }
514 var fantasyErr *fantasy.Error
515 var providerErr *fantasy.ProviderError
516 const defaultTitle = "Provider Error"
517 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
518 if isCancelErr {
519 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
520 } else if isPermissionErr {
521 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
522 } else if errors.Is(err, hyper.ErrNoCredits) {
523 url := hyper.BaseURL()
524 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
525 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
526 } else if errors.As(err, &providerErr) {
527 if providerErr.Message == "The requested model is not supported." {
528 url := "https://github.com/settings/copilot/features"
529 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
530 currentAssistant.AddFinish(
531 message.FinishReasonError,
532 "Copilot model not enabled",
533 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
534 )
535 } else {
536 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
537 }
538 } else if errors.As(err, &fantasyErr) {
539 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
540 } else {
541 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
542 }
543 // Note: we use the parent context here because the genCtx has been
544 // cancelled.
545 updateErr := a.messages.Update(ctx, *currentAssistant)
546 if updateErr != nil {
547 return nil, updateErr
548 }
549 return nil, err
550 }
551
552 // Send notification that agent has finished its turn (skip for
553 // nested/non-interactive sessions).
554 if !call.NonInteractive && a.notify != nil {
555 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
556 SessionID: call.SessionID,
557 SessionTitle: currentSession.Title,
558 Type: notify.TypeAgentFinished,
559 })
560 }
561
562 if shouldSummarize {
563 a.activeRequests.Del(call.SessionID)
564 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
565 return nil, summarizeErr
566 }
567 // If the agent wasn't done...
568 if len(currentAssistant.ToolCalls()) > 0 {
569 existing, ok := a.messageQueue.Get(call.SessionID)
570 if !ok {
571 existing = []SessionAgentCall{}
572 }
573 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
574 existing = append(existing, call)
575 a.messageQueue.Set(call.SessionID, existing)
576 }
577 }
578
579 // Release active request before processing queued messages.
580 a.activeRequests.Del(call.SessionID)
581 cancel()
582
583 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
584 if !ok || len(queuedMessages) == 0 {
585 return result, err
586 }
587 // There are queued messages restart the loop.
588 firstQueuedMessage := queuedMessages[0]
589 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
590 return a.Run(ctx, firstQueuedMessage)
591}
592
593func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
594 if a.IsSessionBusy(sessionID) {
595 return ErrSessionBusy
596 }
597
598 // Copy mutable fields under lock to avoid races with SetModels.
599 largeModel := a.largeModel.Get()
600 systemPromptPrefix := a.systemPromptPrefix.Get()
601
602 currentSession, err := a.sessions.Get(ctx, sessionID)
603 if err != nil {
604 return fmt.Errorf("failed to get session: %w", err)
605 }
606 msgs, err := a.getSessionMessages(ctx, currentSession)
607 if err != nil {
608 return err
609 }
610 if len(msgs) == 0 {
611 // Nothing to summarize.
612 return nil
613 }
614
615 aiMsgs, _ := a.preparePrompt(msgs)
616
617 genCtx, cancel := context.WithCancel(ctx)
618 a.activeRequests.Set(sessionID, cancel)
619 defer a.activeRequests.Del(sessionID)
620 defer cancel()
621
622 agent := fantasy.NewAgent(largeModel.Model,
623 fantasy.WithSystemPrompt(string(summaryPrompt)),
624 fantasy.WithUserAgent(userAgent),
625 )
626 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
627 Role: message.Assistant,
628 Model: largeModel.Model.Model(),
629 Provider: largeModel.Model.Provider(),
630 IsSummaryMessage: true,
631 })
632 if err != nil {
633 return err
634 }
635
636 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
637
638 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
639 Prompt: summaryPromptText,
640 Messages: aiMsgs,
641 ProviderOptions: opts,
642 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
643 prepared.Messages = options.Messages
644 if systemPromptPrefix != "" {
645 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
646 }
647 return callContext, prepared, nil
648 },
649 OnReasoningDelta: func(id string, text string) error {
650 summaryMessage.AppendReasoningContent(text)
651 return a.messages.Update(genCtx, summaryMessage)
652 },
653 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
654 // Handle anthropic signature.
655 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
656 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
657 summaryMessage.AppendReasoningSignature(signature.Signature)
658 }
659 }
660 summaryMessage.FinishThinking()
661 return a.messages.Update(genCtx, summaryMessage)
662 },
663 OnTextDelta: func(id, text string) error {
664 summaryMessage.AppendContent(text)
665 return a.messages.Update(genCtx, summaryMessage)
666 },
667 })
668 if err != nil {
669 isCancelErr := errors.Is(err, context.Canceled)
670 if isCancelErr {
671 // User cancelled summarize we need to remove the summary message.
672 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
673 return deleteErr
674 }
675 return err
676 }
677
678 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
679 err = a.messages.Update(genCtx, summaryMessage)
680 if err != nil {
681 return err
682 }
683
684 var openrouterCost *float64
685 for _, step := range resp.Steps {
686 stepCost := a.openrouterCost(step.ProviderMetadata)
687 if stepCost != nil {
688 newCost := *stepCost
689 if openrouterCost != nil {
690 newCost += *openrouterCost
691 }
692 openrouterCost = &newCost
693 }
694 }
695
696 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
697
698 // Just in case, get just the last usage info.
699 usage := resp.Response.Usage
700 currentSession.SummaryMessageID = summaryMessage.ID
701 currentSession.CompletionTokens = usage.OutputTokens
702 currentSession.PromptTokens = 0
703 _, err = a.sessions.Save(genCtx, currentSession)
704 return err
705}
706
707func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
708 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
709 return fantasy.ProviderOptions{}
710 }
711 return fantasy.ProviderOptions{
712 anthropic.Name: &anthropic.ProviderCacheControlOptions{
713 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
714 },
715 bedrock.Name: &anthropic.ProviderCacheControlOptions{
716 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
717 },
718 vercel.Name: &anthropic.ProviderCacheControlOptions{
719 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
720 },
721 }
722}
723
724func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
725 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
726 var attachmentParts []message.ContentPart
727 for _, attachment := range call.Attachments {
728 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
729 }
730 parts = append(parts, attachmentParts...)
731 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
732 Role: message.User,
733 Parts: parts,
734 })
735 if err != nil {
736 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
737 }
738 return msg, nil
739}
740
741func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
742 var history []fantasy.Message
743 if !a.isSubAgent {
744 history = append(history, fantasy.NewUserMessage(
745 fmt.Sprintf("<system_reminder>%s</system_reminder>",
746 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
747If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
748If not, please feel free to ignore. Again do not mention this message to the user.`,
749 ),
750 ))
751 }
752 for _, m := range msgs {
753 if len(m.Parts) == 0 {
754 continue
755 }
756 // Assistant message without content or tool calls (cancelled before it
757 // returned anything).
758 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
759 continue
760 }
761 history = append(history, m.ToAIMessage()...)
762 }
763
764 var files []fantasy.FilePart
765 for _, attachment := range attachments {
766 if attachment.IsText() {
767 continue
768 }
769 files = append(files, fantasy.FilePart{
770 Filename: attachment.FileName,
771 Data: attachment.Content,
772 MediaType: attachment.MimeType,
773 })
774 }
775
776 return history, files
777}
778
779func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
780 msgs, err := a.messages.List(ctx, session.ID)
781 if err != nil {
782 return nil, fmt.Errorf("failed to list messages: %w", err)
783 }
784
785 if session.SummaryMessageID != "" {
786 summaryMsgIndex := -1
787 for i, msg := range msgs {
788 if msg.ID == session.SummaryMessageID {
789 summaryMsgIndex = i
790 break
791 }
792 }
793 if summaryMsgIndex != -1 {
794 msgs = msgs[summaryMsgIndex:]
795 msgs[0].Role = message.User
796 }
797 }
798 return msgs, nil
799}
800
801// generateTitle generates a session titled based on the initial prompt.
802func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
803 if userPrompt == "" {
804 return
805 }
806
807 smallModel := a.smallModel.Get()
808 largeModel := a.largeModel.Get()
809 systemPromptPrefix := a.systemPromptPrefix.Get()
810
811 var maxOutputTokens int64 = 40
812 if smallModel.CatwalkCfg.CanReason {
813 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
814 }
815
816 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
817 return fantasy.NewAgent(m,
818 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
819 fantasy.WithMaxOutputTokens(tok),
820 fantasy.WithUserAgent(userAgent),
821 )
822 }
823
824 streamCall := fantasy.AgentStreamCall{
825 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
826 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
827 prepared.Messages = opts.Messages
828 if systemPromptPrefix != "" {
829 prepared.Messages = append([]fantasy.Message{
830 fantasy.NewSystemMessage(systemPromptPrefix),
831 }, prepared.Messages...)
832 }
833 return callCtx, prepared, nil
834 },
835 }
836
837 // Use the small model to generate the title.
838 model := smallModel
839 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
840 resp, err := agent.Stream(ctx, streamCall)
841 if err == nil {
842 // We successfully generated a title with the small model.
843 slog.Debug("Generated title with small model")
844 } else {
845 // It didn't work. Let's try with the big model.
846 slog.Error("Error generating title with small model; trying big model", "err", err)
847 model = largeModel
848 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
849 resp, err = agent.Stream(ctx, streamCall)
850 if err == nil {
851 slog.Debug("Generated title with large model")
852 } else {
853 // Welp, the large model didn't work either. Use the default
854 // session name and return.
855 slog.Error("Error generating title with large model", "err", err)
856 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
857 if saveErr != nil {
858 slog.Error("Failed to save session title", "error", saveErr)
859 }
860 return
861 }
862 }
863
864 if resp == nil {
865 // Actually, we didn't get a response so we can't. Use the default
866 // session name and return.
867 slog.Error("Response is nil; can't generate title")
868 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
869 if saveErr != nil {
870 slog.Error("Failed to save session title", "error", saveErr)
871 }
872 return
873 }
874
875 // Clean up title.
876 var title string
877 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
878
879 // Remove thinking tags if present.
880 title = thinkTagRegex.ReplaceAllString(title, "")
881
882 title = strings.TrimSpace(title)
883 title = cmp.Or(title, DefaultSessionName)
884
885 // Calculate usage and cost.
886 var openrouterCost *float64
887 for _, step := range resp.Steps {
888 stepCost := a.openrouterCost(step.ProviderMetadata)
889 if stepCost != nil {
890 newCost := *stepCost
891 if openrouterCost != nil {
892 newCost += *openrouterCost
893 }
894 openrouterCost = &newCost
895 }
896 }
897
898 modelConfig := model.CatwalkCfg
899 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
900 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
901 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
902 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
903
904 // Use override cost if available (e.g., from OpenRouter).
905 if openrouterCost != nil {
906 cost = *openrouterCost
907 }
908
909 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
910 completionTokens := resp.TotalUsage.OutputTokens
911
912 // Atomically update only title and usage fields to avoid overriding other
913 // concurrent session updates.
914 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
915 if saveErr != nil {
916 slog.Error("Failed to save session title and usage", "error", saveErr)
917 return
918 }
919}
920
921func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
922 openrouterMetadata, ok := metadata[openrouter.Name]
923 if !ok {
924 return nil
925 }
926
927 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
928 if !ok {
929 return nil
930 }
931 return &opts.Usage.Cost
932}
933
934func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
935 modelConfig := model.CatwalkCfg
936 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
937 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
938 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
939 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
940
941 a.eventTokensUsed(session.ID, model, usage, cost)
942
943 if overrideCost != nil {
944 session.Cost += *overrideCost
945 } else {
946 session.Cost += cost
947 }
948
949 session.CompletionTokens = usage.OutputTokens
950 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
951}
952
953func (a *sessionAgent) Cancel(sessionID string) {
954 // Cancel regular requests. Don't use Take() here - we need the entry to
955 // remain in activeRequests so IsBusy() returns true until the goroutine
956 // fully completes (including error handling that may access the DB).
957 // The defer in processRequest will clean up the entry.
958 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
959 slog.Debug("Request cancellation initiated", "session_id", sessionID)
960 cancel()
961 }
962
963 // Also check for summarize requests.
964 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
965 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
966 cancel()
967 }
968
969 if a.QueuedPrompts(sessionID) > 0 {
970 slog.Debug("Clearing queued prompts", "session_id", sessionID)
971 a.messageQueue.Del(sessionID)
972 }
973}
974
975func (a *sessionAgent) ClearQueue(sessionID string) {
976 if a.QueuedPrompts(sessionID) > 0 {
977 slog.Debug("Clearing queued prompts", "session_id", sessionID)
978 a.messageQueue.Del(sessionID)
979 }
980}
981
982func (a *sessionAgent) CancelAll() {
983 if !a.IsBusy() {
984 return
985 }
986 for key := range a.activeRequests.Seq2() {
987 a.Cancel(key) // key is sessionID
988 }
989
990 timeout := time.After(5 * time.Second)
991 for a.IsBusy() {
992 select {
993 case <-timeout:
994 return
995 default:
996 time.Sleep(200 * time.Millisecond)
997 }
998 }
999}
1000
1001func (a *sessionAgent) IsBusy() bool {
1002 var busy bool
1003 for cancelFunc := range a.activeRequests.Seq() {
1004 if cancelFunc != nil {
1005 busy = true
1006 break
1007 }
1008 }
1009 return busy
1010}
1011
1012func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1013 _, busy := a.activeRequests.Get(sessionID)
1014 return busy
1015}
1016
1017func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1018 l, ok := a.messageQueue.Get(sessionID)
1019 if !ok {
1020 return 0
1021 }
1022 return len(l)
1023}
1024
1025func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1026 l, ok := a.messageQueue.Get(sessionID)
1027 if !ok {
1028 return nil
1029 }
1030 prompts := make([]string, len(l))
1031 for i, call := range l {
1032 prompts[i] = call.Prompt
1033 }
1034 return prompts
1035}
1036
1037func (a *sessionAgent) SetModels(large Model, small Model) {
1038 a.largeModel.Set(large)
1039 a.smallModel.Set(small)
1040}
1041
1042func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1043 a.tools.SetSlice(tools)
1044}
1045
1046func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1047 a.systemPrompt.Set(systemPrompt)
1048}
1049
1050func (a *sessionAgent) Model() Model {
1051 return a.largeModel.Get()
1052}
1053
1054// convertToToolResult converts a fantasy tool result to a message tool result.
1055func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1056 baseResult := message.ToolResult{
1057 ToolCallID: result.ToolCallID,
1058 Name: result.ToolName,
1059 Metadata: result.ClientMetadata,
1060 }
1061
1062 switch result.Result.GetType() {
1063 case fantasy.ToolResultContentTypeText:
1064 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1065 baseResult.Content = r.Text
1066 }
1067 case fantasy.ToolResultContentTypeError:
1068 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1069 baseResult.Content = r.Error.Error()
1070 baseResult.IsError = true
1071 }
1072 case fantasy.ToolResultContentTypeMedia:
1073 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1074 content := r.Text
1075 if content == "" {
1076 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1077 }
1078 baseResult.Content = content
1079 baseResult.Data = r.Data
1080 baseResult.MIMEType = r.MediaType
1081 }
1082 }
1083
1084 return baseResult
1085}
1086
1087// workaroundProviderMediaLimitations converts media content in tool results to
1088// user messages for providers that don't natively support images in tool results.
1089//
1090// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1091// don't support sending images/media in tool result messages - they only accept
1092// text in tool results. However, they DO support images in user messages.
1093//
1094// If we send media in tool results to these providers, the API returns an error.
1095//
1096// Solution: For these providers, we:
1097// 1. Replace the media in the tool result with a text placeholder
1098// 2. Inject a user message immediately after with the image as a file attachment
1099// 3. This maintains the tool execution flow while working around API limitations
1100//
1101// Anthropic and Bedrock support images natively in tool results, so we skip
1102// this workaround for them.
1103//
1104// Example transformation:
1105//
1106// BEFORE: [tool result: image data]
1107// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1108func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1109 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1110 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1111
1112 if providerSupportsMedia {
1113 return messages
1114 }
1115
1116 convertedMessages := make([]fantasy.Message, 0, len(messages))
1117
1118 for _, msg := range messages {
1119 if msg.Role != fantasy.MessageRoleTool {
1120 convertedMessages = append(convertedMessages, msg)
1121 continue
1122 }
1123
1124 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1125 var mediaFiles []fantasy.FilePart
1126
1127 for _, part := range msg.Content {
1128 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1129 if !ok {
1130 textParts = append(textParts, part)
1131 continue
1132 }
1133
1134 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1135 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1136 if err != nil {
1137 slog.Warn("Failed to decode media data", "error", err)
1138 textParts = append(textParts, part)
1139 continue
1140 }
1141
1142 mediaFiles = append(mediaFiles, fantasy.FilePart{
1143 Data: decoded,
1144 MediaType: media.MediaType,
1145 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1146 })
1147
1148 textParts = append(textParts, fantasy.ToolResultPart{
1149 ToolCallID: toolResult.ToolCallID,
1150 Output: fantasy.ToolResultOutputContentText{
1151 Text: "[Image/media content loaded - see attached file]",
1152 },
1153 ProviderOptions: toolResult.ProviderOptions,
1154 })
1155 } else {
1156 textParts = append(textParts, part)
1157 }
1158 }
1159
1160 convertedMessages = append(convertedMessages, fantasy.Message{
1161 Role: fantasy.MessageRoleTool,
1162 Content: textParts,
1163 })
1164
1165 if len(mediaFiles) > 0 {
1166 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1167 "Here is the media content from the tool result:",
1168 mediaFiles...,
1169 ))
1170 }
1171 }
1172
1173 return convertedMessages
1174}
1175
1176// buildSummaryPrompt constructs the prompt text for session summarization.
1177func buildSummaryPrompt(todos []session.Todo) string {
1178 var sb strings.Builder
1179 sb.WriteString("Provide a detailed summary of our conversation above.")
1180 if len(todos) > 0 {
1181 sb.WriteString("\n\n## Current Todo List\n\n")
1182 for _, t := range todos {
1183 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1184 }
1185 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1186 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1187 }
1188 return sb.String()
1189}