1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/tools"
36 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
37 "github.com/charmbracelet/crush/internal/config"
38 "github.com/charmbracelet/crush/internal/csync"
39 "github.com/charmbracelet/crush/internal/message"
40 "github.com/charmbracelet/crush/internal/permission"
41 "github.com/charmbracelet/crush/internal/session"
42 "github.com/charmbracelet/crush/internal/stringext"
43 "github.com/charmbracelet/x/exp/charmtone"
44)
45
46const (
47 defaultSessionName = "Untitled Session"
48
49 // Constants for auto-summarization thresholds
50 largeContextWindowThreshold = 200_000
51 largeContextWindowBuffer = 20_000
52 smallContextWindowRatio = 0.2
53)
54
55//go:embed templates/title.md
56var titlePrompt []byte
57
58//go:embed templates/summary.md
59var summaryPrompt []byte
60
61// Used to remove <think> tags from generated titles.
62var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
63
64type SessionAgentCall struct {
65 SessionID string
66 Prompt string
67 ProviderOptions fantasy.ProviderOptions
68 Attachments []message.Attachment
69 MaxOutputTokens int64
70 Temperature *float64
71 TopP *float64
72 TopK *int64
73 FrequencyPenalty *float64
74 PresencePenalty *float64
75}
76
77type SessionAgent interface {
78 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
79 SetModels(large Model, small Model)
80 SetTools(tools []fantasy.AgentTool)
81 SetSystemPrompt(systemPrompt string)
82 Cancel(sessionID string)
83 CancelAll()
84 IsSessionBusy(sessionID string) bool
85 IsBusy() bool
86 QueuedPrompts(sessionID string) int
87 QueuedPromptsList(sessionID string) []string
88 ClearQueue(sessionID string)
89 Summarize(context.Context, string, fantasy.ProviderOptions) error
90 Model() Model
91}
92
93type Model struct {
94 Model fantasy.LanguageModel
95 CatwalkCfg catwalk.Model
96 ModelCfg config.SelectedModel
97}
98
99type sessionAgent struct {
100 largeModel *csync.Value[Model]
101 smallModel *csync.Value[Model]
102 systemPromptPrefix *csync.Value[string]
103 systemPrompt *csync.Value[string]
104 tools *csync.Slice[fantasy.AgentTool]
105
106 isSubAgent bool
107 sessions session.Service
108 messages message.Service
109 disableAutoSummarize bool
110 isYolo bool
111
112 messageQueue *csync.Map[string, []SessionAgentCall]
113 activeRequests *csync.Map[string, context.CancelFunc]
114}
115
116type SessionAgentOptions struct {
117 LargeModel Model
118 SmallModel Model
119 SystemPromptPrefix string
120 SystemPrompt string
121 IsSubAgent bool
122 DisableAutoSummarize bool
123 IsYolo bool
124 Sessions session.Service
125 Messages message.Service
126 Tools []fantasy.AgentTool
127}
128
129func NewSessionAgent(
130 opts SessionAgentOptions,
131) SessionAgent {
132 return &sessionAgent{
133 largeModel: csync.NewValue(opts.LargeModel),
134 smallModel: csync.NewValue(opts.SmallModel),
135 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
136 systemPrompt: csync.NewValue(opts.SystemPrompt),
137 isSubAgent: opts.IsSubAgent,
138 sessions: opts.Sessions,
139 messages: opts.Messages,
140 disableAutoSummarize: opts.DisableAutoSummarize,
141 tools: csync.NewSliceFrom(opts.Tools),
142 isYolo: opts.IsYolo,
143 messageQueue: csync.NewMap[string, []SessionAgentCall](),
144 activeRequests: csync.NewMap[string, context.CancelFunc](),
145 }
146}
147
148func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
149 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
150 return nil, ErrEmptyPrompt
151 }
152 if call.SessionID == "" {
153 return nil, ErrSessionMissing
154 }
155
156 // Queue the message if busy
157 if a.IsSessionBusy(call.SessionID) {
158 existing, ok := a.messageQueue.Get(call.SessionID)
159 if !ok {
160 existing = []SessionAgentCall{}
161 }
162 existing = append(existing, call)
163 a.messageQueue.Set(call.SessionID, existing)
164 return nil, nil
165 }
166
167 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
168 agentTools := a.tools.Copy()
169 largeModel := a.largeModel.Get()
170 systemPrompt := a.systemPrompt.Get()
171 promptPrefix := a.systemPromptPrefix.Get()
172 var instructions strings.Builder
173
174 for _, server := range mcp.GetStates() {
175 if server.State != mcp.StateConnected {
176 continue
177 }
178 if s := server.Client.InitializeResult().Instructions; s != "" {
179 instructions.WriteString(s)
180 instructions.WriteString("\n\n")
181 }
182 }
183
184 if s := instructions.String(); s != "" {
185 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
186 }
187
188 if len(agentTools) > 0 {
189 // Add Anthropic caching to the last tool.
190 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
191 }
192
193 agent := fantasy.NewAgent(
194 largeModel.Model,
195 fantasy.WithSystemPrompt(systemPrompt),
196 fantasy.WithTools(agentTools...),
197 )
198
199 sessionLock := sync.Mutex{}
200 currentSession, err := a.sessions.Get(ctx, call.SessionID)
201 if err != nil {
202 return nil, fmt.Errorf("failed to get session: %w", err)
203 }
204
205 msgs, err := a.getSessionMessages(ctx, currentSession)
206 if err != nil {
207 return nil, fmt.Errorf("failed to get session messages: %w", err)
208 }
209
210 var wg sync.WaitGroup
211 // Generate title if first message.
212 if len(msgs) == 0 {
213 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
214 wg.Go(func() {
215 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
216 })
217 }
218 defer wg.Wait()
219
220 // Add the user message to the session.
221 _, err = a.createUserMessage(ctx, call)
222 if err != nil {
223 return nil, err
224 }
225
226 // Add the session to the context.
227 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
228
229 genCtx, cancel := context.WithCancel(ctx)
230 a.activeRequests.Set(call.SessionID, cancel)
231
232 defer cancel()
233 defer a.activeRequests.Del(call.SessionID)
234
235 history, files := a.preparePrompt(msgs, call.Attachments...)
236
237 startTime := time.Now()
238 a.eventPromptSent(call.SessionID)
239
240 var currentAssistant *message.Message
241 var shouldSummarize bool
242 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
243 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
244 Files: files,
245 Messages: history,
246 ProviderOptions: call.ProviderOptions,
247 MaxOutputTokens: &call.MaxOutputTokens,
248 TopP: call.TopP,
249 Temperature: call.Temperature,
250 PresencePenalty: call.PresencePenalty,
251 TopK: call.TopK,
252 FrequencyPenalty: call.FrequencyPenalty,
253 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
254 prepared.Messages = options.Messages
255 for i := range prepared.Messages {
256 prepared.Messages[i].ProviderOptions = nil
257 }
258
259 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
260 a.messageQueue.Del(call.SessionID)
261 for _, queued := range queuedCalls {
262 userMessage, createErr := a.createUserMessage(callContext, queued)
263 if createErr != nil {
264 return callContext, prepared, createErr
265 }
266 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
267 }
268
269 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
270
271 lastSystemRoleInx := 0
272 systemMessageUpdated := false
273 for i, msg := range prepared.Messages {
274 // Only add cache control to the last message.
275 if msg.Role == fantasy.MessageRoleSystem {
276 lastSystemRoleInx = i
277 } else if !systemMessageUpdated {
278 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
279 systemMessageUpdated = true
280 }
281 // Than add cache control to the last 2 messages.
282 if i > len(prepared.Messages)-3 {
283 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
284 }
285 }
286
287 if promptPrefix != "" {
288 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
289 }
290
291 var assistantMsg message.Message
292 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
293 Role: message.Assistant,
294 Parts: []message.ContentPart{},
295 Model: largeModel.ModelCfg.Model,
296 Provider: largeModel.ModelCfg.Provider,
297 })
298 if err != nil {
299 return callContext, prepared, err
300 }
301 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
302 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
303 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
304 currentAssistant = &assistantMsg
305 return callContext, prepared, err
306 },
307 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
308 currentAssistant.AppendReasoningContent(reasoning.Text)
309 return a.messages.Update(genCtx, *currentAssistant)
310 },
311 OnReasoningDelta: func(id string, text string) error {
312 currentAssistant.AppendReasoningContent(text)
313 return a.messages.Update(genCtx, *currentAssistant)
314 },
315 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
316 // handle anthropic signature
317 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
318 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
319 currentAssistant.AppendReasoningSignature(reasoning.Signature)
320 }
321 }
322 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
323 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
324 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
325 }
326 }
327 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
328 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
329 currentAssistant.SetReasoningResponsesData(reasoning)
330 }
331 }
332 currentAssistant.FinishThinking()
333 return a.messages.Update(genCtx, *currentAssistant)
334 },
335 OnTextDelta: func(id string, text string) error {
336 // Strip leading newline from initial text content. This is is
337 // particularly important in non-interactive mode where leading
338 // newlines are very visible.
339 if len(currentAssistant.Parts) == 0 {
340 text = strings.TrimPrefix(text, "\n")
341 }
342
343 currentAssistant.AppendContent(text)
344 return a.messages.Update(genCtx, *currentAssistant)
345 },
346 OnToolInputStart: func(id string, toolName string) error {
347 toolCall := message.ToolCall{
348 ID: id,
349 Name: toolName,
350 ProviderExecuted: false,
351 Finished: false,
352 }
353 currentAssistant.AddToolCall(toolCall)
354 return a.messages.Update(genCtx, *currentAssistant)
355 },
356 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
357 // TODO: implement
358 },
359 OnToolCall: func(tc fantasy.ToolCallContent) error {
360 toolCall := message.ToolCall{
361 ID: tc.ToolCallID,
362 Name: tc.ToolName,
363 Input: tc.Input,
364 ProviderExecuted: false,
365 Finished: true,
366 }
367 currentAssistant.AddToolCall(toolCall)
368 return a.messages.Update(genCtx, *currentAssistant)
369 },
370 OnToolResult: func(result fantasy.ToolResultContent) error {
371 toolResult := a.convertToToolResult(result)
372 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
373 Role: message.Tool,
374 Parts: []message.ContentPart{
375 toolResult,
376 },
377 })
378 return createMsgErr
379 },
380 OnStepFinish: func(stepResult fantasy.StepResult) error {
381 finishReason := message.FinishReasonUnknown
382 switch stepResult.FinishReason {
383 case fantasy.FinishReasonLength:
384 finishReason = message.FinishReasonMaxTokens
385 case fantasy.FinishReasonStop:
386 finishReason = message.FinishReasonEndTurn
387 case fantasy.FinishReasonToolCalls:
388 finishReason = message.FinishReasonToolUse
389 }
390 currentAssistant.AddFinish(finishReason, "", "")
391 sessionLock.Lock()
392 defer sessionLock.Unlock()
393
394 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
395 if getSessionErr != nil {
396 return getSessionErr
397 }
398 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
399 _, sessionErr := a.sessions.Save(ctx, updatedSession)
400 if sessionErr != nil {
401 return sessionErr
402 }
403 currentSession = updatedSession
404 return a.messages.Update(genCtx, *currentAssistant)
405 },
406 StopWhen: []fantasy.StopCondition{
407 func(_ []fantasy.StepResult) bool {
408 cw := int64(largeModel.CatwalkCfg.ContextWindow)
409 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
410 remaining := cw - tokens
411 var threshold int64
412 if cw > largeContextWindowThreshold {
413 threshold = largeContextWindowBuffer
414 } else {
415 threshold = int64(float64(cw) * smallContextWindowRatio)
416 }
417 if (remaining <= threshold) && !a.disableAutoSummarize {
418 shouldSummarize = true
419 return true
420 }
421 return false
422 },
423 func(steps []fantasy.StepResult) bool {
424 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
425 },
426 },
427 })
428
429 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
430
431 if err != nil {
432 isCancelErr := errors.Is(err, context.Canceled)
433 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
434 if currentAssistant == nil {
435 return result, err
436 }
437 // Ensure we finish thinking on error to close the reasoning state.
438 currentAssistant.FinishThinking()
439 toolCalls := currentAssistant.ToolCalls()
440 // INFO: we use the parent context here because the genCtx has been cancelled.
441 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
442 if createErr != nil {
443 return nil, createErr
444 }
445 for _, tc := range toolCalls {
446 if !tc.Finished {
447 tc.Finished = true
448 tc.Input = "{}"
449 currentAssistant.AddToolCall(tc)
450 updateErr := a.messages.Update(ctx, *currentAssistant)
451 if updateErr != nil {
452 return nil, updateErr
453 }
454 }
455
456 found := false
457 for _, msg := range msgs {
458 if msg.Role == message.Tool {
459 for _, tr := range msg.ToolResults() {
460 if tr.ToolCallID == tc.ID {
461 found = true
462 break
463 }
464 }
465 }
466 if found {
467 break
468 }
469 }
470 if found {
471 continue
472 }
473 content := "There was an error while executing the tool"
474 if isCancelErr {
475 content = "Tool execution canceled by user"
476 } else if isPermissionErr {
477 content = "User denied permission"
478 }
479 toolResult := message.ToolResult{
480 ToolCallID: tc.ID,
481 Name: tc.Name,
482 Content: content,
483 IsError: true,
484 }
485 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
486 Role: message.Tool,
487 Parts: []message.ContentPart{
488 toolResult,
489 },
490 })
491 if createErr != nil {
492 return nil, createErr
493 }
494 }
495 var fantasyErr *fantasy.Error
496 var providerErr *fantasy.ProviderError
497 const defaultTitle = "Provider Error"
498 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
499 if isCancelErr {
500 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
501 } else if isPermissionErr {
502 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
503 } else if errors.Is(err, hyper.ErrNoCredits) {
504 url := hyper.BaseURL()
505 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
506 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
507 } else if errors.As(err, &providerErr) {
508 if providerErr.Message == "The requested model is not supported." {
509 url := "https://github.com/settings/copilot/features"
510 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
511 currentAssistant.AddFinish(
512 message.FinishReasonError,
513 "Copilot model not enabled",
514 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
515 )
516 } else {
517 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
518 }
519 } else if errors.As(err, &fantasyErr) {
520 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
521 } else {
522 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
523 }
524 // Note: we use the parent context here because the genCtx has been
525 // cancelled.
526 updateErr := a.messages.Update(ctx, *currentAssistant)
527 if updateErr != nil {
528 return nil, updateErr
529 }
530 return nil, err
531 }
532
533 if shouldSummarize {
534 a.activeRequests.Del(call.SessionID)
535 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
536 return nil, summarizeErr
537 }
538 // If the agent wasn't done...
539 if len(currentAssistant.ToolCalls()) > 0 {
540 existing, ok := a.messageQueue.Get(call.SessionID)
541 if !ok {
542 existing = []SessionAgentCall{}
543 }
544 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
545 existing = append(existing, call)
546 a.messageQueue.Set(call.SessionID, existing)
547 }
548 }
549
550 // Release active request before processing queued messages.
551 a.activeRequests.Del(call.SessionID)
552 cancel()
553
554 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
555 if !ok || len(queuedMessages) == 0 {
556 return result, err
557 }
558 // There are queued messages restart the loop.
559 firstQueuedMessage := queuedMessages[0]
560 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
561 return a.Run(ctx, firstQueuedMessage)
562}
563
564func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
565 if a.IsSessionBusy(sessionID) {
566 return ErrSessionBusy
567 }
568
569 // Copy mutable fields under lock to avoid races with SetModels.
570 largeModel := a.largeModel.Get()
571 systemPromptPrefix := a.systemPromptPrefix.Get()
572
573 currentSession, err := a.sessions.Get(ctx, sessionID)
574 if err != nil {
575 return fmt.Errorf("failed to get session: %w", err)
576 }
577 msgs, err := a.getSessionMessages(ctx, currentSession)
578 if err != nil {
579 return err
580 }
581 if len(msgs) == 0 {
582 // Nothing to summarize.
583 return nil
584 }
585
586 aiMsgs, _ := a.preparePrompt(msgs)
587
588 genCtx, cancel := context.WithCancel(ctx)
589 a.activeRequests.Set(sessionID, cancel)
590 defer a.activeRequests.Del(sessionID)
591 defer cancel()
592
593 agent := fantasy.NewAgent(largeModel.Model,
594 fantasy.WithSystemPrompt(string(summaryPrompt)),
595 )
596 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
597 Role: message.Assistant,
598 Model: largeModel.Model.Model(),
599 Provider: largeModel.Model.Provider(),
600 IsSummaryMessage: true,
601 })
602 if err != nil {
603 return err
604 }
605
606 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
607
608 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
609 Prompt: summaryPromptText,
610 Messages: aiMsgs,
611 ProviderOptions: opts,
612 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
613 prepared.Messages = options.Messages
614 if systemPromptPrefix != "" {
615 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
616 }
617 return callContext, prepared, nil
618 },
619 OnReasoningDelta: func(id string, text string) error {
620 summaryMessage.AppendReasoningContent(text)
621 return a.messages.Update(genCtx, summaryMessage)
622 },
623 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
624 // Handle anthropic signature.
625 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
626 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
627 summaryMessage.AppendReasoningSignature(signature.Signature)
628 }
629 }
630 summaryMessage.FinishThinking()
631 return a.messages.Update(genCtx, summaryMessage)
632 },
633 OnTextDelta: func(id, text string) error {
634 summaryMessage.AppendContent(text)
635 return a.messages.Update(genCtx, summaryMessage)
636 },
637 })
638 if err != nil {
639 isCancelErr := errors.Is(err, context.Canceled)
640 if isCancelErr {
641 // User cancelled summarize we need to remove the summary message.
642 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
643 return deleteErr
644 }
645 return err
646 }
647
648 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
649 err = a.messages.Update(genCtx, summaryMessage)
650 if err != nil {
651 return err
652 }
653
654 var openrouterCost *float64
655 for _, step := range resp.Steps {
656 stepCost := a.openrouterCost(step.ProviderMetadata)
657 if stepCost != nil {
658 newCost := *stepCost
659 if openrouterCost != nil {
660 newCost += *openrouterCost
661 }
662 openrouterCost = &newCost
663 }
664 }
665
666 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
667
668 // Just in case, get just the last usage info.
669 usage := resp.Response.Usage
670 currentSession.SummaryMessageID = summaryMessage.ID
671 currentSession.CompletionTokens = usage.OutputTokens
672 currentSession.PromptTokens = 0
673 _, err = a.sessions.Save(genCtx, currentSession)
674 return err
675}
676
677func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
678 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
679 return fantasy.ProviderOptions{}
680 }
681 return fantasy.ProviderOptions{
682 anthropic.Name: &anthropic.ProviderCacheControlOptions{
683 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
684 },
685 bedrock.Name: &anthropic.ProviderCacheControlOptions{
686 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
687 },
688 vercel.Name: &anthropic.ProviderCacheControlOptions{
689 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
690 },
691 }
692}
693
694func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
695 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
696 var attachmentParts []message.ContentPart
697 for _, attachment := range call.Attachments {
698 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
699 }
700 parts = append(parts, attachmentParts...)
701 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
702 Role: message.User,
703 Parts: parts,
704 })
705 if err != nil {
706 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
707 }
708 return msg, nil
709}
710
711func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
712 var history []fantasy.Message
713 if !a.isSubAgent {
714 history = append(history, fantasy.NewUserMessage(
715 fmt.Sprintf("<system_reminder>%s</system_reminder>",
716 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
717If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
718If not, please feel free to ignore. Again do not mention this message to the user.`,
719 ),
720 ))
721 }
722 for _, m := range msgs {
723 if len(m.Parts) == 0 {
724 continue
725 }
726 // Assistant message without content or tool calls (cancelled before it
727 // returned anything).
728 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
729 continue
730 }
731 history = append(history, m.ToAIMessage()...)
732 }
733
734 var files []fantasy.FilePart
735 for _, attachment := range attachments {
736 if attachment.IsText() {
737 continue
738 }
739 files = append(files, fantasy.FilePart{
740 Filename: attachment.FileName,
741 Data: attachment.Content,
742 MediaType: attachment.MimeType,
743 })
744 }
745
746 return history, files
747}
748
749func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
750 msgs, err := a.messages.List(ctx, session.ID)
751 if err != nil {
752 return nil, fmt.Errorf("failed to list messages: %w", err)
753 }
754
755 if session.SummaryMessageID != "" {
756 summaryMsgIndex := -1
757 for i, msg := range msgs {
758 if msg.ID == session.SummaryMessageID {
759 summaryMsgIndex = i
760 break
761 }
762 }
763 if summaryMsgIndex != -1 {
764 msgs = msgs[summaryMsgIndex:]
765 msgs[0].Role = message.User
766 }
767 }
768 return msgs, nil
769}
770
771// generateTitle generates a session titled based on the initial prompt.
772func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
773 if userPrompt == "" {
774 return
775 }
776
777 smallModel := a.smallModel.Get()
778 largeModel := a.largeModel.Get()
779 systemPromptPrefix := a.systemPromptPrefix.Get()
780
781 var maxOutputTokens int64 = 40
782 if smallModel.CatwalkCfg.CanReason {
783 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
784 }
785
786 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
787 return fantasy.NewAgent(m,
788 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
789 fantasy.WithMaxOutputTokens(tok),
790 )
791 }
792
793 streamCall := fantasy.AgentStreamCall{
794 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
795 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
796 prepared.Messages = opts.Messages
797 if systemPromptPrefix != "" {
798 prepared.Messages = append([]fantasy.Message{
799 fantasy.NewSystemMessage(systemPromptPrefix),
800 }, prepared.Messages...)
801 }
802 return callCtx, prepared, nil
803 },
804 }
805
806 // Use the small model to generate the title.
807 model := smallModel
808 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
809 resp, err := agent.Stream(ctx, streamCall)
810 if err == nil {
811 // We successfully generated a title with the small model.
812 slog.Debug("Generated title with small model")
813 } else {
814 // It didn't work. Let's try with the big model.
815 slog.Error("Error generating title with small model; trying big model", "err", err)
816 model = largeModel
817 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
818 resp, err = agent.Stream(ctx, streamCall)
819 if err == nil {
820 slog.Debug("Generated title with large model")
821 } else {
822 // Welp, the large model didn't work either. Use the default
823 // session name and return.
824 slog.Error("Error generating title with large model", "err", err)
825 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
826 if saveErr != nil {
827 slog.Error("Failed to save session title and usage", "error", saveErr)
828 }
829 return
830 }
831 }
832
833 if resp == nil {
834 // Actually, we didn't get a response so we can't. Use the default
835 // session name and return.
836 slog.Error("Response is nil; can't generate title")
837 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
838 if saveErr != nil {
839 slog.Error("Failed to save session title and usage", "error", saveErr)
840 }
841 return
842 }
843
844 // Clean up title.
845 var title string
846 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
847
848 // Remove thinking tags if present.
849 title = thinkTagRegex.ReplaceAllString(title, "")
850
851 title = strings.TrimSpace(title)
852 if title == "" {
853 slog.Debug("Empty title; using fallback")
854 title = defaultSessionName
855 }
856
857 // Calculate usage and cost.
858 var openrouterCost *float64
859 for _, step := range resp.Steps {
860 stepCost := a.openrouterCost(step.ProviderMetadata)
861 if stepCost != nil {
862 newCost := *stepCost
863 if openrouterCost != nil {
864 newCost += *openrouterCost
865 }
866 openrouterCost = &newCost
867 }
868 }
869
870 modelConfig := model.CatwalkCfg
871 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
872 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
873 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
874 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
875
876 // Use override cost if available (e.g., from OpenRouter).
877 if openrouterCost != nil {
878 cost = *openrouterCost
879 }
880
881 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
882 completionTokens := resp.TotalUsage.OutputTokens
883
884 // Atomically update only title and usage fields to avoid overriding other
885 // concurrent session updates.
886 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
887 if saveErr != nil {
888 slog.Error("Failed to save session title and usage", "error", saveErr)
889 return
890 }
891}
892
893func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
894 openrouterMetadata, ok := metadata[openrouter.Name]
895 if !ok {
896 return nil
897 }
898
899 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
900 if !ok {
901 return nil
902 }
903 return &opts.Usage.Cost
904}
905
906func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
907 modelConfig := model.CatwalkCfg
908 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
909 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
910 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
911 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
912
913 a.eventTokensUsed(session.ID, model, usage, cost)
914
915 if overrideCost != nil {
916 session.Cost += *overrideCost
917 } else {
918 session.Cost += cost
919 }
920
921 session.CompletionTokens = usage.OutputTokens
922 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
923}
924
925func (a *sessionAgent) Cancel(sessionID string) {
926 // Cancel regular requests. Don't use Take() here - we need the entry to
927 // remain in activeRequests so IsBusy() returns true until the goroutine
928 // fully completes (including error handling that may access the DB).
929 // The defer in processRequest will clean up the entry.
930 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
931 slog.Debug("Request cancellation initiated", "session_id", sessionID)
932 cancel()
933 }
934
935 // Also check for summarize requests.
936 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
937 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
938 cancel()
939 }
940
941 if a.QueuedPrompts(sessionID) > 0 {
942 slog.Debug("Clearing queued prompts", "session_id", sessionID)
943 a.messageQueue.Del(sessionID)
944 }
945}
946
947func (a *sessionAgent) ClearQueue(sessionID string) {
948 if a.QueuedPrompts(sessionID) > 0 {
949 slog.Debug("Clearing queued prompts", "session_id", sessionID)
950 a.messageQueue.Del(sessionID)
951 }
952}
953
954func (a *sessionAgent) CancelAll() {
955 if !a.IsBusy() {
956 return
957 }
958 for key := range a.activeRequests.Seq2() {
959 a.Cancel(key) // key is sessionID
960 }
961
962 timeout := time.After(5 * time.Second)
963 for a.IsBusy() {
964 select {
965 case <-timeout:
966 return
967 default:
968 time.Sleep(200 * time.Millisecond)
969 }
970 }
971}
972
973func (a *sessionAgent) IsBusy() bool {
974 var busy bool
975 for cancelFunc := range a.activeRequests.Seq() {
976 if cancelFunc != nil {
977 busy = true
978 break
979 }
980 }
981 return busy
982}
983
984func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
985 _, busy := a.activeRequests.Get(sessionID)
986 return busy
987}
988
989func (a *sessionAgent) QueuedPrompts(sessionID string) int {
990 l, ok := a.messageQueue.Get(sessionID)
991 if !ok {
992 return 0
993 }
994 return len(l)
995}
996
997func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
998 l, ok := a.messageQueue.Get(sessionID)
999 if !ok {
1000 return nil
1001 }
1002 prompts := make([]string, len(l))
1003 for i, call := range l {
1004 prompts[i] = call.Prompt
1005 }
1006 return prompts
1007}
1008
1009func (a *sessionAgent) SetModels(large Model, small Model) {
1010 a.largeModel.Set(large)
1011 a.smallModel.Set(small)
1012}
1013
1014func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1015 a.tools.SetSlice(tools)
1016}
1017
1018func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1019 a.systemPrompt.Set(systemPrompt)
1020}
1021
1022func (a *sessionAgent) Model() Model {
1023 return a.largeModel.Get()
1024}
1025
1026// convertToToolResult converts a fantasy tool result to a message tool result.
1027func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1028 baseResult := message.ToolResult{
1029 ToolCallID: result.ToolCallID,
1030 Name: result.ToolName,
1031 Metadata: result.ClientMetadata,
1032 }
1033
1034 switch result.Result.GetType() {
1035 case fantasy.ToolResultContentTypeText:
1036 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1037 baseResult.Content = r.Text
1038 }
1039 case fantasy.ToolResultContentTypeError:
1040 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1041 baseResult.Content = r.Error.Error()
1042 baseResult.IsError = true
1043 }
1044 case fantasy.ToolResultContentTypeMedia:
1045 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1046 content := r.Text
1047 if content == "" {
1048 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1049 }
1050 baseResult.Content = content
1051 baseResult.Data = r.Data
1052 baseResult.MIMEType = r.MediaType
1053 }
1054 }
1055
1056 return baseResult
1057}
1058
1059// workaroundProviderMediaLimitations converts media content in tool results to
1060// user messages for providers that don't natively support images in tool results.
1061//
1062// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1063// don't support sending images/media in tool result messages - they only accept
1064// text in tool results. However, they DO support images in user messages.
1065//
1066// If we send media in tool results to these providers, the API returns an error.
1067//
1068// Solution: For these providers, we:
1069// 1. Replace the media in the tool result with a text placeholder
1070// 2. Inject a user message immediately after with the image as a file attachment
1071// 3. This maintains the tool execution flow while working around API limitations
1072//
1073// Anthropic and Bedrock support images natively in tool results, so we skip
1074// this workaround for them.
1075//
1076// Example transformation:
1077//
1078// BEFORE: [tool result: image data]
1079// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1080func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1081 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1082 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1083
1084 if providerSupportsMedia {
1085 return messages
1086 }
1087
1088 convertedMessages := make([]fantasy.Message, 0, len(messages))
1089
1090 for _, msg := range messages {
1091 if msg.Role != fantasy.MessageRoleTool {
1092 convertedMessages = append(convertedMessages, msg)
1093 continue
1094 }
1095
1096 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1097 var mediaFiles []fantasy.FilePart
1098
1099 for _, part := range msg.Content {
1100 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1101 if !ok {
1102 textParts = append(textParts, part)
1103 continue
1104 }
1105
1106 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1107 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1108 if err != nil {
1109 slog.Warn("Failed to decode media data", "error", err)
1110 textParts = append(textParts, part)
1111 continue
1112 }
1113
1114 mediaFiles = append(mediaFiles, fantasy.FilePart{
1115 Data: decoded,
1116 MediaType: media.MediaType,
1117 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1118 })
1119
1120 textParts = append(textParts, fantasy.ToolResultPart{
1121 ToolCallID: toolResult.ToolCallID,
1122 Output: fantasy.ToolResultOutputContentText{
1123 Text: "[Image/media content loaded - see attached file]",
1124 },
1125 ProviderOptions: toolResult.ProviderOptions,
1126 })
1127 } else {
1128 textParts = append(textParts, part)
1129 }
1130 }
1131
1132 convertedMessages = append(convertedMessages, fantasy.Message{
1133 Role: fantasy.MessageRoleTool,
1134 Content: textParts,
1135 })
1136
1137 if len(mediaFiles) > 0 {
1138 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1139 "Here is the media content from the tool result:",
1140 mediaFiles...,
1141 ))
1142 }
1143 }
1144
1145 return convertedMessages
1146}
1147
1148// buildSummaryPrompt constructs the prompt text for session summarization.
1149func buildSummaryPrompt(todos []session.Todo) string {
1150 var sb strings.Builder
1151 sb.WriteString("Provide a detailed summary of our conversation above.")
1152 if len(todos) > 0 {
1153 sb.WriteString("\n\n## Current Todo List\n\n")
1154 for _, t := range todos {
1155 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1156 }
1157 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1158 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1159 }
1160 return sb.String()
1161}