1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41 "github.com/charmbracelet/x/exp/charmtone"
42)
43
44const (
45 defaultSessionName = "Untitled Session"
46
47 // Constants for auto-summarization thresholds
48 largeContextWindowThreshold = 200_000
49 largeContextWindowBuffer = 20_000
50 smallContextWindowRatio = 0.2
51)
52
53//go:embed templates/title.md
54var titlePrompt []byte
55
56//go:embed templates/summary.md
57var summaryPrompt []byte
58
59// Used to remove <think> tags from generated titles.
60var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
61
62type SessionAgentCall struct {
63 SessionID string
64 Prompt string
65 ProviderOptions fantasy.ProviderOptions
66 Attachments []message.Attachment
67 MaxOutputTokens int64
68 Temperature *float64
69 TopP *float64
70 TopK *int64
71 FrequencyPenalty *float64
72 PresencePenalty *float64
73}
74
75type SessionAgent interface {
76 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
77 SetModels(large Model, small Model)
78 SetTools(tools []fantasy.AgentTool)
79 SetSystemPrompt(systemPrompt string)
80 Cancel(sessionID string)
81 CancelAll()
82 IsSessionBusy(sessionID string) bool
83 IsBusy() bool
84 QueuedPrompts(sessionID string) int
85 QueuedPromptsList(sessionID string) []string
86 ClearQueue(sessionID string)
87 Summarize(context.Context, string, fantasy.ProviderOptions) error
88 Model() Model
89}
90
91type Model struct {
92 Model fantasy.LanguageModel
93 CatwalkCfg catwalk.Model
94 ModelCfg config.SelectedModel
95}
96
97type sessionAgent struct {
98 largeModel *csync.Value[Model]
99 smallModel *csync.Value[Model]
100 systemPromptPrefix *csync.Value[string]
101 systemPrompt *csync.Value[string]
102 tools *csync.Slice[fantasy.AgentTool]
103
104 isSubAgent bool
105 sessions session.Service
106 messages message.Service
107 disableAutoSummarize bool
108 isYolo bool
109
110 messageQueue *csync.Map[string, []SessionAgentCall]
111 activeRequests *csync.Map[string, context.CancelFunc]
112}
113
114type SessionAgentOptions struct {
115 LargeModel Model
116 SmallModel Model
117 SystemPromptPrefix string
118 SystemPrompt string
119 IsSubAgent bool
120 DisableAutoSummarize bool
121 IsYolo bool
122 Sessions session.Service
123 Messages message.Service
124 Tools []fantasy.AgentTool
125}
126
127func NewSessionAgent(
128 opts SessionAgentOptions,
129) SessionAgent {
130 return &sessionAgent{
131 largeModel: csync.NewValue(opts.LargeModel),
132 smallModel: csync.NewValue(opts.SmallModel),
133 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
134 systemPrompt: csync.NewValue(opts.SystemPrompt),
135 isSubAgent: opts.IsSubAgent,
136 sessions: opts.Sessions,
137 messages: opts.Messages,
138 disableAutoSummarize: opts.DisableAutoSummarize,
139 tools: csync.NewSliceFrom(opts.Tools),
140 isYolo: opts.IsYolo,
141 messageQueue: csync.NewMap[string, []SessionAgentCall](),
142 activeRequests: csync.NewMap[string, context.CancelFunc](),
143 }
144}
145
146func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
147 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
148 return nil, ErrEmptyPrompt
149 }
150 if call.SessionID == "" {
151 return nil, ErrSessionMissing
152 }
153
154 // Queue the message if busy
155 if a.IsSessionBusy(call.SessionID) {
156 existing, ok := a.messageQueue.Get(call.SessionID)
157 if !ok {
158 existing = []SessionAgentCall{}
159 }
160 existing = append(existing, call)
161 a.messageQueue.Set(call.SessionID, existing)
162 return nil, nil
163 }
164
165 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
166 agentTools := a.tools.Copy()
167 largeModel := a.largeModel.Get()
168 systemPrompt := a.systemPrompt.Get()
169 promptPrefix := a.systemPromptPrefix.Get()
170
171 if len(agentTools) > 0 {
172 // Add Anthropic caching to the last tool.
173 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
174 }
175
176 agent := fantasy.NewAgent(
177 largeModel.Model,
178 fantasy.WithSystemPrompt(systemPrompt),
179 fantasy.WithTools(agentTools...),
180 )
181
182 sessionLock := sync.Mutex{}
183 currentSession, err := a.sessions.Get(ctx, call.SessionID)
184 if err != nil {
185 return nil, fmt.Errorf("failed to get session: %w", err)
186 }
187
188 msgs, err := a.getSessionMessages(ctx, currentSession)
189 if err != nil {
190 return nil, fmt.Errorf("failed to get session messages: %w", err)
191 }
192
193 var wg sync.WaitGroup
194 // Generate title if first message.
195 if len(msgs) == 0 {
196 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
197 wg.Go(func() {
198 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
199 })
200 }
201 defer wg.Wait()
202
203 // Add the user message to the session.
204 _, err = a.createUserMessage(ctx, call)
205 if err != nil {
206 return nil, err
207 }
208
209 // Add the session to the context.
210 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
211
212 genCtx, cancel := context.WithCancel(ctx)
213 a.activeRequests.Set(call.SessionID, cancel)
214
215 defer cancel()
216 defer a.activeRequests.Del(call.SessionID)
217
218 history, files := a.preparePrompt(msgs, call.Attachments...)
219
220 startTime := time.Now()
221 a.eventPromptSent(call.SessionID)
222
223 var currentAssistant *message.Message
224 var shouldSummarize bool
225 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
226 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
227 Files: files,
228 Messages: history,
229 ProviderOptions: call.ProviderOptions,
230 MaxOutputTokens: &call.MaxOutputTokens,
231 TopP: call.TopP,
232 Temperature: call.Temperature,
233 PresencePenalty: call.PresencePenalty,
234 TopK: call.TopK,
235 FrequencyPenalty: call.FrequencyPenalty,
236 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
237 prepared.Messages = options.Messages
238 for i := range prepared.Messages {
239 prepared.Messages[i].ProviderOptions = nil
240 }
241
242 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
243 a.messageQueue.Del(call.SessionID)
244 for _, queued := range queuedCalls {
245 userMessage, createErr := a.createUserMessage(callContext, queued)
246 if createErr != nil {
247 return callContext, prepared, createErr
248 }
249 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
250 }
251
252 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
253
254 lastSystemRoleInx := 0
255 systemMessageUpdated := false
256 for i, msg := range prepared.Messages {
257 // Only add cache control to the last message.
258 if msg.Role == fantasy.MessageRoleSystem {
259 lastSystemRoleInx = i
260 } else if !systemMessageUpdated {
261 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
262 systemMessageUpdated = true
263 }
264 // Than add cache control to the last 2 messages.
265 if i > len(prepared.Messages)-3 {
266 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
267 }
268 }
269
270 if promptPrefix != "" {
271 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
272 }
273
274 var assistantMsg message.Message
275 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
276 Role: message.Assistant,
277 Parts: []message.ContentPart{},
278 Model: largeModel.ModelCfg.Model,
279 Provider: largeModel.ModelCfg.Provider,
280 })
281 if err != nil {
282 return callContext, prepared, err
283 }
284 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
285 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
286 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
287 currentAssistant = &assistantMsg
288 return callContext, prepared, err
289 },
290 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
291 currentAssistant.AppendReasoningContent(reasoning.Text)
292 return a.messages.Update(genCtx, *currentAssistant)
293 },
294 OnReasoningDelta: func(id string, text string) error {
295 currentAssistant.AppendReasoningContent(text)
296 return a.messages.Update(genCtx, *currentAssistant)
297 },
298 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
299 // handle anthropic signature
300 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
301 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
302 currentAssistant.AppendReasoningSignature(reasoning.Signature)
303 }
304 }
305 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
306 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
307 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
308 }
309 }
310 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
311 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
312 currentAssistant.SetReasoningResponsesData(reasoning)
313 }
314 }
315 currentAssistant.FinishThinking()
316 return a.messages.Update(genCtx, *currentAssistant)
317 },
318 OnTextDelta: func(id string, text string) error {
319 // Strip leading newline from initial text content. This is is
320 // particularly important in non-interactive mode where leading
321 // newlines are very visible.
322 if len(currentAssistant.Parts) == 0 {
323 text = strings.TrimPrefix(text, "\n")
324 }
325
326 currentAssistant.AppendContent(text)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnToolInputStart: func(id string, toolName string) error {
330 toolCall := message.ToolCall{
331 ID: id,
332 Name: toolName,
333 ProviderExecuted: false,
334 Finished: false,
335 }
336 currentAssistant.AddToolCall(toolCall)
337 return a.messages.Update(genCtx, *currentAssistant)
338 },
339 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
340 // TODO: implement
341 },
342 OnToolCall: func(tc fantasy.ToolCallContent) error {
343 toolCall := message.ToolCall{
344 ID: tc.ToolCallID,
345 Name: tc.ToolName,
346 Input: tc.Input,
347 ProviderExecuted: false,
348 Finished: true,
349 }
350 currentAssistant.AddToolCall(toolCall)
351 return a.messages.Update(genCtx, *currentAssistant)
352 },
353 OnToolResult: func(result fantasy.ToolResultContent) error {
354 toolResult := a.convertToToolResult(result)
355 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
356 Role: message.Tool,
357 Parts: []message.ContentPart{
358 toolResult,
359 },
360 })
361 return createMsgErr
362 },
363 OnStepFinish: func(stepResult fantasy.StepResult) error {
364 finishReason := message.FinishReasonUnknown
365 switch stepResult.FinishReason {
366 case fantasy.FinishReasonLength:
367 finishReason = message.FinishReasonMaxTokens
368 case fantasy.FinishReasonStop:
369 finishReason = message.FinishReasonEndTurn
370 case fantasy.FinishReasonToolCalls:
371 finishReason = message.FinishReasonToolUse
372 }
373 currentAssistant.AddFinish(finishReason, "", "")
374 sessionLock.Lock()
375 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
376 if getSessionErr != nil {
377 sessionLock.Unlock()
378 return getSessionErr
379 }
380 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
381 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
382 if sessionErr == nil {
383 currentSession = updatedSession
384 }
385 sessionLock.Unlock()
386 if sessionErr != nil {
387 return sessionErr
388 }
389 return a.messages.Update(genCtx, *currentAssistant)
390 },
391 StopWhen: []fantasy.StopCondition{
392 func(_ []fantasy.StepResult) bool {
393 cw := int64(largeModel.CatwalkCfg.ContextWindow)
394 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
395 remaining := cw - tokens
396 var threshold int64
397 if cw > largeContextWindowThreshold {
398 threshold = largeContextWindowBuffer
399 } else {
400 threshold = int64(float64(cw) * smallContextWindowRatio)
401 }
402 if (remaining <= threshold) && !a.disableAutoSummarize {
403 shouldSummarize = true
404 return true
405 }
406 return false
407 },
408 },
409 })
410
411 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
412
413 if err != nil {
414 isCancelErr := errors.Is(err, context.Canceled)
415 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
416 if currentAssistant == nil {
417 return result, err
418 }
419 // Ensure we finish thinking on error to close the reasoning state.
420 currentAssistant.FinishThinking()
421 toolCalls := currentAssistant.ToolCalls()
422 // INFO: we use the parent context here because the genCtx has been cancelled.
423 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
424 if createErr != nil {
425 return nil, createErr
426 }
427 for _, tc := range toolCalls {
428 if !tc.Finished {
429 tc.Finished = true
430 tc.Input = "{}"
431 currentAssistant.AddToolCall(tc)
432 updateErr := a.messages.Update(ctx, *currentAssistant)
433 if updateErr != nil {
434 return nil, updateErr
435 }
436 }
437
438 found := false
439 for _, msg := range msgs {
440 if msg.Role == message.Tool {
441 for _, tr := range msg.ToolResults() {
442 if tr.ToolCallID == tc.ID {
443 found = true
444 break
445 }
446 }
447 }
448 if found {
449 break
450 }
451 }
452 if found {
453 continue
454 }
455 content := "There was an error while executing the tool"
456 if isCancelErr {
457 content = "Tool execution canceled by user"
458 } else if isPermissionErr {
459 content = "User denied permission"
460 }
461 toolResult := message.ToolResult{
462 ToolCallID: tc.ID,
463 Name: tc.Name,
464 Content: content,
465 IsError: true,
466 }
467 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
468 Role: message.Tool,
469 Parts: []message.ContentPart{
470 toolResult,
471 },
472 })
473 if createErr != nil {
474 return nil, createErr
475 }
476 }
477 var fantasyErr *fantasy.Error
478 var providerErr *fantasy.ProviderError
479 const defaultTitle = "Provider Error"
480 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
481 if isCancelErr {
482 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
483 } else if isPermissionErr {
484 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
485 } else if errors.Is(err, hyper.ErrNoCredits) {
486 url := hyper.BaseURL()
487 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
488 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
489 } else if errors.As(err, &providerErr) {
490 if providerErr.Message == "The requested model is not supported." {
491 url := "https://github.com/settings/copilot/features"
492 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
493 currentAssistant.AddFinish(
494 message.FinishReasonError,
495 "Copilot model not enabled",
496 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
497 )
498 } else {
499 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
500 }
501 } else if errors.As(err, &fantasyErr) {
502 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
503 } else {
504 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
505 }
506 // Note: we use the parent context here because the genCtx has been
507 // cancelled.
508 updateErr := a.messages.Update(ctx, *currentAssistant)
509 if updateErr != nil {
510 return nil, updateErr
511 }
512 return nil, err
513 }
514
515 if shouldSummarize {
516 a.activeRequests.Del(call.SessionID)
517 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
518 return nil, summarizeErr
519 }
520 // If the agent wasn't done...
521 if len(currentAssistant.ToolCalls()) > 0 {
522 existing, ok := a.messageQueue.Get(call.SessionID)
523 if !ok {
524 existing = []SessionAgentCall{}
525 }
526 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
527 existing = append(existing, call)
528 a.messageQueue.Set(call.SessionID, existing)
529 }
530 }
531
532 // Release active request before processing queued messages.
533 a.activeRequests.Del(call.SessionID)
534 cancel()
535
536 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
537 if !ok || len(queuedMessages) == 0 {
538 return result, err
539 }
540 // There are queued messages restart the loop.
541 firstQueuedMessage := queuedMessages[0]
542 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
543 return a.Run(ctx, firstQueuedMessage)
544}
545
546func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
547 if a.IsSessionBusy(sessionID) {
548 return ErrSessionBusy
549 }
550
551 // Copy mutable fields under lock to avoid races with SetModels.
552 largeModel := a.largeModel.Get()
553 systemPromptPrefix := a.systemPromptPrefix.Get()
554
555 currentSession, err := a.sessions.Get(ctx, sessionID)
556 if err != nil {
557 return fmt.Errorf("failed to get session: %w", err)
558 }
559 msgs, err := a.getSessionMessages(ctx, currentSession)
560 if err != nil {
561 return err
562 }
563 if len(msgs) == 0 {
564 // Nothing to summarize.
565 return nil
566 }
567
568 aiMsgs, _ := a.preparePrompt(msgs)
569
570 genCtx, cancel := context.WithCancel(ctx)
571 a.activeRequests.Set(sessionID, cancel)
572 defer a.activeRequests.Del(sessionID)
573 defer cancel()
574
575 agent := fantasy.NewAgent(largeModel.Model,
576 fantasy.WithSystemPrompt(string(summaryPrompt)),
577 )
578 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
579 Role: message.Assistant,
580 Model: largeModel.Model.Model(),
581 Provider: largeModel.Model.Provider(),
582 IsSummaryMessage: true,
583 })
584 if err != nil {
585 return err
586 }
587
588 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
589
590 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
591 Prompt: summaryPromptText,
592 Messages: aiMsgs,
593 ProviderOptions: opts,
594 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
595 prepared.Messages = options.Messages
596 if systemPromptPrefix != "" {
597 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
598 }
599 return callContext, prepared, nil
600 },
601 OnReasoningDelta: func(id string, text string) error {
602 summaryMessage.AppendReasoningContent(text)
603 return a.messages.Update(genCtx, summaryMessage)
604 },
605 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
606 // Handle anthropic signature.
607 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
608 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
609 summaryMessage.AppendReasoningSignature(signature.Signature)
610 }
611 }
612 summaryMessage.FinishThinking()
613 return a.messages.Update(genCtx, summaryMessage)
614 },
615 OnTextDelta: func(id, text string) error {
616 summaryMessage.AppendContent(text)
617 return a.messages.Update(genCtx, summaryMessage)
618 },
619 })
620 if err != nil {
621 isCancelErr := errors.Is(err, context.Canceled)
622 if isCancelErr {
623 // User cancelled summarize we need to remove the summary message.
624 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
625 return deleteErr
626 }
627 return err
628 }
629
630 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
631 err = a.messages.Update(genCtx, summaryMessage)
632 if err != nil {
633 return err
634 }
635
636 var openrouterCost *float64
637 for _, step := range resp.Steps {
638 stepCost := a.openrouterCost(step.ProviderMetadata)
639 if stepCost != nil {
640 newCost := *stepCost
641 if openrouterCost != nil {
642 newCost += *openrouterCost
643 }
644 openrouterCost = &newCost
645 }
646 }
647
648 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
649
650 // Just in case, get just the last usage info.
651 usage := resp.Response.Usage
652 currentSession.SummaryMessageID = summaryMessage.ID
653 currentSession.CompletionTokens = usage.OutputTokens
654 currentSession.PromptTokens = 0
655 _, err = a.sessions.Save(genCtx, currentSession)
656 return err
657}
658
659func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
660 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
661 return fantasy.ProviderOptions{}
662 }
663 return fantasy.ProviderOptions{
664 anthropic.Name: &anthropic.ProviderCacheControlOptions{
665 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
666 },
667 bedrock.Name: &anthropic.ProviderCacheControlOptions{
668 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
669 },
670 }
671}
672
673func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
674 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
675 var attachmentParts []message.ContentPart
676 for _, attachment := range call.Attachments {
677 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
678 }
679 parts = append(parts, attachmentParts...)
680 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
681 Role: message.User,
682 Parts: parts,
683 })
684 if err != nil {
685 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
686 }
687 return msg, nil
688}
689
690func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
691 var history []fantasy.Message
692 if !a.isSubAgent {
693 history = append(history, fantasy.NewUserMessage(
694 fmt.Sprintf("<system_reminder>%s</system_reminder>",
695 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
696If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
697If not, please feel free to ignore. Again do not mention this message to the user.`,
698 ),
699 ))
700 }
701 for _, m := range msgs {
702 if len(m.Parts) == 0 {
703 continue
704 }
705 // Assistant message without content or tool calls (cancelled before it
706 // returned anything).
707 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
708 continue
709 }
710 history = append(history, m.ToAIMessage()...)
711 }
712
713 var files []fantasy.FilePart
714 for _, attachment := range attachments {
715 if attachment.IsText() {
716 continue
717 }
718 files = append(files, fantasy.FilePart{
719 Filename: attachment.FileName,
720 Data: attachment.Content,
721 MediaType: attachment.MimeType,
722 })
723 }
724
725 return history, files
726}
727
728func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
729 msgs, err := a.messages.List(ctx, session.ID)
730 if err != nil {
731 return nil, fmt.Errorf("failed to list messages: %w", err)
732 }
733
734 if session.SummaryMessageID != "" {
735 summaryMsgIndex := -1
736 for i, msg := range msgs {
737 if msg.ID == session.SummaryMessageID {
738 summaryMsgIndex = i
739 break
740 }
741 }
742 if summaryMsgIndex != -1 {
743 msgs = msgs[summaryMsgIndex:]
744 msgs[0].Role = message.User
745 }
746 }
747 return msgs, nil
748}
749
750// generateTitle generates a session titled based on the initial prompt.
751func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
752 if userPrompt == "" {
753 return
754 }
755
756 smallModel := a.smallModel.Get()
757 largeModel := a.largeModel.Get()
758 systemPromptPrefix := a.systemPromptPrefix.Get()
759
760 var maxOutputTokens int64 = 40
761 if smallModel.CatwalkCfg.CanReason {
762 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
763 }
764
765 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
766 return fantasy.NewAgent(m,
767 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
768 fantasy.WithMaxOutputTokens(tok),
769 )
770 }
771
772 streamCall := fantasy.AgentStreamCall{
773 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
774 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
775 prepared.Messages = opts.Messages
776 if systemPromptPrefix != "" {
777 prepared.Messages = append([]fantasy.Message{
778 fantasy.NewSystemMessage(systemPromptPrefix),
779 }, prepared.Messages...)
780 }
781 return callCtx, prepared, nil
782 },
783 }
784
785 // Use the small model to generate the title.
786 model := smallModel
787 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
788 resp, err := agent.Stream(ctx, streamCall)
789 if err == nil {
790 // We successfully generated a title with the small model.
791 slog.Info("generated title with small model")
792 } else {
793 // It didn't work. Let's try with the big model.
794 slog.Error("error generating title with small model; trying big model", "err", err)
795 model = largeModel
796 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
797 resp, err = agent.Stream(ctx, streamCall)
798 if err == nil {
799 slog.Info("generated title with large model")
800 } else {
801 // Welp, the large model didn't work either. Use the default
802 // session name and return.
803 slog.Error("error generating title with large model", "err", err)
804 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
805 if saveErr != nil {
806 slog.Error("failed to save session title and usage", "error", saveErr)
807 }
808 return
809 }
810 }
811
812 if resp == nil {
813 // Actually, we didn't get a response so we can't. Use the default
814 // session name and return.
815 slog.Error("response is nil; can't generate title")
816 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
817 if saveErr != nil {
818 slog.Error("failed to save session title and usage", "error", saveErr)
819 }
820 return
821 }
822
823 // Clean up title.
824 var title string
825 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
826
827 // Remove thinking tags if present.
828 title = thinkTagRegex.ReplaceAllString(title, "")
829
830 title = strings.TrimSpace(title)
831 if title == "" {
832 slog.Warn("empty title; using fallback")
833 title = defaultSessionName
834 }
835
836 // Calculate usage and cost.
837 var openrouterCost *float64
838 for _, step := range resp.Steps {
839 stepCost := a.openrouterCost(step.ProviderMetadata)
840 if stepCost != nil {
841 newCost := *stepCost
842 if openrouterCost != nil {
843 newCost += *openrouterCost
844 }
845 openrouterCost = &newCost
846 }
847 }
848
849 modelConfig := model.CatwalkCfg
850 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
851 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
852 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
853 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
854
855 // Use override cost if available (e.g., from OpenRouter).
856 if openrouterCost != nil {
857 cost = *openrouterCost
858 }
859
860 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
861 completionTokens := resp.TotalUsage.OutputTokens
862
863 // Atomically update only title and usage fields to avoid overriding other
864 // concurrent session updates.
865 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
866 if saveErr != nil {
867 slog.Error("failed to save session title and usage", "error", saveErr)
868 return
869 }
870}
871
872func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
873 openrouterMetadata, ok := metadata[openrouter.Name]
874 if !ok {
875 return nil
876 }
877
878 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
879 if !ok {
880 return nil
881 }
882 return &opts.Usage.Cost
883}
884
885func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
886 modelConfig := model.CatwalkCfg
887 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
888 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
889 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
890 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
891
892 a.eventTokensUsed(session.ID, model, usage, cost)
893
894 if overrideCost != nil {
895 session.Cost += *overrideCost
896 } else {
897 session.Cost += cost
898 }
899
900 session.CompletionTokens = usage.OutputTokens
901 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
902}
903
904func (a *sessionAgent) Cancel(sessionID string) {
905 // Cancel regular requests. Don't use Take() here - we need the entry to
906 // remain in activeRequests so IsBusy() returns true until the goroutine
907 // fully completes (including error handling that may access the DB).
908 // The defer in processRequest will clean up the entry.
909 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
910 slog.Info("Request cancellation initiated", "session_id", sessionID)
911 cancel()
912 }
913
914 // Also check for summarize requests.
915 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
916 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
917 cancel()
918 }
919
920 if a.QueuedPrompts(sessionID) > 0 {
921 slog.Info("Clearing queued prompts", "session_id", sessionID)
922 a.messageQueue.Del(sessionID)
923 }
924}
925
926func (a *sessionAgent) ClearQueue(sessionID string) {
927 if a.QueuedPrompts(sessionID) > 0 {
928 slog.Info("Clearing queued prompts", "session_id", sessionID)
929 a.messageQueue.Del(sessionID)
930 }
931}
932
933func (a *sessionAgent) CancelAll() {
934 if !a.IsBusy() {
935 return
936 }
937 for key := range a.activeRequests.Seq2() {
938 a.Cancel(key) // key is sessionID
939 }
940
941 timeout := time.After(5 * time.Second)
942 for a.IsBusy() {
943 select {
944 case <-timeout:
945 return
946 default:
947 time.Sleep(200 * time.Millisecond)
948 }
949 }
950}
951
952func (a *sessionAgent) IsBusy() bool {
953 var busy bool
954 for cancelFunc := range a.activeRequests.Seq() {
955 if cancelFunc != nil {
956 busy = true
957 break
958 }
959 }
960 return busy
961}
962
963func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
964 _, busy := a.activeRequests.Get(sessionID)
965 return busy
966}
967
968func (a *sessionAgent) QueuedPrompts(sessionID string) int {
969 l, ok := a.messageQueue.Get(sessionID)
970 if !ok {
971 return 0
972 }
973 return len(l)
974}
975
976func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
977 l, ok := a.messageQueue.Get(sessionID)
978 if !ok {
979 return nil
980 }
981 prompts := make([]string, len(l))
982 for i, call := range l {
983 prompts[i] = call.Prompt
984 }
985 return prompts
986}
987
988func (a *sessionAgent) SetModels(large Model, small Model) {
989 a.largeModel.Set(large)
990 a.smallModel.Set(small)
991}
992
993func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
994 a.tools.SetSlice(tools)
995}
996
997func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
998 a.systemPrompt.Set(systemPrompt)
999}
1000
1001func (a *sessionAgent) Model() Model {
1002 return a.largeModel.Get()
1003}
1004
1005// convertToToolResult converts a fantasy tool result to a message tool result.
1006func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1007 baseResult := message.ToolResult{
1008 ToolCallID: result.ToolCallID,
1009 Name: result.ToolName,
1010 Metadata: result.ClientMetadata,
1011 }
1012
1013 switch result.Result.GetType() {
1014 case fantasy.ToolResultContentTypeText:
1015 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1016 baseResult.Content = r.Text
1017 }
1018 case fantasy.ToolResultContentTypeError:
1019 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1020 baseResult.Content = r.Error.Error()
1021 baseResult.IsError = true
1022 }
1023 case fantasy.ToolResultContentTypeMedia:
1024 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1025 content := r.Text
1026 if content == "" {
1027 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1028 }
1029 baseResult.Content = content
1030 baseResult.Data = r.Data
1031 baseResult.MIMEType = r.MediaType
1032 }
1033 }
1034
1035 return baseResult
1036}
1037
1038// workaroundProviderMediaLimitations converts media content in tool results to
1039// user messages for providers that don't natively support images in tool results.
1040//
1041// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1042// don't support sending images/media in tool result messages - they only accept
1043// text in tool results. However, they DO support images in user messages.
1044//
1045// If we send media in tool results to these providers, the API returns an error.
1046//
1047// Solution: For these providers, we:
1048// 1. Replace the media in the tool result with a text placeholder
1049// 2. Inject a user message immediately after with the image as a file attachment
1050// 3. This maintains the tool execution flow while working around API limitations
1051//
1052// Anthropic and Bedrock support images natively in tool results, so we skip
1053// this workaround for them.
1054//
1055// Example transformation:
1056//
1057// BEFORE: [tool result: image data]
1058// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1059func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1060 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1061 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1062
1063 if providerSupportsMedia {
1064 return messages
1065 }
1066
1067 convertedMessages := make([]fantasy.Message, 0, len(messages))
1068
1069 for _, msg := range messages {
1070 if msg.Role != fantasy.MessageRoleTool {
1071 convertedMessages = append(convertedMessages, msg)
1072 continue
1073 }
1074
1075 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1076 var mediaFiles []fantasy.FilePart
1077
1078 for _, part := range msg.Content {
1079 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1080 if !ok {
1081 textParts = append(textParts, part)
1082 continue
1083 }
1084
1085 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1086 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1087 if err != nil {
1088 slog.Warn("failed to decode media data", "error", err)
1089 textParts = append(textParts, part)
1090 continue
1091 }
1092
1093 mediaFiles = append(mediaFiles, fantasy.FilePart{
1094 Data: decoded,
1095 MediaType: media.MediaType,
1096 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1097 })
1098
1099 textParts = append(textParts, fantasy.ToolResultPart{
1100 ToolCallID: toolResult.ToolCallID,
1101 Output: fantasy.ToolResultOutputContentText{
1102 Text: "[Image/media content loaded - see attached file]",
1103 },
1104 ProviderOptions: toolResult.ProviderOptions,
1105 })
1106 } else {
1107 textParts = append(textParts, part)
1108 }
1109 }
1110
1111 convertedMessages = append(convertedMessages, fantasy.Message{
1112 Role: fantasy.MessageRoleTool,
1113 Content: textParts,
1114 })
1115
1116 if len(mediaFiles) > 0 {
1117 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1118 "Here is the media content from the tool result:",
1119 mediaFiles...,
1120 ))
1121 }
1122 }
1123
1124 return convertedMessages
1125}
1126
1127// buildSummaryPrompt constructs the prompt text for session summarization.
1128func buildSummaryPrompt(todos []session.Todo) string {
1129 var sb strings.Builder
1130 sb.WriteString("Provide a detailed summary of our conversation above.")
1131 if len(todos) > 0 {
1132 sb.WriteString("\n\n## Current Todo List\n\n")
1133 for _, t := range todos {
1134 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1135 }
1136 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1137 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1138 }
1139 return sb.String()
1140}