1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/tools"
36 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
37 "github.com/charmbracelet/crush/internal/config"
38 "github.com/charmbracelet/crush/internal/csync"
39 "github.com/charmbracelet/crush/internal/message"
40 "github.com/charmbracelet/crush/internal/permission"
41 "github.com/charmbracelet/crush/internal/session"
42 "github.com/charmbracelet/crush/internal/stringext"
43 "github.com/charmbracelet/x/exp/charmtone"
44)
45
46const (
47 defaultSessionName = "Untitled Session"
48
49 // Constants for auto-summarization thresholds
50 largeContextWindowThreshold = 200_000
51 largeContextWindowBuffer = 20_000
52 smallContextWindowRatio = 0.2
53)
54
55//go:embed templates/title.md
56var titlePrompt []byte
57
58//go:embed templates/summary.md
59var summaryPrompt []byte
60
61// Used to remove <think> tags from generated titles.
62var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
63
64type SessionAgentCall struct {
65 SessionID string
66 Prompt string
67 ProviderOptions fantasy.ProviderOptions
68 Attachments []message.Attachment
69 MaxOutputTokens int64
70 Temperature *float64
71 TopP *float64
72 TopK *int64
73 FrequencyPenalty *float64
74 PresencePenalty *float64
75}
76
77type SessionAgent interface {
78 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
79 SetModels(large Model, small Model)
80 SetTools(tools []fantasy.AgentTool)
81 SetSystemPrompt(systemPrompt string)
82 Cancel(sessionID string)
83 CancelAll()
84 IsSessionBusy(sessionID string) bool
85 IsBusy() bool
86 QueuedPrompts(sessionID string) int
87 QueuedPromptsList(sessionID string) []string
88 ClearQueue(sessionID string)
89 Summarize(context.Context, string, fantasy.ProviderOptions) error
90 Model() Model
91}
92
93type Model struct {
94 Model fantasy.LanguageModel
95 CatwalkCfg catwalk.Model
96 ModelCfg config.SelectedModel
97}
98
99type sessionAgent struct {
100 largeModel *csync.Value[Model]
101 smallModel *csync.Value[Model]
102 systemPromptPrefix *csync.Value[string]
103 systemPrompt *csync.Value[string]
104 tools *csync.Slice[fantasy.AgentTool]
105
106 isSubAgent bool
107 sessions session.Service
108 messages message.Service
109 disableAutoSummarize bool
110 isYolo bool
111
112 messageQueue *csync.Map[string, []SessionAgentCall]
113 activeRequests *csync.Map[string, context.CancelFunc]
114}
115
116type SessionAgentOptions struct {
117 LargeModel Model
118 SmallModel Model
119 SystemPromptPrefix string
120 SystemPrompt string
121 IsSubAgent bool
122 DisableAutoSummarize bool
123 IsYolo bool
124 Sessions session.Service
125 Messages message.Service
126 Tools []fantasy.AgentTool
127}
128
129func NewSessionAgent(
130 opts SessionAgentOptions,
131) SessionAgent {
132 return &sessionAgent{
133 largeModel: csync.NewValue(opts.LargeModel),
134 smallModel: csync.NewValue(opts.SmallModel),
135 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
136 systemPrompt: csync.NewValue(opts.SystemPrompt),
137 isSubAgent: opts.IsSubAgent,
138 sessions: opts.Sessions,
139 messages: opts.Messages,
140 disableAutoSummarize: opts.DisableAutoSummarize,
141 tools: csync.NewSliceFrom(opts.Tools),
142 isYolo: opts.IsYolo,
143 messageQueue: csync.NewMap[string, []SessionAgentCall](),
144 activeRequests: csync.NewMap[string, context.CancelFunc](),
145 }
146}
147
148func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
149 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
150 return nil, ErrEmptyPrompt
151 }
152 if call.SessionID == "" {
153 return nil, ErrSessionMissing
154 }
155
156 // Queue the message if busy
157 if a.IsSessionBusy(call.SessionID) {
158 existing, ok := a.messageQueue.Get(call.SessionID)
159 if !ok {
160 existing = []SessionAgentCall{}
161 }
162 existing = append(existing, call)
163 a.messageQueue.Set(call.SessionID, existing)
164 return nil, nil
165 }
166
167 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
168 agentTools := a.tools.Copy()
169 largeModel := a.largeModel.Get()
170 systemPrompt := a.systemPrompt.Get()
171 promptPrefix := a.systemPromptPrefix.Get()
172 var instructions strings.Builder
173
174 for _, server := range mcp.GetStates() {
175 if server.State != mcp.StateConnected {
176 continue
177 }
178 if s := server.Client.InitializeResult().Instructions; s != "" {
179 instructions.WriteString(s)
180 instructions.WriteString("\n\n")
181 }
182 }
183
184 if s := instructions.String(); s != "" {
185 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
186 }
187
188 if len(agentTools) > 0 {
189 // Add Anthropic caching to the last tool.
190 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
191 }
192
193 agent := fantasy.NewAgent(
194 largeModel.Model,
195 fantasy.WithSystemPrompt(systemPrompt),
196 fantasy.WithTools(agentTools...),
197 )
198
199 sessionLock := sync.Mutex{}
200 currentSession, err := a.sessions.Get(ctx, call.SessionID)
201 if err != nil {
202 return nil, fmt.Errorf("failed to get session: %w", err)
203 }
204
205 msgs, err := a.getSessionMessages(ctx, currentSession)
206 if err != nil {
207 return nil, fmt.Errorf("failed to get session messages: %w", err)
208 }
209
210 var wg sync.WaitGroup
211 // Generate title if first message.
212 if len(msgs) == 0 {
213 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
214 wg.Go(func() {
215 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
216 })
217 }
218 defer wg.Wait()
219
220 // Add the user message to the session.
221 _, err = a.createUserMessage(ctx, call)
222 if err != nil {
223 return nil, err
224 }
225
226 // Add the session to the context.
227 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
228
229 genCtx, cancel := context.WithCancel(ctx)
230 a.activeRequests.Set(call.SessionID, cancel)
231
232 defer cancel()
233 defer a.activeRequests.Del(call.SessionID)
234
235 history, files := a.preparePrompt(msgs, call.Attachments...)
236
237 startTime := time.Now()
238 a.eventPromptSent(call.SessionID)
239
240 var currentAssistant *message.Message
241 var shouldSummarize bool
242 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
243 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
244 Files: files,
245 Messages: history,
246 ProviderOptions: call.ProviderOptions,
247 MaxOutputTokens: &call.MaxOutputTokens,
248 TopP: call.TopP,
249 Temperature: call.Temperature,
250 PresencePenalty: call.PresencePenalty,
251 TopK: call.TopK,
252 FrequencyPenalty: call.FrequencyPenalty,
253 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
254 prepared.Messages = options.Messages
255 for i := range prepared.Messages {
256 prepared.Messages[i].ProviderOptions = nil
257 }
258
259 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
260 a.messageQueue.Del(call.SessionID)
261 for _, queued := range queuedCalls {
262 userMessage, createErr := a.createUserMessage(callContext, queued)
263 if createErr != nil {
264 return callContext, prepared, createErr
265 }
266 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
267 }
268
269 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
270
271 lastSystemRoleInx := 0
272 systemMessageUpdated := false
273 for i, msg := range prepared.Messages {
274 // Only add cache control to the last message.
275 if msg.Role == fantasy.MessageRoleSystem {
276 lastSystemRoleInx = i
277 } else if !systemMessageUpdated {
278 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
279 systemMessageUpdated = true
280 }
281 // Than add cache control to the last 2 messages.
282 if i > len(prepared.Messages)-3 {
283 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
284 }
285 }
286
287 if promptPrefix != "" {
288 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
289 }
290
291 var assistantMsg message.Message
292 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
293 Role: message.Assistant,
294 Parts: []message.ContentPart{},
295 Model: largeModel.ModelCfg.Model,
296 Provider: largeModel.ModelCfg.Provider,
297 })
298 if err != nil {
299 return callContext, prepared, err
300 }
301 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
302 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
303 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
304 currentAssistant = &assistantMsg
305 return callContext, prepared, err
306 },
307 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
308 currentAssistant.AppendReasoningContent(reasoning.Text)
309 return a.messages.Update(genCtx, *currentAssistant)
310 },
311 OnReasoningDelta: func(id string, text string) error {
312 currentAssistant.AppendReasoningContent(text)
313 return a.messages.Update(genCtx, *currentAssistant)
314 },
315 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
316 // handle anthropic signature
317 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
318 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
319 currentAssistant.AppendReasoningSignature(reasoning.Signature)
320 }
321 }
322 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
323 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
324 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
325 }
326 }
327 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
328 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
329 currentAssistant.SetReasoningResponsesData(reasoning)
330 }
331 }
332 currentAssistant.FinishThinking()
333 return a.messages.Update(genCtx, *currentAssistant)
334 },
335 OnTextDelta: func(id string, text string) error {
336 // Strip leading newline from initial text content. This is is
337 // particularly important in non-interactive mode where leading
338 // newlines are very visible.
339 if len(currentAssistant.Parts) == 0 {
340 text = strings.TrimPrefix(text, "\n")
341 }
342
343 currentAssistant.AppendContent(text)
344 return a.messages.Update(genCtx, *currentAssistant)
345 },
346 OnToolInputStart: func(id string, toolName string) error {
347 toolCall := message.ToolCall{
348 ID: id,
349 Name: toolName,
350 ProviderExecuted: false,
351 Finished: false,
352 }
353 currentAssistant.AddToolCall(toolCall)
354 return a.messages.Update(genCtx, *currentAssistant)
355 },
356 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
357 // TODO: implement
358 },
359 OnToolCall: func(tc fantasy.ToolCallContent) error {
360 toolCall := message.ToolCall{
361 ID: tc.ToolCallID,
362 Name: tc.ToolName,
363 Input: tc.Input,
364 ProviderExecuted: false,
365 Finished: true,
366 }
367 currentAssistant.AddToolCall(toolCall)
368 return a.messages.Update(genCtx, *currentAssistant)
369 },
370 OnToolResult: func(result fantasy.ToolResultContent) error {
371 toolResult := a.convertToToolResult(result)
372 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
373 Role: message.Tool,
374 Parts: []message.ContentPart{
375 toolResult,
376 },
377 })
378 return createMsgErr
379 },
380 OnStepFinish: func(stepResult fantasy.StepResult) error {
381 finishReason := message.FinishReasonUnknown
382 switch stepResult.FinishReason {
383 case fantasy.FinishReasonLength:
384 finishReason = message.FinishReasonMaxTokens
385 case fantasy.FinishReasonStop:
386 finishReason = message.FinishReasonEndTurn
387 case fantasy.FinishReasonToolCalls:
388 finishReason = message.FinishReasonToolUse
389 }
390 currentAssistant.AddFinish(finishReason, "", "")
391 sessionLock.Lock()
392 defer sessionLock.Unlock()
393
394 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
395 if getSessionErr != nil {
396 return getSessionErr
397 }
398 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
399 _, sessionErr := a.sessions.Save(ctx, updatedSession)
400 if sessionErr != nil {
401 return sessionErr
402 }
403 currentSession = updatedSession
404 return a.messages.Update(genCtx, *currentAssistant)
405 },
406 StopWhen: []fantasy.StopCondition{
407 func(_ []fantasy.StepResult) bool {
408 cw := int64(largeModel.CatwalkCfg.ContextWindow)
409 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
410 remaining := cw - tokens
411 var threshold int64
412 if cw > largeContextWindowThreshold {
413 threshold = largeContextWindowBuffer
414 } else {
415 threshold = int64(float64(cw) * smallContextWindowRatio)
416 }
417 if (remaining <= threshold) && !a.disableAutoSummarize {
418 shouldSummarize = true
419 return true
420 }
421 return false
422 },
423 func(steps []fantasy.StepResult) bool {
424 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
425 },
426 },
427 })
428
429 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
430
431 if err != nil {
432 isCancelErr := errors.Is(err, context.Canceled)
433 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
434 if currentAssistant == nil {
435 return result, err
436 }
437 // Ensure we finish thinking on error to close the reasoning state.
438 currentAssistant.FinishThinking()
439 toolCalls := currentAssistant.ToolCalls()
440 // INFO: we use the parent context here because the genCtx has been cancelled.
441 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
442 if createErr != nil {
443 return nil, createErr
444 }
445 for _, tc := range toolCalls {
446 if !tc.Finished {
447 tc.Finished = true
448 tc.Input = "{}"
449 currentAssistant.AddToolCall(tc)
450 updateErr := a.messages.Update(ctx, *currentAssistant)
451 if updateErr != nil {
452 return nil, updateErr
453 }
454 }
455
456 found := false
457 for _, msg := range msgs {
458 if msg.Role == message.Tool {
459 for _, tr := range msg.ToolResults() {
460 if tr.ToolCallID == tc.ID {
461 found = true
462 break
463 }
464 }
465 }
466 if found {
467 break
468 }
469 }
470 if found {
471 continue
472 }
473 content := "There was an error while executing the tool"
474 if isCancelErr {
475 content = "Tool execution canceled by user"
476 } else if isPermissionErr {
477 content = "User denied permission"
478 }
479 toolResult := message.ToolResult{
480 ToolCallID: tc.ID,
481 Name: tc.Name,
482 Content: content,
483 IsError: true,
484 }
485 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
486 Role: message.Tool,
487 Parts: []message.ContentPart{
488 toolResult,
489 },
490 })
491 if createErr != nil {
492 return nil, createErr
493 }
494 }
495 var fantasyErr *fantasy.Error
496 var providerErr *fantasy.ProviderError
497 const defaultTitle = "Provider Error"
498 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
499 if isCancelErr {
500 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
501 } else if isPermissionErr {
502 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
503 } else if errors.Is(err, hyper.ErrNoCredits) {
504 url := hyper.BaseURL()
505 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
506 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
507 } else if errors.As(err, &providerErr) {
508 if providerErr.Message == "The requested model is not supported." {
509 url := "https://github.com/settings/copilot/features"
510 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
511 currentAssistant.AddFinish(
512 message.FinishReasonError,
513 "Copilot model not enabled",
514 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
515 )
516 } else {
517 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
518 }
519 } else if errors.As(err, &fantasyErr) {
520 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
521 } else {
522 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
523 }
524 // Note: we use the parent context here because the genCtx has been
525 // cancelled.
526 updateErr := a.messages.Update(ctx, *currentAssistant)
527 if updateErr != nil {
528 return nil, updateErr
529 }
530 return nil, err
531 }
532
533 if shouldSummarize {
534 a.activeRequests.Del(call.SessionID)
535 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
536 return nil, summarizeErr
537 }
538 // If the agent wasn't done...
539 if len(currentAssistant.ToolCalls()) > 0 {
540 existing, ok := a.messageQueue.Get(call.SessionID)
541 if !ok {
542 existing = []SessionAgentCall{}
543 }
544 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
545 existing = append(existing, call)
546 a.messageQueue.Set(call.SessionID, existing)
547 }
548 }
549
550 // Release active request before processing queued messages.
551 a.activeRequests.Del(call.SessionID)
552 cancel()
553
554 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
555 if !ok || len(queuedMessages) == 0 {
556 return result, err
557 }
558 // There are queued messages restart the loop.
559 firstQueuedMessage := queuedMessages[0]
560 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
561 return a.Run(ctx, firstQueuedMessage)
562}
563
564func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
565 if a.IsSessionBusy(sessionID) {
566 return ErrSessionBusy
567 }
568
569 // Copy mutable fields under lock to avoid races with SetModels.
570 largeModel := a.largeModel.Get()
571 systemPromptPrefix := a.systemPromptPrefix.Get()
572
573 currentSession, err := a.sessions.Get(ctx, sessionID)
574 if err != nil {
575 return fmt.Errorf("failed to get session: %w", err)
576 }
577 msgs, err := a.getSessionMessages(ctx, currentSession)
578 if err != nil {
579 return err
580 }
581 if len(msgs) == 0 {
582 // Nothing to summarize.
583 return nil
584 }
585
586 aiMsgs, _ := a.preparePrompt(msgs)
587
588 genCtx, cancel := context.WithCancel(ctx)
589 a.activeRequests.Set(sessionID, cancel)
590 defer a.activeRequests.Del(sessionID)
591 defer cancel()
592
593 agent := fantasy.NewAgent(largeModel.Model,
594 fantasy.WithSystemPrompt(string(summaryPrompt)),
595 )
596 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
597 Role: message.Assistant,
598 Model: largeModel.Model.Model(),
599 Provider: largeModel.Model.Provider(),
600 IsSummaryMessage: true,
601 })
602 if err != nil {
603 return err
604 }
605
606 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
607
608 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
609 Prompt: summaryPromptText,
610 Messages: aiMsgs,
611 ProviderOptions: opts,
612 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
613 prepared.Messages = options.Messages
614 if systemPromptPrefix != "" {
615 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
616 }
617 return callContext, prepared, nil
618 },
619 OnReasoningDelta: func(id string, text string) error {
620 summaryMessage.AppendReasoningContent(text)
621 return a.messages.Update(genCtx, summaryMessage)
622 },
623 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
624 // Handle anthropic signature.
625 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
626 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
627 summaryMessage.AppendReasoningSignature(signature.Signature)
628 }
629 }
630 summaryMessage.FinishThinking()
631 return a.messages.Update(genCtx, summaryMessage)
632 },
633 OnTextDelta: func(id, text string) error {
634 summaryMessage.AppendContent(text)
635 return a.messages.Update(genCtx, summaryMessage)
636 },
637 })
638 if err != nil {
639 isCancelErr := errors.Is(err, context.Canceled)
640 if isCancelErr {
641 // User cancelled summarize we need to remove the summary message.
642 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
643 return deleteErr
644 }
645 return err
646 }
647
648 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
649 err = a.messages.Update(genCtx, summaryMessage)
650 if err != nil {
651 return err
652 }
653
654 var openrouterCost *float64
655 for _, step := range resp.Steps {
656 stepCost := a.openrouterCost(step.ProviderMetadata)
657 if stepCost != nil {
658 newCost := *stepCost
659 if openrouterCost != nil {
660 newCost += *openrouterCost
661 }
662 openrouterCost = &newCost
663 }
664 }
665
666 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
667
668 // Just in case, get just the last usage info.
669 usage := resp.Response.Usage
670 currentSession.SummaryMessageID = summaryMessage.ID
671 currentSession.CompletionTokens = usage.OutputTokens
672 currentSession.PromptTokens = 0
673 _, err = a.sessions.Save(genCtx, currentSession)
674 return err
675}
676
677func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
678 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
679 return fantasy.ProviderOptions{}
680 }
681 return fantasy.ProviderOptions{
682 anthropic.Name: &anthropic.ProviderCacheControlOptions{
683 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
684 },
685 bedrock.Name: &anthropic.ProviderCacheControlOptions{
686 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
687 },
688 vercel.Name: &anthropic.ProviderCacheControlOptions{
689 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
690 },
691 }
692}
693
694func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
695 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
696 var attachmentParts []message.ContentPart
697 for _, attachment := range call.Attachments {
698 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
699 }
700 parts = append(parts, attachmentParts...)
701 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
702 Role: message.User,
703 Parts: parts,
704 })
705 if err != nil {
706 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
707 }
708 return msg, nil
709}
710
711func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
712 var history []fantasy.Message
713 if !a.isSubAgent {
714 history = append(history, fantasy.NewUserMessage(
715 fmt.Sprintf("<system_reminder>%s</system_reminder>",
716 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
717If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
718If not, please feel free to ignore. Again do not mention this message to the user.`,
719 ),
720 ))
721 }
722 for _, m := range msgs {
723 if len(m.Parts) == 0 {
724 continue
725 }
726 // Assistant message without content or tool calls (cancelled before it
727 // returned anything).
728 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
729 continue
730 }
731 history = append(history, m.ToAIMessage()...)
732 }
733
734 var files []fantasy.FilePart
735 for _, attachment := range attachments {
736 if attachment.IsText() {
737 continue
738 }
739 files = append(files, fantasy.FilePart{
740 Filename: attachment.FileName,
741 Data: attachment.Content,
742 MediaType: attachment.MimeType,
743 })
744 }
745
746 return history, files
747}
748
749func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
750 msgs, err := a.messages.List(ctx, session.ID)
751 if err != nil {
752 return nil, fmt.Errorf("failed to list messages: %w", err)
753 }
754
755 if session.SummaryMessageID != "" {
756 summaryMsgIndex := -1
757 for i, msg := range msgs {
758 if msg.ID == session.SummaryMessageID {
759 summaryMsgIndex = i
760 break
761 }
762 }
763 if summaryMsgIndex != -1 {
764 msgs = msgs[summaryMsgIndex:]
765 msgs[0].Role = message.User
766 }
767 }
768 return msgs, nil
769}
770
771// generateTitle generates a session titled based on the initial prompt.
772func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
773 if userPrompt == "" {
774 return
775 }
776
777 smallModel := a.smallModel.Get()
778 largeModel := a.largeModel.Get()
779 systemPromptPrefix := a.systemPromptPrefix.Get()
780
781 var maxOutputTokens int64 = 40
782 if smallModel.CatwalkCfg.CanReason {
783 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
784 }
785
786 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
787 return fantasy.NewAgent(m,
788 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
789 fantasy.WithMaxOutputTokens(tok),
790 )
791 }
792
793 streamCall := fantasy.AgentStreamCall{
794 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
795 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
796 prepared.Messages = opts.Messages
797 if systemPromptPrefix != "" {
798 prepared.Messages = append([]fantasy.Message{
799 fantasy.NewSystemMessage(systemPromptPrefix),
800 }, prepared.Messages...)
801 }
802 return callCtx, prepared, nil
803 },
804 }
805
806 // Use the small model to generate the title.
807 model := smallModel
808 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
809 resp, err := agent.Stream(ctx, streamCall)
810 if err == nil {
811 // We successfully generated a title with the small model.
812 slog.Debug("Generated title with small model")
813 } else {
814 // It didn't work. Let's try with the big model.
815 slog.Error("Error generating title with small model; trying big model", "err", err)
816 model = largeModel
817 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
818 resp, err = agent.Stream(ctx, streamCall)
819 if err == nil {
820 slog.Debug("Generated title with large model")
821 } else {
822 // Welp, the large model didn't work either. Use the default
823 // session name and return.
824 slog.Error("Error generating title with large model", "err", err)
825 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
826 if saveErr != nil {
827 slog.Error("Failed to save session title and usage", "error", saveErr)
828 }
829 return
830 }
831 }
832
833 if resp == nil {
834 // Actually, we didn't get a response so we can't. Use the default
835 // session name and return.
836 slog.Error("Response is nil; can't generate title")
837 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
838 if saveErr != nil {
839 slog.Error("Failed to save session title and usage", "error", saveErr)
840 }
841 return
842 }
843
844 // Clean up title.
845 var title string
846 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
847
848 // Remove thinking tags if present.
849 title = thinkTagRegex.ReplaceAllString(title, "")
850
851 title = strings.TrimSpace(title)
852 title = cmp.Or(title, defaultSessionName)
853
854 // Calculate usage and cost.
855 var openrouterCost *float64
856 for _, step := range resp.Steps {
857 stepCost := a.openrouterCost(step.ProviderMetadata)
858 if stepCost != nil {
859 newCost := *stepCost
860 if openrouterCost != nil {
861 newCost += *openrouterCost
862 }
863 openrouterCost = &newCost
864 }
865 }
866
867 modelConfig := model.CatwalkCfg
868 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
869 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
870 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
871 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
872
873 // Use override cost if available (e.g., from OpenRouter).
874 if openrouterCost != nil {
875 cost = *openrouterCost
876 }
877
878 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
879 completionTokens := resp.TotalUsage.OutputTokens
880
881 // Atomically update only title and usage fields to avoid overriding other
882 // concurrent session updates.
883 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
884 if saveErr != nil {
885 slog.Error("Failed to save session title and usage", "error", saveErr)
886 return
887 }
888}
889
890func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
891 openrouterMetadata, ok := metadata[openrouter.Name]
892 if !ok {
893 return nil
894 }
895
896 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
897 if !ok {
898 return nil
899 }
900 return &opts.Usage.Cost
901}
902
903func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
904 modelConfig := model.CatwalkCfg
905 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
906 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
907 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
908 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
909
910 a.eventTokensUsed(session.ID, model, usage, cost)
911
912 if overrideCost != nil {
913 session.Cost += *overrideCost
914 } else {
915 session.Cost += cost
916 }
917
918 session.CompletionTokens = usage.OutputTokens
919 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
920}
921
922func (a *sessionAgent) Cancel(sessionID string) {
923 // Cancel regular requests. Don't use Take() here - we need the entry to
924 // remain in activeRequests so IsBusy() returns true until the goroutine
925 // fully completes (including error handling that may access the DB).
926 // The defer in processRequest will clean up the entry.
927 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
928 slog.Debug("Request cancellation initiated", "session_id", sessionID)
929 cancel()
930 }
931
932 // Also check for summarize requests.
933 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
934 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
935 cancel()
936 }
937
938 if a.QueuedPrompts(sessionID) > 0 {
939 slog.Debug("Clearing queued prompts", "session_id", sessionID)
940 a.messageQueue.Del(sessionID)
941 }
942}
943
944func (a *sessionAgent) ClearQueue(sessionID string) {
945 if a.QueuedPrompts(sessionID) > 0 {
946 slog.Debug("Clearing queued prompts", "session_id", sessionID)
947 a.messageQueue.Del(sessionID)
948 }
949}
950
951func (a *sessionAgent) CancelAll() {
952 if !a.IsBusy() {
953 return
954 }
955 for key := range a.activeRequests.Seq2() {
956 a.Cancel(key) // key is sessionID
957 }
958
959 timeout := time.After(5 * time.Second)
960 for a.IsBusy() {
961 select {
962 case <-timeout:
963 return
964 default:
965 time.Sleep(200 * time.Millisecond)
966 }
967 }
968}
969
970func (a *sessionAgent) IsBusy() bool {
971 var busy bool
972 for cancelFunc := range a.activeRequests.Seq() {
973 if cancelFunc != nil {
974 busy = true
975 break
976 }
977 }
978 return busy
979}
980
981func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
982 _, busy := a.activeRequests.Get(sessionID)
983 return busy
984}
985
986func (a *sessionAgent) QueuedPrompts(sessionID string) int {
987 l, ok := a.messageQueue.Get(sessionID)
988 if !ok {
989 return 0
990 }
991 return len(l)
992}
993
994func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
995 l, ok := a.messageQueue.Get(sessionID)
996 if !ok {
997 return nil
998 }
999 prompts := make([]string, len(l))
1000 for i, call := range l {
1001 prompts[i] = call.Prompt
1002 }
1003 return prompts
1004}
1005
1006func (a *sessionAgent) SetModels(large Model, small Model) {
1007 a.largeModel.Set(large)
1008 a.smallModel.Set(small)
1009}
1010
1011func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1012 a.tools.SetSlice(tools)
1013}
1014
1015func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1016 a.systemPrompt.Set(systemPrompt)
1017}
1018
1019func (a *sessionAgent) Model() Model {
1020 return a.largeModel.Get()
1021}
1022
1023// convertToToolResult converts a fantasy tool result to a message tool result.
1024func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1025 baseResult := message.ToolResult{
1026 ToolCallID: result.ToolCallID,
1027 Name: result.ToolName,
1028 Metadata: result.ClientMetadata,
1029 }
1030
1031 switch result.Result.GetType() {
1032 case fantasy.ToolResultContentTypeText:
1033 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1034 baseResult.Content = r.Text
1035 }
1036 case fantasy.ToolResultContentTypeError:
1037 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1038 baseResult.Content = r.Error.Error()
1039 baseResult.IsError = true
1040 }
1041 case fantasy.ToolResultContentTypeMedia:
1042 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1043 content := r.Text
1044 if content == "" {
1045 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1046 }
1047 baseResult.Content = content
1048 baseResult.Data = r.Data
1049 baseResult.MIMEType = r.MediaType
1050 }
1051 }
1052
1053 return baseResult
1054}
1055
1056// workaroundProviderMediaLimitations converts media content in tool results to
1057// user messages for providers that don't natively support images in tool results.
1058//
1059// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1060// don't support sending images/media in tool result messages - they only accept
1061// text in tool results. However, they DO support images in user messages.
1062//
1063// If we send media in tool results to these providers, the API returns an error.
1064//
1065// Solution: For these providers, we:
1066// 1. Replace the media in the tool result with a text placeholder
1067// 2. Inject a user message immediately after with the image as a file attachment
1068// 3. This maintains the tool execution flow while working around API limitations
1069//
1070// Anthropic and Bedrock support images natively in tool results, so we skip
1071// this workaround for them.
1072//
1073// Example transformation:
1074//
1075// BEFORE: [tool result: image data]
1076// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1077func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1078 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1079 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1080
1081 if providerSupportsMedia {
1082 return messages
1083 }
1084
1085 convertedMessages := make([]fantasy.Message, 0, len(messages))
1086
1087 for _, msg := range messages {
1088 if msg.Role != fantasy.MessageRoleTool {
1089 convertedMessages = append(convertedMessages, msg)
1090 continue
1091 }
1092
1093 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1094 var mediaFiles []fantasy.FilePart
1095
1096 for _, part := range msg.Content {
1097 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1098 if !ok {
1099 textParts = append(textParts, part)
1100 continue
1101 }
1102
1103 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1104 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1105 if err != nil {
1106 slog.Warn("Failed to decode media data", "error", err)
1107 textParts = append(textParts, part)
1108 continue
1109 }
1110
1111 mediaFiles = append(mediaFiles, fantasy.FilePart{
1112 Data: decoded,
1113 MediaType: media.MediaType,
1114 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1115 })
1116
1117 textParts = append(textParts, fantasy.ToolResultPart{
1118 ToolCallID: toolResult.ToolCallID,
1119 Output: fantasy.ToolResultOutputContentText{
1120 Text: "[Image/media content loaded - see attached file]",
1121 },
1122 ProviderOptions: toolResult.ProviderOptions,
1123 })
1124 } else {
1125 textParts = append(textParts, part)
1126 }
1127 }
1128
1129 convertedMessages = append(convertedMessages, fantasy.Message{
1130 Role: fantasy.MessageRoleTool,
1131 Content: textParts,
1132 })
1133
1134 if len(mediaFiles) > 0 {
1135 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1136 "Here is the media content from the tool result:",
1137 mediaFiles...,
1138 ))
1139 }
1140 }
1141
1142 return convertedMessages
1143}
1144
1145// buildSummaryPrompt constructs the prompt text for session summarization.
1146func buildSummaryPrompt(todos []session.Todo) string {
1147 var sb strings.Builder
1148 sb.WriteString("Provide a detailed summary of our conversation above.")
1149 if len(todos) > 0 {
1150 sb.WriteString("\n\n## Current Todo List\n\n")
1151 for _, t := range todos {
1152 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1153 }
1154 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1155 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1156 }
1157 return sb.String()
1158}