1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/tools"
36 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
37 "github.com/charmbracelet/crush/internal/config"
38 "github.com/charmbracelet/crush/internal/csync"
39 "github.com/charmbracelet/crush/internal/message"
40 "github.com/charmbracelet/crush/internal/permission"
41 "github.com/charmbracelet/crush/internal/session"
42 "github.com/charmbracelet/crush/internal/stringext"
43 "github.com/charmbracelet/crush/internal/version"
44 "github.com/charmbracelet/x/exp/charmtone"
45)
46
47const (
48 DefaultSessionName = "Untitled Session"
49
50 // Constants for auto-summarization thresholds
51 largeContextWindowThreshold = 200_000
52 largeContextWindowBuffer = 20_000
53 smallContextWindowRatio = 0.2
54)
55
56//go:embed templates/title.md
57var titlePrompt []byte
58
59//go:embed templates/summary.md
60var summaryPrompt []byte
61
62// Used to remove <think> tags from generated titles.
63var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
64
65type SessionAgentCall struct {
66 SessionID string
67 Prompt string
68 ProviderOptions fantasy.ProviderOptions
69 Attachments []message.Attachment
70 MaxOutputTokens int64
71 Temperature *float64
72 TopP *float64
73 TopK *int64
74 FrequencyPenalty *float64
75 PresencePenalty *float64
76}
77
78type SessionAgent interface {
79 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
80 SetModels(large Model, small Model)
81 SetTools(tools []fantasy.AgentTool)
82 SetSystemPrompt(systemPrompt string)
83 Cancel(sessionID string)
84 CancelAll()
85 IsSessionBusy(sessionID string) bool
86 IsBusy() bool
87 QueuedPrompts(sessionID string) int
88 QueuedPromptsList(sessionID string) []string
89 ClearQueue(sessionID string)
90 Summarize(context.Context, string, fantasy.ProviderOptions) error
91 Model() Model
92}
93
94type Model struct {
95 Model fantasy.LanguageModel
96 CatwalkCfg catwalk.Model
97 ModelCfg config.SelectedModel
98}
99
100type sessionAgent struct {
101 largeModel *csync.Value[Model]
102 smallModel *csync.Value[Model]
103 systemPromptPrefix *csync.Value[string]
104 systemPrompt *csync.Value[string]
105 tools *csync.Slice[fantasy.AgentTool]
106
107 isSubAgent bool
108 sessions session.Service
109 messages message.Service
110 disableAutoSummarize bool
111 isYolo bool
112
113 messageQueue *csync.Map[string, []SessionAgentCall]
114 activeRequests *csync.Map[string, context.CancelFunc]
115}
116
117type SessionAgentOptions struct {
118 LargeModel Model
119 SmallModel Model
120 SystemPromptPrefix string
121 SystemPrompt string
122 IsSubAgent bool
123 DisableAutoSummarize bool
124 IsYolo bool
125 Sessions session.Service
126 Messages message.Service
127 Tools []fantasy.AgentTool
128}
129
130func NewSessionAgent(
131 opts SessionAgentOptions,
132) SessionAgent {
133 return &sessionAgent{
134 largeModel: csync.NewValue(opts.LargeModel),
135 smallModel: csync.NewValue(opts.SmallModel),
136 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
137 systemPrompt: csync.NewValue(opts.SystemPrompt),
138 isSubAgent: opts.IsSubAgent,
139 sessions: opts.Sessions,
140 messages: opts.Messages,
141 disableAutoSummarize: opts.DisableAutoSummarize,
142 tools: csync.NewSliceFrom(opts.Tools),
143 isYolo: opts.IsYolo,
144 messageQueue: csync.NewMap[string, []SessionAgentCall](),
145 activeRequests: csync.NewMap[string, context.CancelFunc](),
146 }
147}
148
149func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
150 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
151 return nil, ErrEmptyPrompt
152 }
153 if call.SessionID == "" {
154 return nil, ErrSessionMissing
155 }
156
157 // Queue the message if busy
158 if a.IsSessionBusy(call.SessionID) {
159 existing, ok := a.messageQueue.Get(call.SessionID)
160 if !ok {
161 existing = []SessionAgentCall{}
162 }
163 existing = append(existing, call)
164 a.messageQueue.Set(call.SessionID, existing)
165 return nil, nil
166 }
167
168 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
169 agentTools := a.tools.Copy()
170 largeModel := a.largeModel.Get()
171 systemPrompt := a.systemPrompt.Get()
172 promptPrefix := a.systemPromptPrefix.Get()
173 var instructions strings.Builder
174
175 for _, server := range mcp.GetStates() {
176 if server.State != mcp.StateConnected {
177 continue
178 }
179 if s := server.Client.InitializeResult().Instructions; s != "" {
180 instructions.WriteString(s)
181 instructions.WriteString("\n\n")
182 }
183 }
184
185 if s := instructions.String(); s != "" {
186 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
187 }
188
189 if len(agentTools) > 0 {
190 // Add Anthropic caching to the last tool.
191 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
192 }
193
194 agent := fantasy.NewAgent(
195 largeModel.Model,
196 fantasy.WithSystemPrompt(systemPrompt),
197 fantasy.WithTools(agentTools...),
198 fantasy.WithUserAgent("Charm Crush/"+version.Version),
199 )
200
201 sessionLock := sync.Mutex{}
202 currentSession, err := a.sessions.Get(ctx, call.SessionID)
203 if err != nil {
204 return nil, fmt.Errorf("failed to get session: %w", err)
205 }
206
207 msgs, err := a.getSessionMessages(ctx, currentSession)
208 if err != nil {
209 return nil, fmt.Errorf("failed to get session messages: %w", err)
210 }
211
212 var wg sync.WaitGroup
213 // Generate title if first message.
214 if len(msgs) == 0 {
215 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
216 wg.Go(func() {
217 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
218 })
219 }
220 defer wg.Wait()
221
222 // Add the user message to the session.
223 _, err = a.createUserMessage(ctx, call)
224 if err != nil {
225 return nil, err
226 }
227
228 // Add the session to the context.
229 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
230
231 genCtx, cancel := context.WithCancel(ctx)
232 a.activeRequests.Set(call.SessionID, cancel)
233
234 defer cancel()
235 defer a.activeRequests.Del(call.SessionID)
236
237 history, files := a.preparePrompt(msgs, call.Attachments...)
238
239 startTime := time.Now()
240 a.eventPromptSent(call.SessionID)
241
242 var currentAssistant *message.Message
243 var shouldSummarize bool
244 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
245 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
246 Files: files,
247 Messages: history,
248 ProviderOptions: call.ProviderOptions,
249 MaxOutputTokens: &call.MaxOutputTokens,
250 TopP: call.TopP,
251 Temperature: call.Temperature,
252 PresencePenalty: call.PresencePenalty,
253 TopK: call.TopK,
254 FrequencyPenalty: call.FrequencyPenalty,
255 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
256 prepared.Messages = options.Messages
257 for i := range prepared.Messages {
258 prepared.Messages[i].ProviderOptions = nil
259 }
260
261 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
262 a.messageQueue.Del(call.SessionID)
263 for _, queued := range queuedCalls {
264 userMessage, createErr := a.createUserMessage(callContext, queued)
265 if createErr != nil {
266 return callContext, prepared, createErr
267 }
268 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
269 }
270
271 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
272
273 lastSystemRoleInx := 0
274 systemMessageUpdated := false
275 for i, msg := range prepared.Messages {
276 // Only add cache control to the last message.
277 if msg.Role == fantasy.MessageRoleSystem {
278 lastSystemRoleInx = i
279 } else if !systemMessageUpdated {
280 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
281 systemMessageUpdated = true
282 }
283 // Than add cache control to the last 2 messages.
284 if i > len(prepared.Messages)-3 {
285 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
286 }
287 }
288
289 if promptPrefix != "" {
290 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
291 }
292
293 var assistantMsg message.Message
294 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
295 Role: message.Assistant,
296 Parts: []message.ContentPart{},
297 Model: largeModel.ModelCfg.Model,
298 Provider: largeModel.ModelCfg.Provider,
299 })
300 if err != nil {
301 return callContext, prepared, err
302 }
303 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
304 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
305 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
306 currentAssistant = &assistantMsg
307 return callContext, prepared, err
308 },
309 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
310 currentAssistant.AppendReasoningContent(reasoning.Text)
311 return a.messages.Update(genCtx, *currentAssistant)
312 },
313 OnReasoningDelta: func(id string, text string) error {
314 currentAssistant.AppendReasoningContent(text)
315 return a.messages.Update(genCtx, *currentAssistant)
316 },
317 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
318 // handle anthropic signature
319 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
320 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
321 currentAssistant.AppendReasoningSignature(reasoning.Signature)
322 }
323 }
324 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
325 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
326 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
327 }
328 }
329 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
330 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
331 currentAssistant.SetReasoningResponsesData(reasoning)
332 }
333 }
334 currentAssistant.FinishThinking()
335 return a.messages.Update(genCtx, *currentAssistant)
336 },
337 OnTextDelta: func(id string, text string) error {
338 // Strip leading newline from initial text content. This is is
339 // particularly important in non-interactive mode where leading
340 // newlines are very visible.
341 if len(currentAssistant.Parts) == 0 {
342 text = strings.TrimPrefix(text, "\n")
343 }
344
345 currentAssistant.AppendContent(text)
346 return a.messages.Update(genCtx, *currentAssistant)
347 },
348 OnToolInputStart: func(id string, toolName string) error {
349 toolCall := message.ToolCall{
350 ID: id,
351 Name: toolName,
352 ProviderExecuted: false,
353 Finished: false,
354 }
355 currentAssistant.AddToolCall(toolCall)
356 return a.messages.Update(genCtx, *currentAssistant)
357 },
358 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
359 // TODO: implement
360 },
361 OnToolCall: func(tc fantasy.ToolCallContent) error {
362 toolCall := message.ToolCall{
363 ID: tc.ToolCallID,
364 Name: tc.ToolName,
365 Input: tc.Input,
366 ProviderExecuted: false,
367 Finished: true,
368 }
369 currentAssistant.AddToolCall(toolCall)
370 return a.messages.Update(genCtx, *currentAssistant)
371 },
372 OnToolResult: func(result fantasy.ToolResultContent) error {
373 toolResult := a.convertToToolResult(result)
374 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
375 Role: message.Tool,
376 Parts: []message.ContentPart{
377 toolResult,
378 },
379 })
380 return createMsgErr
381 },
382 OnStepFinish: func(stepResult fantasy.StepResult) error {
383 finishReason := message.FinishReasonUnknown
384 switch stepResult.FinishReason {
385 case fantasy.FinishReasonLength:
386 finishReason = message.FinishReasonMaxTokens
387 case fantasy.FinishReasonStop:
388 finishReason = message.FinishReasonEndTurn
389 case fantasy.FinishReasonToolCalls:
390 finishReason = message.FinishReasonToolUse
391 }
392 currentAssistant.AddFinish(finishReason, "", "")
393 sessionLock.Lock()
394 defer sessionLock.Unlock()
395
396 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
397 if getSessionErr != nil {
398 return getSessionErr
399 }
400 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
401 _, sessionErr := a.sessions.Save(ctx, updatedSession)
402 if sessionErr != nil {
403 return sessionErr
404 }
405 currentSession = updatedSession
406 return a.messages.Update(genCtx, *currentAssistant)
407 },
408 StopWhen: []fantasy.StopCondition{
409 func(_ []fantasy.StepResult) bool {
410 cw := int64(largeModel.CatwalkCfg.ContextWindow)
411 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
412 remaining := cw - tokens
413 var threshold int64
414 if cw > largeContextWindowThreshold {
415 threshold = largeContextWindowBuffer
416 } else {
417 threshold = int64(float64(cw) * smallContextWindowRatio)
418 }
419 if (remaining <= threshold) && !a.disableAutoSummarize {
420 shouldSummarize = true
421 return true
422 }
423 return false
424 },
425 func(steps []fantasy.StepResult) bool {
426 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
427 },
428 },
429 })
430
431 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
432
433 if err != nil {
434 isCancelErr := errors.Is(err, context.Canceled)
435 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
436 if currentAssistant == nil {
437 return result, err
438 }
439 // Ensure we finish thinking on error to close the reasoning state.
440 currentAssistant.FinishThinking()
441 toolCalls := currentAssistant.ToolCalls()
442 // INFO: we use the parent context here because the genCtx has been cancelled.
443 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
444 if createErr != nil {
445 return nil, createErr
446 }
447 for _, tc := range toolCalls {
448 if !tc.Finished {
449 tc.Finished = true
450 tc.Input = "{}"
451 currentAssistant.AddToolCall(tc)
452 updateErr := a.messages.Update(ctx, *currentAssistant)
453 if updateErr != nil {
454 return nil, updateErr
455 }
456 }
457
458 found := false
459 for _, msg := range msgs {
460 if msg.Role == message.Tool {
461 for _, tr := range msg.ToolResults() {
462 if tr.ToolCallID == tc.ID {
463 found = true
464 break
465 }
466 }
467 }
468 if found {
469 break
470 }
471 }
472 if found {
473 continue
474 }
475 content := "There was an error while executing the tool"
476 if isCancelErr {
477 content = "Tool execution canceled by user"
478 } else if isPermissionErr {
479 content = "User denied permission"
480 }
481 toolResult := message.ToolResult{
482 ToolCallID: tc.ID,
483 Name: tc.Name,
484 Content: content,
485 IsError: true,
486 }
487 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
488 Role: message.Tool,
489 Parts: []message.ContentPart{
490 toolResult,
491 },
492 })
493 if createErr != nil {
494 return nil, createErr
495 }
496 }
497 var fantasyErr *fantasy.Error
498 var providerErr *fantasy.ProviderError
499 const defaultTitle = "Provider Error"
500 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
501 if isCancelErr {
502 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
503 } else if isPermissionErr {
504 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
505 } else if errors.Is(err, hyper.ErrNoCredits) {
506 url := hyper.BaseURL()
507 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
508 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
509 } else if errors.As(err, &providerErr) {
510 if providerErr.Message == "The requested model is not supported." {
511 url := "https://github.com/settings/copilot/features"
512 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
513 currentAssistant.AddFinish(
514 message.FinishReasonError,
515 "Copilot model not enabled",
516 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
517 )
518 } else {
519 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
520 }
521 } else if errors.As(err, &fantasyErr) {
522 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
523 } else {
524 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
525 }
526 // Note: we use the parent context here because the genCtx has been
527 // cancelled.
528 updateErr := a.messages.Update(ctx, *currentAssistant)
529 if updateErr != nil {
530 return nil, updateErr
531 }
532 return nil, err
533 }
534
535 if shouldSummarize {
536 a.activeRequests.Del(call.SessionID)
537 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
538 return nil, summarizeErr
539 }
540 // If the agent wasn't done...
541 if len(currentAssistant.ToolCalls()) > 0 {
542 existing, ok := a.messageQueue.Get(call.SessionID)
543 if !ok {
544 existing = []SessionAgentCall{}
545 }
546 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
547 existing = append(existing, call)
548 a.messageQueue.Set(call.SessionID, existing)
549 }
550 }
551
552 // Release active request before processing queued messages.
553 a.activeRequests.Del(call.SessionID)
554 cancel()
555
556 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
557 if !ok || len(queuedMessages) == 0 {
558 return result, err
559 }
560 // There are queued messages restart the loop.
561 firstQueuedMessage := queuedMessages[0]
562 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
563 return a.Run(ctx, firstQueuedMessage)
564}
565
566func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
567 if a.IsSessionBusy(sessionID) {
568 return ErrSessionBusy
569 }
570
571 // Copy mutable fields under lock to avoid races with SetModels.
572 largeModel := a.largeModel.Get()
573 systemPromptPrefix := a.systemPromptPrefix.Get()
574
575 currentSession, err := a.sessions.Get(ctx, sessionID)
576 if err != nil {
577 return fmt.Errorf("failed to get session: %w", err)
578 }
579 msgs, err := a.getSessionMessages(ctx, currentSession)
580 if err != nil {
581 return err
582 }
583 if len(msgs) == 0 {
584 // Nothing to summarize.
585 return nil
586 }
587
588 aiMsgs, _ := a.preparePrompt(msgs)
589
590 genCtx, cancel := context.WithCancel(ctx)
591 a.activeRequests.Set(sessionID, cancel)
592 defer a.activeRequests.Del(sessionID)
593 defer cancel()
594
595 agent := fantasy.NewAgent(largeModel.Model,
596 fantasy.WithSystemPrompt(string(summaryPrompt)),
597 fantasy.WithUserAgent("Charm Crush/"+version.Version),
598 )
599 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
600 Role: message.Assistant,
601 Model: largeModel.Model.Model(),
602 Provider: largeModel.Model.Provider(),
603 IsSummaryMessage: true,
604 })
605 if err != nil {
606 return err
607 }
608
609 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
610
611 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
612 Prompt: summaryPromptText,
613 Messages: aiMsgs,
614 ProviderOptions: opts,
615 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
616 prepared.Messages = options.Messages
617 if systemPromptPrefix != "" {
618 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
619 }
620 return callContext, prepared, nil
621 },
622 OnReasoningDelta: func(id string, text string) error {
623 summaryMessage.AppendReasoningContent(text)
624 return a.messages.Update(genCtx, summaryMessage)
625 },
626 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
627 // Handle anthropic signature.
628 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
629 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
630 summaryMessage.AppendReasoningSignature(signature.Signature)
631 }
632 }
633 summaryMessage.FinishThinking()
634 return a.messages.Update(genCtx, summaryMessage)
635 },
636 OnTextDelta: func(id, text string) error {
637 summaryMessage.AppendContent(text)
638 return a.messages.Update(genCtx, summaryMessage)
639 },
640 })
641 if err != nil {
642 isCancelErr := errors.Is(err, context.Canceled)
643 if isCancelErr {
644 // User cancelled summarize we need to remove the summary message.
645 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
646 return deleteErr
647 }
648 return err
649 }
650
651 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
652 err = a.messages.Update(genCtx, summaryMessage)
653 if err != nil {
654 return err
655 }
656
657 var openrouterCost *float64
658 for _, step := range resp.Steps {
659 stepCost := a.openrouterCost(step.ProviderMetadata)
660 if stepCost != nil {
661 newCost := *stepCost
662 if openrouterCost != nil {
663 newCost += *openrouterCost
664 }
665 openrouterCost = &newCost
666 }
667 }
668
669 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
670
671 // Just in case, get just the last usage info.
672 usage := resp.Response.Usage
673 currentSession.SummaryMessageID = summaryMessage.ID
674 currentSession.CompletionTokens = usage.OutputTokens
675 currentSession.PromptTokens = 0
676 _, err = a.sessions.Save(genCtx, currentSession)
677 return err
678}
679
680func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
681 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
682 return fantasy.ProviderOptions{}
683 }
684 return fantasy.ProviderOptions{
685 anthropic.Name: &anthropic.ProviderCacheControlOptions{
686 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
687 },
688 bedrock.Name: &anthropic.ProviderCacheControlOptions{
689 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
690 },
691 vercel.Name: &anthropic.ProviderCacheControlOptions{
692 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
693 },
694 }
695}
696
697func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
698 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
699 var attachmentParts []message.ContentPart
700 for _, attachment := range call.Attachments {
701 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
702 }
703 parts = append(parts, attachmentParts...)
704 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
705 Role: message.User,
706 Parts: parts,
707 })
708 if err != nil {
709 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
710 }
711 return msg, nil
712}
713
714func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
715 var history []fantasy.Message
716 if !a.isSubAgent {
717 history = append(history, fantasy.NewUserMessage(
718 fmt.Sprintf("<system_reminder>%s</system_reminder>",
719 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
720If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
721If not, please feel free to ignore. Again do not mention this message to the user.`,
722 ),
723 ))
724 }
725 for _, m := range msgs {
726 if len(m.Parts) == 0 {
727 continue
728 }
729 // Assistant message without content or tool calls (cancelled before it
730 // returned anything).
731 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
732 continue
733 }
734 history = append(history, m.ToAIMessage()...)
735 }
736
737 var files []fantasy.FilePart
738 for _, attachment := range attachments {
739 if attachment.IsText() {
740 continue
741 }
742 files = append(files, fantasy.FilePart{
743 Filename: attachment.FileName,
744 Data: attachment.Content,
745 MediaType: attachment.MimeType,
746 })
747 }
748
749 return history, files
750}
751
752func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
753 msgs, err := a.messages.List(ctx, session.ID)
754 if err != nil {
755 return nil, fmt.Errorf("failed to list messages: %w", err)
756 }
757
758 if session.SummaryMessageID != "" {
759 summaryMsgIndex := -1
760 for i, msg := range msgs {
761 if msg.ID == session.SummaryMessageID {
762 summaryMsgIndex = i
763 break
764 }
765 }
766 if summaryMsgIndex != -1 {
767 msgs = msgs[summaryMsgIndex:]
768 msgs[0].Role = message.User
769 }
770 }
771 return msgs, nil
772}
773
774// generateTitle generates a session titled based on the initial prompt.
775func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
776 if userPrompt == "" {
777 return
778 }
779
780 smallModel := a.smallModel.Get()
781 largeModel := a.largeModel.Get()
782 systemPromptPrefix := a.systemPromptPrefix.Get()
783
784 var maxOutputTokens int64 = 40
785 if smallModel.CatwalkCfg.CanReason {
786 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
787 }
788
789 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
790 return fantasy.NewAgent(m,
791 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
792 fantasy.WithMaxOutputTokens(tok),
793 fantasy.WithUserAgent("Charm Crush/"+version.Version),
794 )
795 }
796
797 streamCall := fantasy.AgentStreamCall{
798 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
799 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
800 prepared.Messages = opts.Messages
801 if systemPromptPrefix != "" {
802 prepared.Messages = append([]fantasy.Message{
803 fantasy.NewSystemMessage(systemPromptPrefix),
804 }, prepared.Messages...)
805 }
806 return callCtx, prepared, nil
807 },
808 }
809
810 // Use the small model to generate the title.
811 model := smallModel
812 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
813 resp, err := agent.Stream(ctx, streamCall)
814 if err == nil {
815 // We successfully generated a title with the small model.
816 slog.Debug("Generated title with small model")
817 } else {
818 // It didn't work. Let's try with the big model.
819 slog.Error("Error generating title with small model; trying big model", "err", err)
820 model = largeModel
821 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
822 resp, err = agent.Stream(ctx, streamCall)
823 if err == nil {
824 slog.Debug("Generated title with large model")
825 } else {
826 // Welp, the large model didn't work either. Use the default
827 // session name and return.
828 slog.Error("Error generating title with large model", "err", err)
829 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
830 if saveErr != nil {
831 slog.Error("Failed to save session title and usage", "error", saveErr)
832 }
833 return
834 }
835 }
836
837 if resp == nil {
838 // Actually, we didn't get a response so we can't. Use the default
839 // session name and return.
840 slog.Error("Response is nil; can't generate title")
841 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
842 if saveErr != nil {
843 slog.Error("Failed to save session title and usage", "error", saveErr)
844 }
845 return
846 }
847
848 // Clean up title.
849 var title string
850 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
851
852 // Remove thinking tags if present.
853 title = thinkTagRegex.ReplaceAllString(title, "")
854
855 title = strings.TrimSpace(title)
856 title = cmp.Or(title, DefaultSessionName)
857
858 // Calculate usage and cost.
859 var openrouterCost *float64
860 for _, step := range resp.Steps {
861 stepCost := a.openrouterCost(step.ProviderMetadata)
862 if stepCost != nil {
863 newCost := *stepCost
864 if openrouterCost != nil {
865 newCost += *openrouterCost
866 }
867 openrouterCost = &newCost
868 }
869 }
870
871 modelConfig := model.CatwalkCfg
872 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
873 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
874 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
875 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
876
877 // Use override cost if available (e.g., from OpenRouter).
878 if openrouterCost != nil {
879 cost = *openrouterCost
880 }
881
882 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
883 completionTokens := resp.TotalUsage.OutputTokens
884
885 // Atomically update only title and usage fields to avoid overriding other
886 // concurrent session updates.
887 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
888 if saveErr != nil {
889 slog.Error("Failed to save session title and usage", "error", saveErr)
890 return
891 }
892}
893
894func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
895 openrouterMetadata, ok := metadata[openrouter.Name]
896 if !ok {
897 return nil
898 }
899
900 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
901 if !ok {
902 return nil
903 }
904 return &opts.Usage.Cost
905}
906
907func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
908 modelConfig := model.CatwalkCfg
909 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
910 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
911 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
912 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
913
914 a.eventTokensUsed(session.ID, model, usage, cost)
915
916 if overrideCost != nil {
917 session.Cost += *overrideCost
918 } else {
919 session.Cost += cost
920 }
921
922 session.CompletionTokens = usage.OutputTokens
923 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
924}
925
926func (a *sessionAgent) Cancel(sessionID string) {
927 // Cancel regular requests. Don't use Take() here - we need the entry to
928 // remain in activeRequests so IsBusy() returns true until the goroutine
929 // fully completes (including error handling that may access the DB).
930 // The defer in processRequest will clean up the entry.
931 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
932 slog.Debug("Request cancellation initiated", "session_id", sessionID)
933 cancel()
934 }
935
936 // Also check for summarize requests.
937 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
938 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
939 cancel()
940 }
941
942 if a.QueuedPrompts(sessionID) > 0 {
943 slog.Debug("Clearing queued prompts", "session_id", sessionID)
944 a.messageQueue.Del(sessionID)
945 }
946}
947
948func (a *sessionAgent) ClearQueue(sessionID string) {
949 if a.QueuedPrompts(sessionID) > 0 {
950 slog.Debug("Clearing queued prompts", "session_id", sessionID)
951 a.messageQueue.Del(sessionID)
952 }
953}
954
955func (a *sessionAgent) CancelAll() {
956 if !a.IsBusy() {
957 return
958 }
959 for key := range a.activeRequests.Seq2() {
960 a.Cancel(key) // key is sessionID
961 }
962
963 timeout := time.After(5 * time.Second)
964 for a.IsBusy() {
965 select {
966 case <-timeout:
967 return
968 default:
969 time.Sleep(200 * time.Millisecond)
970 }
971 }
972}
973
974func (a *sessionAgent) IsBusy() bool {
975 var busy bool
976 for cancelFunc := range a.activeRequests.Seq() {
977 if cancelFunc != nil {
978 busy = true
979 break
980 }
981 }
982 return busy
983}
984
985func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
986 _, busy := a.activeRequests.Get(sessionID)
987 return busy
988}
989
990func (a *sessionAgent) QueuedPrompts(sessionID string) int {
991 l, ok := a.messageQueue.Get(sessionID)
992 if !ok {
993 return 0
994 }
995 return len(l)
996}
997
998func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
999 l, ok := a.messageQueue.Get(sessionID)
1000 if !ok {
1001 return nil
1002 }
1003 prompts := make([]string, len(l))
1004 for i, call := range l {
1005 prompts[i] = call.Prompt
1006 }
1007 return prompts
1008}
1009
1010func (a *sessionAgent) SetModels(large Model, small Model) {
1011 a.largeModel.Set(large)
1012 a.smallModel.Set(small)
1013}
1014
1015func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1016 a.tools.SetSlice(tools)
1017}
1018
1019func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1020 a.systemPrompt.Set(systemPrompt)
1021}
1022
1023func (a *sessionAgent) Model() Model {
1024 return a.largeModel.Get()
1025}
1026
1027// convertToToolResult converts a fantasy tool result to a message tool result.
1028func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1029 baseResult := message.ToolResult{
1030 ToolCallID: result.ToolCallID,
1031 Name: result.ToolName,
1032 Metadata: result.ClientMetadata,
1033 }
1034
1035 switch result.Result.GetType() {
1036 case fantasy.ToolResultContentTypeText:
1037 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1038 baseResult.Content = r.Text
1039 }
1040 case fantasy.ToolResultContentTypeError:
1041 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1042 baseResult.Content = r.Error.Error()
1043 baseResult.IsError = true
1044 }
1045 case fantasy.ToolResultContentTypeMedia:
1046 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1047 content := r.Text
1048 if content == "" {
1049 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1050 }
1051 baseResult.Content = content
1052 baseResult.Data = r.Data
1053 baseResult.MIMEType = r.MediaType
1054 }
1055 }
1056
1057 return baseResult
1058}
1059
1060// workaroundProviderMediaLimitations converts media content in tool results to
1061// user messages for providers that don't natively support images in tool results.
1062//
1063// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1064// don't support sending images/media in tool result messages - they only accept
1065// text in tool results. However, they DO support images in user messages.
1066//
1067// If we send media in tool results to these providers, the API returns an error.
1068//
1069// Solution: For these providers, we:
1070// 1. Replace the media in the tool result with a text placeholder
1071// 2. Inject a user message immediately after with the image as a file attachment
1072// 3. This maintains the tool execution flow while working around API limitations
1073//
1074// Anthropic and Bedrock support images natively in tool results, so we skip
1075// this workaround for them.
1076//
1077// Example transformation:
1078//
1079// BEFORE: [tool result: image data]
1080// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1081func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1082 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1083 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1084
1085 if providerSupportsMedia {
1086 return messages
1087 }
1088
1089 convertedMessages := make([]fantasy.Message, 0, len(messages))
1090
1091 for _, msg := range messages {
1092 if msg.Role != fantasy.MessageRoleTool {
1093 convertedMessages = append(convertedMessages, msg)
1094 continue
1095 }
1096
1097 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1098 var mediaFiles []fantasy.FilePart
1099
1100 for _, part := range msg.Content {
1101 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1102 if !ok {
1103 textParts = append(textParts, part)
1104 continue
1105 }
1106
1107 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1108 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1109 if err != nil {
1110 slog.Warn("Failed to decode media data", "error", err)
1111 textParts = append(textParts, part)
1112 continue
1113 }
1114
1115 mediaFiles = append(mediaFiles, fantasy.FilePart{
1116 Data: decoded,
1117 MediaType: media.MediaType,
1118 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1119 })
1120
1121 textParts = append(textParts, fantasy.ToolResultPart{
1122 ToolCallID: toolResult.ToolCallID,
1123 Output: fantasy.ToolResultOutputContentText{
1124 Text: "[Image/media content loaded - see attached file]",
1125 },
1126 ProviderOptions: toolResult.ProviderOptions,
1127 })
1128 } else {
1129 textParts = append(textParts, part)
1130 }
1131 }
1132
1133 convertedMessages = append(convertedMessages, fantasy.Message{
1134 Role: fantasy.MessageRoleTool,
1135 Content: textParts,
1136 })
1137
1138 if len(mediaFiles) > 0 {
1139 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1140 "Here is the media content from the tool result:",
1141 mediaFiles...,
1142 ))
1143 }
1144 }
1145
1146 return convertedMessages
1147}
1148
1149// buildSummaryPrompt constructs the prompt text for session summarization.
1150func buildSummaryPrompt(todos []session.Todo) string {
1151 var sb strings.Builder
1152 sb.WriteString("Provide a detailed summary of our conversation above.")
1153 if len(todos) > 0 {
1154 sb.WriteString("\n\n## Current Todo List\n\n")
1155 for _, t := range todos {
1156 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1157 }
1158 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1159 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1160 }
1161 return sb.String()
1162}