1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/notification"
39 "github.com/charmbracelet/crush/internal/permission"
40 "github.com/charmbracelet/crush/internal/session"
41 "github.com/charmbracelet/crush/internal/stringext"
42)
43
44const defaultSessionName = "Untitled Session"
45
46//go:embed templates/title.md
47var titlePrompt []byte
48
49//go:embed templates/summary.md
50var summaryPrompt []byte
51
52// Used to remove <think> tags from generated titles.
53var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
54
55type SessionAgentCall struct {
56 SessionID string
57 Prompt string
58 ProviderOptions fantasy.ProviderOptions
59 Attachments []message.Attachment
60 MaxOutputTokens int64
61 Temperature *float64
62 TopP *float64
63 TopK *int64
64 FrequencyPenalty *float64
65 PresencePenalty *float64
66 NonInteractive bool
67}
68
69type SessionAgent interface {
70 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
71 SetModels(large Model, small Model)
72 SetTools(tools []fantasy.AgentTool)
73 Cancel(sessionID string)
74 CancelAll()
75 IsSessionBusy(sessionID string) bool
76 IsBusy() bool
77 QueuedPrompts(sessionID string) int
78 QueuedPromptsList(sessionID string) []string
79 ClearQueue(sessionID string)
80 Summarize(context.Context, string, fantasy.ProviderOptions) error
81 Model() Model
82}
83
84type Model struct {
85 Model fantasy.LanguageModel
86 CatwalkCfg catwalk.Model
87 ModelCfg config.SelectedModel
88}
89
90type sessionAgent struct {
91 largeModel Model
92 smallModel Model
93 systemPromptPrefix string
94 systemPrompt string
95 isSubAgent bool
96 tools []fantasy.AgentTool
97 sessions session.Service
98 messages message.Service
99 disableAutoSummarize bool
100 isYolo bool
101
102 messageQueue *csync.Map[string, []SessionAgentCall]
103 activeRequests *csync.Map[string, context.CancelFunc]
104}
105
106type SessionAgentOptions struct {
107 LargeModel Model
108 SmallModel Model
109 SystemPromptPrefix string
110 SystemPrompt string
111 IsSubAgent bool
112 DisableAutoSummarize bool
113 IsYolo bool
114 Sessions session.Service
115 Messages message.Service
116 Tools []fantasy.AgentTool
117}
118
119func NewSessionAgent(
120 opts SessionAgentOptions,
121) SessionAgent {
122 return &sessionAgent{
123 largeModel: opts.LargeModel,
124 smallModel: opts.SmallModel,
125 systemPromptPrefix: opts.SystemPromptPrefix,
126 systemPrompt: opts.SystemPrompt,
127 isSubAgent: opts.IsSubAgent,
128 sessions: opts.Sessions,
129 messages: opts.Messages,
130 disableAutoSummarize: opts.DisableAutoSummarize,
131 tools: opts.Tools,
132 isYolo: opts.IsYolo,
133 messageQueue: csync.NewMap[string, []SessionAgentCall](),
134 activeRequests: csync.NewMap[string, context.CancelFunc](),
135 }
136}
137
138func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
139 if call.Prompt == "" {
140 return nil, ErrEmptyPrompt
141 }
142 if call.SessionID == "" {
143 return nil, ErrSessionMissing
144 }
145
146 // Queue the message if busy
147 if a.IsSessionBusy(call.SessionID) {
148 existing, ok := a.messageQueue.Get(call.SessionID)
149 if !ok {
150 existing = []SessionAgentCall{}
151 }
152 existing = append(existing, call)
153 a.messageQueue.Set(call.SessionID, existing)
154 return nil, nil
155 }
156
157 if len(a.tools) > 0 {
158 // Add Anthropic caching to the last tool.
159 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
160 }
161
162 agent := fantasy.NewAgent(
163 a.largeModel.Model,
164 fantasy.WithSystemPrompt(a.systemPrompt),
165 fantasy.WithTools(a.tools...),
166 )
167
168 sessionLock := sync.Mutex{}
169 currentSession, err := a.sessions.Get(ctx, call.SessionID)
170 if err != nil {
171 return nil, fmt.Errorf("failed to get session: %w", err)
172 }
173
174 msgs, err := a.getSessionMessages(ctx, currentSession)
175 if err != nil {
176 return nil, fmt.Errorf("failed to get session messages: %w", err)
177 }
178
179 var wg sync.WaitGroup
180 // Generate title if first message.
181 if len(msgs) == 0 {
182 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
183 wg.Go(func() {
184 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
185 })
186 }
187
188 // Add the user message to the session.
189 _, err = a.createUserMessage(ctx, call)
190 if err != nil {
191 return nil, err
192 }
193
194 // Add the session to the context.
195 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
196
197 genCtx, cancel := context.WithCancel(ctx)
198 a.activeRequests.Set(call.SessionID, cancel)
199
200 defer cancel()
201 defer a.activeRequests.Del(call.SessionID)
202
203 history, files := a.preparePrompt(msgs, call.Attachments...)
204
205 startTime := time.Now()
206 a.eventPromptSent(call.SessionID)
207
208 var currentAssistant *message.Message
209 var shouldSummarize bool
210 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
211 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
212 Files: files,
213 Messages: history,
214 ProviderOptions: call.ProviderOptions,
215 MaxOutputTokens: &call.MaxOutputTokens,
216 TopP: call.TopP,
217 Temperature: call.Temperature,
218 PresencePenalty: call.PresencePenalty,
219 TopK: call.TopK,
220 FrequencyPenalty: call.FrequencyPenalty,
221 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
222 prepared.Messages = options.Messages
223 for i := range prepared.Messages {
224 prepared.Messages[i].ProviderOptions = nil
225 }
226
227 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
228 a.messageQueue.Del(call.SessionID)
229 for _, queued := range queuedCalls {
230 userMessage, createErr := a.createUserMessage(callContext, queued)
231 if createErr != nil {
232 return callContext, prepared, createErr
233 }
234 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
235 }
236
237 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
238
239 lastSystemRoleInx := 0
240 systemMessageUpdated := false
241 for i, msg := range prepared.Messages {
242 // Only add cache control to the last message.
243 if msg.Role == fantasy.MessageRoleSystem {
244 lastSystemRoleInx = i
245 } else if !systemMessageUpdated {
246 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
247 systemMessageUpdated = true
248 }
249 // Than add cache control to the last 2 messages.
250 if i > len(prepared.Messages)-3 {
251 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
252 }
253 }
254
255 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
256 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
257 }
258
259 var assistantMsg message.Message
260 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
261 Role: message.Assistant,
262 Parts: []message.ContentPart{},
263 Model: a.largeModel.ModelCfg.Model,
264 Provider: a.largeModel.ModelCfg.Provider,
265 })
266 if err != nil {
267 return callContext, prepared, err
268 }
269 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
270 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
271 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
272 currentAssistant = &assistantMsg
273 return callContext, prepared, err
274 },
275 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
276 currentAssistant.AppendReasoningContent(reasoning.Text)
277 return a.messages.Update(genCtx, *currentAssistant)
278 },
279 OnReasoningDelta: func(id string, text string) error {
280 currentAssistant.AppendReasoningContent(text)
281 return a.messages.Update(genCtx, *currentAssistant)
282 },
283 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
284 // handle anthropic signature
285 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
286 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
287 currentAssistant.AppendReasoningSignature(reasoning.Signature)
288 }
289 }
290 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
291 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
292 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
293 }
294 }
295 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
296 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
297 currentAssistant.SetReasoningResponsesData(reasoning)
298 }
299 }
300 currentAssistant.FinishThinking()
301 return a.messages.Update(genCtx, *currentAssistant)
302 },
303 OnTextDelta: func(id string, text string) error {
304 // Strip leading newline from initial text content. This is is
305 // particularly important in non-interactive mode where leading
306 // newlines are very visible.
307 if len(currentAssistant.Parts) == 0 {
308 text = strings.TrimPrefix(text, "\n")
309 }
310
311 currentAssistant.AppendContent(text)
312 return a.messages.Update(genCtx, *currentAssistant)
313 },
314 OnToolInputStart: func(id string, toolName string) error {
315 toolCall := message.ToolCall{
316 ID: id,
317 Name: toolName,
318 ProviderExecuted: false,
319 Finished: false,
320 }
321 currentAssistant.AddToolCall(toolCall)
322 return a.messages.Update(genCtx, *currentAssistant)
323 },
324 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
325 // TODO: implement
326 },
327 OnToolCall: func(tc fantasy.ToolCallContent) error {
328 toolCall := message.ToolCall{
329 ID: tc.ToolCallID,
330 Name: tc.ToolName,
331 Input: tc.Input,
332 ProviderExecuted: false,
333 Finished: true,
334 }
335 currentAssistant.AddToolCall(toolCall)
336 return a.messages.Update(genCtx, *currentAssistant)
337 },
338 OnToolResult: func(result fantasy.ToolResultContent) error {
339 toolResult := a.convertToToolResult(result)
340 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
341 Role: message.Tool,
342 Parts: []message.ContentPart{
343 toolResult,
344 },
345 })
346 return createMsgErr
347 },
348 OnStepFinish: func(stepResult fantasy.StepResult) error {
349 finishReason := message.FinishReasonUnknown
350 switch stepResult.FinishReason {
351 case fantasy.FinishReasonLength:
352 finishReason = message.FinishReasonMaxTokens
353 case fantasy.FinishReasonStop:
354 finishReason = message.FinishReasonEndTurn
355 case fantasy.FinishReasonToolCalls:
356 finishReason = message.FinishReasonToolUse
357 }
358 currentAssistant.AddFinish(finishReason, "", "")
359 sessionLock.Lock()
360 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
361 if getSessionErr != nil {
362 sessionLock.Unlock()
363 return getSessionErr
364 }
365 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
366 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
367 sessionLock.Unlock()
368 if sessionErr != nil {
369 return sessionErr
370 }
371 return a.messages.Update(genCtx, *currentAssistant)
372 },
373 StopWhen: []fantasy.StopCondition{
374 func(_ []fantasy.StepResult) bool {
375 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
376 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
377 remaining := cw - tokens
378 var threshold int64
379 if cw > 200_000 {
380 threshold = 20_000
381 } else {
382 threshold = int64(float64(cw) * 0.2)
383 }
384 if (remaining <= threshold) && !a.disableAutoSummarize {
385 shouldSummarize = true
386 return true
387 }
388 return false
389 },
390 },
391 })
392
393 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
394
395 if err != nil {
396 isCancelErr := errors.Is(err, context.Canceled)
397 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
398 if currentAssistant == nil {
399 return result, err
400 }
401 // Ensure we finish thinking on error to close the reasoning state.
402 currentAssistant.FinishThinking()
403 toolCalls := currentAssistant.ToolCalls()
404 // INFO: we use the parent context here because the genCtx has been cancelled.
405 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
406 if createErr != nil {
407 return nil, createErr
408 }
409 for _, tc := range toolCalls {
410 if !tc.Finished {
411 tc.Finished = true
412 tc.Input = "{}"
413 currentAssistant.AddToolCall(tc)
414 updateErr := a.messages.Update(ctx, *currentAssistant)
415 if updateErr != nil {
416 return nil, updateErr
417 }
418 }
419
420 found := false
421 for _, msg := range msgs {
422 if msg.Role == message.Tool {
423 for _, tr := range msg.ToolResults() {
424 if tr.ToolCallID == tc.ID {
425 found = true
426 break
427 }
428 }
429 }
430 if found {
431 break
432 }
433 }
434 if found {
435 continue
436 }
437 content := "There was an error while executing the tool"
438 if isCancelErr {
439 content = "Tool execution canceled by user"
440 } else if isPermissionErr {
441 content = "User denied permission"
442 }
443 toolResult := message.ToolResult{
444 ToolCallID: tc.ID,
445 Name: tc.Name,
446 Content: content,
447 IsError: true,
448 }
449 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
450 Role: message.Tool,
451 Parts: []message.ContentPart{
452 toolResult,
453 },
454 })
455 if createErr != nil {
456 return nil, createErr
457 }
458 }
459 var fantasyErr *fantasy.Error
460 var providerErr *fantasy.ProviderError
461 const defaultTitle = "Provider Error"
462 if isCancelErr {
463 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
464 } else if isPermissionErr {
465 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
466 } else if errors.Is(err, hyper.ErrNoCredits) {
467 url := hyper.BaseURL()
468 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
469 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
470 } else if errors.As(err, &providerErr) {
471 if providerErr.Message == "The requested model is not supported." {
472 url := "https://github.com/settings/copilot/features"
473 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
474 currentAssistant.AddFinish(
475 message.FinishReasonError,
476 "Copilot model not enabled",
477 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
478 )
479 } else {
480 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
481 }
482 } else if errors.As(err, &fantasyErr) {
483 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
484 } else {
485 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
486 }
487 // Note: we use the parent context here because the genCtx has been
488 // cancelled.
489 updateErr := a.messages.Update(ctx, *currentAssistant)
490 if updateErr != nil {
491 return nil, updateErr
492 }
493 return nil, err
494 }
495 wg.Wait()
496
497 // Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
498 if !call.NonInteractive {
499 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
500 _ = notification.Send("Crush is waiting...", notifBody)
501 }
502
503 if shouldSummarize {
504 a.activeRequests.Del(call.SessionID)
505 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
506 return nil, summarizeErr
507 }
508 // If the agent wasn't done...
509 if len(currentAssistant.ToolCalls()) > 0 {
510 existing, ok := a.messageQueue.Get(call.SessionID)
511 if !ok {
512 existing = []SessionAgentCall{}
513 }
514 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
515 existing = append(existing, call)
516 a.messageQueue.Set(call.SessionID, existing)
517 }
518 }
519
520 // Release active request before processing queued messages.
521 a.activeRequests.Del(call.SessionID)
522 cancel()
523
524 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
525 if !ok || len(queuedMessages) == 0 {
526 return result, err
527 }
528 // There are queued messages restart the loop.
529 firstQueuedMessage := queuedMessages[0]
530 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
531 return a.Run(ctx, firstQueuedMessage)
532}
533
534func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
535 if a.IsSessionBusy(sessionID) {
536 return ErrSessionBusy
537 }
538
539 currentSession, err := a.sessions.Get(ctx, sessionID)
540 if err != nil {
541 return fmt.Errorf("failed to get session: %w", err)
542 }
543 msgs, err := a.getSessionMessages(ctx, currentSession)
544 if err != nil {
545 return err
546 }
547 if len(msgs) == 0 {
548 // Nothing to summarize.
549 return nil
550 }
551
552 aiMsgs, _ := a.preparePrompt(msgs)
553
554 genCtx, cancel := context.WithCancel(ctx)
555 a.activeRequests.Set(sessionID, cancel)
556 defer a.activeRequests.Del(sessionID)
557 defer cancel()
558
559 agent := fantasy.NewAgent(a.largeModel.Model,
560 fantasy.WithSystemPrompt(string(summaryPrompt)),
561 )
562 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
563 Role: message.Assistant,
564 Model: a.largeModel.Model.Model(),
565 Provider: a.largeModel.Model.Provider(),
566 IsSummaryMessage: true,
567 })
568 if err != nil {
569 return err
570 }
571
572 summaryPromptText := "Provide a detailed summary of our conversation above."
573 if len(currentSession.Todos) > 0 {
574 summaryPromptText += "\n\n## Current Todo List\n\n"
575 for _, t := range currentSession.Todos {
576 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
577 }
578 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
579 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
580 }
581
582 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
583 Prompt: summaryPromptText,
584 Messages: aiMsgs,
585 ProviderOptions: opts,
586 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
587 prepared.Messages = options.Messages
588 if a.systemPromptPrefix != "" {
589 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
590 }
591 return callContext, prepared, nil
592 },
593 OnReasoningDelta: func(id string, text string) error {
594 summaryMessage.AppendReasoningContent(text)
595 return a.messages.Update(genCtx, summaryMessage)
596 },
597 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
598 // Handle anthropic signature.
599 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
600 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
601 summaryMessage.AppendReasoningSignature(signature.Signature)
602 }
603 }
604 summaryMessage.FinishThinking()
605 return a.messages.Update(genCtx, summaryMessage)
606 },
607 OnTextDelta: func(id, text string) error {
608 summaryMessage.AppendContent(text)
609 return a.messages.Update(genCtx, summaryMessage)
610 },
611 })
612 if err != nil {
613 isCancelErr := errors.Is(err, context.Canceled)
614 if isCancelErr {
615 // User cancelled summarize we need to remove the summary message.
616 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
617 return deleteErr
618 }
619 return err
620 }
621
622 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
623 err = a.messages.Update(genCtx, summaryMessage)
624 if err != nil {
625 return err
626 }
627
628 var openrouterCost *float64
629 for _, step := range resp.Steps {
630 stepCost := a.openrouterCost(step.ProviderMetadata)
631 if stepCost != nil {
632 newCost := *stepCost
633 if openrouterCost != nil {
634 newCost += *openrouterCost
635 }
636 openrouterCost = &newCost
637 }
638 }
639
640 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
641
642 // Just in case, get just the last usage info.
643 usage := resp.Response.Usage
644 currentSession.SummaryMessageID = summaryMessage.ID
645 currentSession.CompletionTokens = usage.OutputTokens
646 currentSession.PromptTokens = 0
647 _, err = a.sessions.Save(genCtx, currentSession)
648 return err
649}
650
651func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
652 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
653 return fantasy.ProviderOptions{}
654 }
655 return fantasy.ProviderOptions{
656 anthropic.Name: &anthropic.ProviderCacheControlOptions{
657 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
658 },
659 bedrock.Name: &anthropic.ProviderCacheControlOptions{
660 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
661 },
662 }
663}
664
665func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
666 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
667 var attachmentParts []message.ContentPart
668 for _, attachment := range call.Attachments {
669 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
670 }
671 parts = append(parts, attachmentParts...)
672 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
673 Role: message.User,
674 Parts: parts,
675 })
676 if err != nil {
677 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
678 }
679 return msg, nil
680}
681
682func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
683 var history []fantasy.Message
684 if !a.isSubAgent {
685 history = append(history, fantasy.NewUserMessage(
686 fmt.Sprintf("<system_reminder>%s</system_reminder>",
687 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
688If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
689If not, please feel free to ignore. Again do not mention this message to the user.`,
690 ),
691 ))
692 }
693 for _, m := range msgs {
694 if len(m.Parts) == 0 {
695 continue
696 }
697 // Assistant message without content or tool calls (cancelled before it
698 // returned anything).
699 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
700 continue
701 }
702 history = append(history, m.ToAIMessage()...)
703 }
704
705 var files []fantasy.FilePart
706 for _, attachment := range attachments {
707 if attachment.IsText() {
708 continue
709 }
710 files = append(files, fantasy.FilePart{
711 Filename: attachment.FileName,
712 Data: attachment.Content,
713 MediaType: attachment.MimeType,
714 })
715 }
716
717 return history, files
718}
719
720func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
721 msgs, err := a.messages.List(ctx, session.ID)
722 if err != nil {
723 return nil, fmt.Errorf("failed to list messages: %w", err)
724 }
725
726 if session.SummaryMessageID != "" {
727 summaryMsgInex := -1
728 for i, msg := range msgs {
729 if msg.ID == session.SummaryMessageID {
730 summaryMsgInex = i
731 break
732 }
733 }
734 if summaryMsgInex != -1 {
735 msgs = msgs[summaryMsgInex:]
736 msgs[0].Role = message.User
737 }
738 }
739 return msgs, nil
740}
741
742// generateTitle generates a session titled based on the initial prompt.
743func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
744 if userPrompt == "" {
745 return
746 }
747
748 var maxOutputTokens int64 = 40
749 if a.smallModel.CatwalkCfg.CanReason {
750 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
751 }
752
753 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
754 return fantasy.NewAgent(m,
755 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
756 fantasy.WithMaxOutputTokens(tok),
757 )
758 }
759
760 streamCall := fantasy.AgentStreamCall{
761 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
762 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
763 prepared.Messages = opts.Messages
764 if a.systemPromptPrefix != "" {
765 prepared.Messages = append([]fantasy.Message{
766 fantasy.NewSystemMessage(a.systemPromptPrefix),
767 }, prepared.Messages...)
768 }
769 return callCtx, prepared, nil
770 },
771 }
772
773 // Use the small model to generate the title.
774 model := &a.smallModel
775 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
776 resp, err := agent.Stream(ctx, streamCall)
777 if err == nil {
778 // We successfully generated a title with the small model.
779 slog.Info("generated title with small model")
780 } else {
781 // It didn't work. Let's try with the big model.
782 slog.Error("error generating title with small model; trying big model", "err", err)
783 model = &a.largeModel
784 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
785 resp, err = agent.Stream(ctx, streamCall)
786 if err == nil {
787 slog.Info("generated title with large model")
788 } else {
789 // Welp, the large model didn't work either. Use the default
790 // session name and return.
791 slog.Error("error generating title with large model", "err", err)
792 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
793 if saveErr != nil {
794 slog.Error("failed to save session title and usage", "error", saveErr)
795 }
796 return
797 }
798 }
799
800 if resp == nil {
801 // Actually, we didn't get a response so we can't. Use the default
802 // session name and return.
803 slog.Error("response is nil; can't generate title")
804 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
805 if saveErr != nil {
806 slog.Error("failed to save session title and usage", "error", saveErr)
807 }
808 return
809 }
810
811 // Clean up title.
812 var title string
813 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
814 slog.Info("generated title", "title", title)
815
816 // Remove thinking tags if present.
817 title = thinkTagRegex.ReplaceAllString(title, "")
818
819 title = strings.TrimSpace(title)
820 if title == "" {
821 slog.Warn("empty title; using fallback")
822 title = defaultSessionName
823 }
824
825 // Calculate usage and cost.
826 var openrouterCost *float64
827 for _, step := range resp.Steps {
828 stepCost := a.openrouterCost(step.ProviderMetadata)
829 if stepCost != nil {
830 newCost := *stepCost
831 if openrouterCost != nil {
832 newCost += *openrouterCost
833 }
834 openrouterCost = &newCost
835 }
836 }
837
838 modelConfig := model.CatwalkCfg
839 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
840 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
841 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
842 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
843
844 // Use override cost if available (e.g., from OpenRouter).
845 if openrouterCost != nil {
846 cost = *openrouterCost
847 }
848
849 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
850 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
851
852 // Atomically update only title and usage fields to avoid overriding other
853 // concurrent session updates.
854 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
855 if saveErr != nil {
856 slog.Error("failed to save session title and usage", "error", saveErr)
857 return
858 }
859}
860
861func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
862 openrouterMetadata, ok := metadata[openrouter.Name]
863 if !ok {
864 return nil
865 }
866
867 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
868 if !ok {
869 return nil
870 }
871 return &opts.Usage.Cost
872}
873
874func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
875 modelConfig := model.CatwalkCfg
876 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
877 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
878 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
879 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
880
881 a.eventTokensUsed(session.ID, model, usage, cost)
882
883 if overrideCost != nil {
884 session.Cost += *overrideCost
885 } else {
886 session.Cost += cost
887 }
888
889 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
890 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
891}
892
893func (a *sessionAgent) Cancel(sessionID string) {
894 // Cancel regular requests.
895 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
896 slog.Info("Request cancellation initiated", "session_id", sessionID)
897 cancel()
898 }
899
900 // Also check for summarize requests.
901 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
902 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
903 cancel()
904 }
905
906 if a.QueuedPrompts(sessionID) > 0 {
907 slog.Info("Clearing queued prompts", "session_id", sessionID)
908 a.messageQueue.Del(sessionID)
909 }
910}
911
912func (a *sessionAgent) ClearQueue(sessionID string) {
913 if a.QueuedPrompts(sessionID) > 0 {
914 slog.Info("Clearing queued prompts", "session_id", sessionID)
915 a.messageQueue.Del(sessionID)
916 }
917}
918
919func (a *sessionAgent) CancelAll() {
920 if !a.IsBusy() {
921 return
922 }
923 for key := range a.activeRequests.Seq2() {
924 a.Cancel(key) // key is sessionID
925 }
926
927 timeout := time.After(5 * time.Second)
928 for a.IsBusy() {
929 select {
930 case <-timeout:
931 return
932 default:
933 time.Sleep(200 * time.Millisecond)
934 }
935 }
936}
937
938func (a *sessionAgent) IsBusy() bool {
939 var busy bool
940 for cancelFunc := range a.activeRequests.Seq() {
941 if cancelFunc != nil {
942 busy = true
943 break
944 }
945 }
946 return busy
947}
948
949func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
950 _, busy := a.activeRequests.Get(sessionID)
951 return busy
952}
953
954func (a *sessionAgent) QueuedPrompts(sessionID string) int {
955 l, ok := a.messageQueue.Get(sessionID)
956 if !ok {
957 return 0
958 }
959 return len(l)
960}
961
962func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
963 l, ok := a.messageQueue.Get(sessionID)
964 if !ok {
965 return nil
966 }
967 prompts := make([]string, len(l))
968 for i, call := range l {
969 prompts[i] = call.Prompt
970 }
971 return prompts
972}
973
974func (a *sessionAgent) SetModels(large Model, small Model) {
975 a.largeModel = large
976 a.smallModel = small
977}
978
979func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
980 a.tools = tools
981}
982
983func (a *sessionAgent) Model() Model {
984 return a.largeModel
985}
986
987func (a *sessionAgent) promptPrefix() string {
988 return a.systemPromptPrefix
989}
990
991// convertToToolResult converts a fantasy tool result to a message tool result.
992func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
993 baseResult := message.ToolResult{
994 ToolCallID: result.ToolCallID,
995 Name: result.ToolName,
996 Metadata: result.ClientMetadata,
997 }
998
999 switch result.Result.GetType() {
1000 case fantasy.ToolResultContentTypeText:
1001 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1002 baseResult.Content = r.Text
1003 }
1004 case fantasy.ToolResultContentTypeError:
1005 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1006 baseResult.Content = r.Error.Error()
1007 baseResult.IsError = true
1008 }
1009 case fantasy.ToolResultContentTypeMedia:
1010 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1011 content := r.Text
1012 if content == "" {
1013 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1014 }
1015 baseResult.Content = content
1016 baseResult.Data = r.Data
1017 baseResult.MIMEType = r.MediaType
1018 }
1019 }
1020
1021 return baseResult
1022}
1023
1024// workaroundProviderMediaLimitations converts media content in tool results to
1025// user messages for providers that don't natively support images in tool results.
1026//
1027// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1028// don't support sending images/media in tool result messages - they only accept
1029// text in tool results. However, they DO support images in user messages.
1030//
1031// If we send media in tool results to these providers, the API returns an error.
1032//
1033// Solution: For these providers, we:
1034// 1. Replace the media in the tool result with a text placeholder
1035// 2. Inject a user message immediately after with the image as a file attachment
1036// 3. This maintains the tool execution flow while working around API limitations
1037//
1038// Anthropic and Bedrock support images natively in tool results, so we skip
1039// this workaround for them.
1040//
1041// Example transformation:
1042//
1043// BEFORE: [tool result: image data]
1044// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1045func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1046 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1047 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1048
1049 if providerSupportsMedia {
1050 return messages
1051 }
1052
1053 convertedMessages := make([]fantasy.Message, 0, len(messages))
1054
1055 for _, msg := range messages {
1056 if msg.Role != fantasy.MessageRoleTool {
1057 convertedMessages = append(convertedMessages, msg)
1058 continue
1059 }
1060
1061 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1062 var mediaFiles []fantasy.FilePart
1063
1064 for _, part := range msg.Content {
1065 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1066 if !ok {
1067 textParts = append(textParts, part)
1068 continue
1069 }
1070
1071 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1072 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1073 if err != nil {
1074 slog.Warn("failed to decode media data", "error", err)
1075 textParts = append(textParts, part)
1076 continue
1077 }
1078
1079 mediaFiles = append(mediaFiles, fantasy.FilePart{
1080 Data: decoded,
1081 MediaType: media.MediaType,
1082 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1083 })
1084
1085 textParts = append(textParts, fantasy.ToolResultPart{
1086 ToolCallID: toolResult.ToolCallID,
1087 Output: fantasy.ToolResultOutputContentText{
1088 Text: "[Image/media content loaded - see attached file]",
1089 },
1090 ProviderOptions: toolResult.ProviderOptions,
1091 })
1092 } else {
1093 textParts = append(textParts, part)
1094 }
1095 }
1096
1097 convertedMessages = append(convertedMessages, fantasy.Message{
1098 Role: fantasy.MessageRoleTool,
1099 Content: textParts,
1100 })
1101
1102 if len(mediaFiles) > 0 {
1103 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1104 "Here is the media content from the tool result:",
1105 mediaFiles...,
1106 ))
1107 }
1108 }
1109
1110 return convertedMessages
1111}