1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "git.secluded.site/crush/internal/agent/hyper"
33 "git.secluded.site/crush/internal/agent/tools"
34 "git.secluded.site/crush/internal/config"
35 "git.secluded.site/crush/internal/csync"
36 "git.secluded.site/crush/internal/message"
37 "git.secluded.site/crush/internal/notification"
38 "git.secluded.site/crush/internal/permission"
39 "git.secluded.site/crush/internal/session"
40 "git.secluded.site/crush/internal/stringext"
41 "github.com/charmbracelet/catwalk/pkg/catwalk"
42)
43
44const defaultSessionName = "Untitled Session"
45
46//go:embed templates/title.md
47var titlePrompt []byte
48
49//go:embed templates/summary.md
50var summaryPrompt []byte
51
52// Used to remove <think> tags from generated titles.
53var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
54
55type SessionAgentCall struct {
56 SessionID string
57 Prompt string
58 ProviderOptions fantasy.ProviderOptions
59 Attachments []message.Attachment
60 MaxOutputTokens int64
61 Temperature *float64
62 TopP *float64
63 TopK *int64
64 FrequencyPenalty *float64
65 PresencePenalty *float64
66 NonInteractive bool
67}
68
69type SessionAgent interface {
70 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
71 SetModels(large Model, small Model)
72 SetTools(tools []fantasy.AgentTool)
73 Cancel(sessionID string)
74 CancelAll()
75 IsSessionBusy(sessionID string) bool
76 IsBusy() bool
77 QueuedPrompts(sessionID string) int
78 QueuedPromptsList(sessionID string) []string
79 ClearQueue(sessionID string)
80 Summarize(context.Context, string, fantasy.ProviderOptions) error
81 Model() Model
82}
83
84type Model struct {
85 Model fantasy.LanguageModel
86 CatwalkCfg catwalk.Model
87 ModelCfg config.SelectedModel
88}
89
90type sessionAgent struct {
91 largeModel Model
92 smallModel Model
93 systemPromptPrefix string
94 systemPrompt string
95 isSubAgent bool
96 tools []fantasy.AgentTool
97 sessions session.Service
98 messages message.Service
99 disableAutoSummarize bool
100 isYolo bool
101
102 messageQueue *csync.Map[string, []SessionAgentCall]
103 activeRequests *csync.Map[string, context.CancelFunc]
104}
105
106type SessionAgentOptions struct {
107 LargeModel Model
108 SmallModel Model
109 SystemPromptPrefix string
110 SystemPrompt string
111 IsSubAgent bool
112 DisableAutoSummarize bool
113 IsYolo bool
114 Sessions session.Service
115 Messages message.Service
116 Tools []fantasy.AgentTool
117}
118
119func NewSessionAgent(
120 opts SessionAgentOptions,
121) SessionAgent {
122 return &sessionAgent{
123 largeModel: opts.LargeModel,
124 smallModel: opts.SmallModel,
125 systemPromptPrefix: opts.SystemPromptPrefix,
126 systemPrompt: opts.SystemPrompt,
127 isSubAgent: opts.IsSubAgent,
128 sessions: opts.Sessions,
129 messages: opts.Messages,
130 disableAutoSummarize: opts.DisableAutoSummarize,
131 tools: opts.Tools,
132 isYolo: opts.IsYolo,
133 messageQueue: csync.NewMap[string, []SessionAgentCall](),
134 activeRequests: csync.NewMap[string, context.CancelFunc](),
135 }
136}
137
138func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
139 if call.Prompt == "" {
140 return nil, ErrEmptyPrompt
141 }
142 if call.SessionID == "" {
143 return nil, ErrSessionMissing
144 }
145
146 // Queue the message if busy
147 if a.IsSessionBusy(call.SessionID) {
148 existing, ok := a.messageQueue.Get(call.SessionID)
149 if !ok {
150 existing = []SessionAgentCall{}
151 }
152 existing = append(existing, call)
153 a.messageQueue.Set(call.SessionID, existing)
154 return nil, nil
155 }
156
157 if len(a.tools) > 0 {
158 // Add Anthropic caching to the last tool.
159 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
160 }
161
162 agent := fantasy.NewAgent(
163 a.largeModel.Model,
164 fantasy.WithSystemPrompt(a.systemPrompt),
165 fantasy.WithTools(a.tools...),
166 )
167
168 sessionLock := sync.Mutex{}
169 currentSession, err := a.sessions.Get(ctx, call.SessionID)
170 if err != nil {
171 return nil, fmt.Errorf("failed to get session: %w", err)
172 }
173
174 msgs, err := a.getSessionMessages(ctx, currentSession)
175 if err != nil {
176 return nil, fmt.Errorf("failed to get session messages: %w", err)
177 }
178
179 var wg sync.WaitGroup
180 // Generate title if first message.
181 if len(msgs) == 0 {
182 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
183 wg.Go(func() {
184 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
185 })
186 }
187
188 // Add the user message to the session.
189 _, err = a.createUserMessage(ctx, call)
190 if err != nil {
191 return nil, err
192 }
193
194 // Add the session to the context.
195 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
196
197 genCtx, cancel := context.WithCancel(ctx)
198 a.activeRequests.Set(call.SessionID, cancel)
199
200 defer cancel()
201 defer a.activeRequests.Del(call.SessionID)
202
203 history, files := a.preparePrompt(msgs, call.Attachments...)
204
205 startTime := time.Now()
206 a.eventPromptSent(call.SessionID)
207
208 var currentAssistant *message.Message
209 var shouldSummarize bool
210 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
211 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
212 Files: files,
213 Messages: history,
214 ProviderOptions: call.ProviderOptions,
215 MaxOutputTokens: &call.MaxOutputTokens,
216 TopP: call.TopP,
217 Temperature: call.Temperature,
218 PresencePenalty: call.PresencePenalty,
219 TopK: call.TopK,
220 FrequencyPenalty: call.FrequencyPenalty,
221 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
222 prepared.Messages = options.Messages
223 for i := range prepared.Messages {
224 prepared.Messages[i].ProviderOptions = nil
225 }
226
227 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
228 a.messageQueue.Del(call.SessionID)
229 for _, queued := range queuedCalls {
230 userMessage, createErr := a.createUserMessage(callContext, queued)
231 if createErr != nil {
232 return callContext, prepared, createErr
233 }
234 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
235 }
236
237 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
238
239 lastSystemRoleInx := 0
240 systemMessageUpdated := false
241 for i, msg := range prepared.Messages {
242 // Only add cache control to the last message.
243 if msg.Role == fantasy.MessageRoleSystem {
244 lastSystemRoleInx = i
245 } else if !systemMessageUpdated {
246 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
247 systemMessageUpdated = true
248 }
249 // Than add cache control to the last 2 messages.
250 if i > len(prepared.Messages)-3 {
251 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
252 }
253 }
254
255 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
256 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
257 }
258
259 var assistantMsg message.Message
260 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
261 Role: message.Assistant,
262 Parts: []message.ContentPart{},
263 Model: a.largeModel.ModelCfg.Model,
264 Provider: a.largeModel.ModelCfg.Provider,
265 })
266 if err != nil {
267 return callContext, prepared, err
268 }
269 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
270 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
271 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
272 currentAssistant = &assistantMsg
273 return callContext, prepared, err
274 },
275 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
276 currentAssistant.AppendReasoningContent(reasoning.Text)
277 return a.messages.Update(genCtx, *currentAssistant)
278 },
279 OnReasoningDelta: func(id string, text string) error {
280 currentAssistant.AppendReasoningContent(text)
281 return a.messages.Update(genCtx, *currentAssistant)
282 },
283 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
284 // handle anthropic signature
285 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
286 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
287 currentAssistant.AppendReasoningSignature(reasoning.Signature)
288 }
289 }
290 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
291 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
292 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
293 }
294 }
295 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
296 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
297 currentAssistant.SetReasoningResponsesData(reasoning)
298 }
299 }
300 currentAssistant.FinishThinking()
301 return a.messages.Update(genCtx, *currentAssistant)
302 },
303 OnTextDelta: func(id string, text string) error {
304 // Strip leading newline from initial text content. This is is
305 // particularly important in non-interactive mode where leading
306 // newlines are very visible.
307 if len(currentAssistant.Parts) == 0 {
308 text = strings.TrimPrefix(text, "\n")
309 }
310
311 currentAssistant.AppendContent(text)
312 return a.messages.Update(genCtx, *currentAssistant)
313 },
314 OnToolInputStart: func(id string, toolName string) error {
315 toolCall := message.ToolCall{
316 ID: id,
317 Name: toolName,
318 ProviderExecuted: false,
319 Finished: false,
320 }
321 currentAssistant.AddToolCall(toolCall)
322 return a.messages.Update(genCtx, *currentAssistant)
323 },
324 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
325 // TODO: implement
326 },
327 OnToolCall: func(tc fantasy.ToolCallContent) error {
328 toolCall := message.ToolCall{
329 ID: tc.ToolCallID,
330 Name: tc.ToolName,
331 Input: tc.Input,
332 ProviderExecuted: false,
333 Finished: true,
334 }
335 currentAssistant.AddToolCall(toolCall)
336 return a.messages.Update(genCtx, *currentAssistant)
337 },
338 OnToolResult: func(result fantasy.ToolResultContent) error {
339 toolResult := a.convertToToolResult(result)
340 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
341 Role: message.Tool,
342 Parts: []message.ContentPart{
343 toolResult,
344 },
345 })
346 return createMsgErr
347 },
348 OnStepFinish: func(stepResult fantasy.StepResult) error {
349 finishReason := message.FinishReasonUnknown
350 switch stepResult.FinishReason {
351 case fantasy.FinishReasonLength:
352 finishReason = message.FinishReasonMaxTokens
353 case fantasy.FinishReasonStop:
354 finishReason = message.FinishReasonEndTurn
355 case fantasy.FinishReasonToolCalls:
356 finishReason = message.FinishReasonToolUse
357 }
358 currentAssistant.AddFinish(finishReason, "", "")
359 sessionLock.Lock()
360 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
361 if getSessionErr != nil {
362 sessionLock.Unlock()
363 return getSessionErr
364 }
365 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
366 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
367 sessionLock.Unlock()
368 if sessionErr != nil {
369 return sessionErr
370 }
371 return a.messages.Update(genCtx, *currentAssistant)
372 },
373 StopWhen: []fantasy.StopCondition{
374 func(steps []fantasy.StepResult) bool {
375 if len(steps) == 0 {
376 return false
377 }
378 lastStep := steps[len(steps)-1]
379 usage := lastStep.Usage
380 tokens := usage.InputTokens + usage.OutputTokens + usage.CacheCreationTokens + usage.CacheReadTokens
381 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
382 remaining := cw - tokens
383 var threshold int64
384 if cw > 200_000 {
385 threshold = 20_000
386 } else {
387 threshold = int64(float64(cw) * 0.2)
388 }
389 if (remaining <= threshold) && !a.disableAutoSummarize {
390 shouldSummarize = true
391 return true
392 }
393 return false
394 },
395 },
396 })
397
398 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
399
400 if err != nil {
401 isCancelErr := errors.Is(err, context.Canceled)
402 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
403 if currentAssistant == nil {
404 return result, err
405 }
406 // Ensure we finish thinking on error to close the reasoning state.
407 currentAssistant.FinishThinking()
408 toolCalls := currentAssistant.ToolCalls()
409 // INFO: we use the parent context here because the genCtx has been cancelled.
410 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
411 if createErr != nil {
412 return nil, createErr
413 }
414 for _, tc := range toolCalls {
415 if !tc.Finished {
416 tc.Finished = true
417 tc.Input = "{}"
418 currentAssistant.AddToolCall(tc)
419 updateErr := a.messages.Update(ctx, *currentAssistant)
420 if updateErr != nil {
421 return nil, updateErr
422 }
423 }
424
425 found := false
426 for _, msg := range msgs {
427 if msg.Role == message.Tool {
428 for _, tr := range msg.ToolResults() {
429 if tr.ToolCallID == tc.ID {
430 found = true
431 break
432 }
433 }
434 }
435 if found {
436 break
437 }
438 }
439 if found {
440 continue
441 }
442 content := "There was an error while executing the tool"
443 if isCancelErr {
444 content = "Tool execution canceled by user"
445 } else if isPermissionErr {
446 content = "User denied permission"
447 }
448 toolResult := message.ToolResult{
449 ToolCallID: tc.ID,
450 Name: tc.Name,
451 Content: content,
452 IsError: true,
453 }
454 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
455 Role: message.Tool,
456 Parts: []message.ContentPart{
457 toolResult,
458 },
459 })
460 if createErr != nil {
461 return nil, createErr
462 }
463 }
464 var fantasyErr *fantasy.Error
465 var providerErr *fantasy.ProviderError
466 const defaultTitle = "Provider Error"
467 if isCancelErr {
468 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
469 } else if isPermissionErr {
470 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
471 } else if errors.Is(err, hyper.ErrNoCredits) {
472 url := hyper.BaseURL()
473 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
474 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
475 } else if errors.As(err, &providerErr) {
476 if providerErr.Message == "The requested model is not supported." {
477 url := "https://github.com/settings/copilot/features"
478 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
479 currentAssistant.AddFinish(
480 message.FinishReasonError,
481 "Copilot model not enabled",
482 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
483 )
484 } else {
485 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
486 }
487 } else if errors.As(err, &fantasyErr) {
488 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
489 } else {
490 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
491 }
492 // Note: we use the parent context here because the genCtx has been
493 // cancelled.
494 updateErr := a.messages.Update(ctx, *currentAssistant)
495 if updateErr != nil {
496 return nil, updateErr
497 }
498 return nil, err
499 }
500 wg.Wait()
501
502 // Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
503 if !call.NonInteractive {
504 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
505 _ = notification.Send("Crush is waiting...", notifBody)
506 }
507
508 if shouldSummarize {
509 a.activeRequests.Del(call.SessionID)
510 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
511 return nil, summarizeErr
512 }
513 // If the agent wasn't done...
514 if len(currentAssistant.ToolCalls()) > 0 {
515 existing, ok := a.messageQueue.Get(call.SessionID)
516 if !ok {
517 existing = []SessionAgentCall{}
518 }
519 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
520 existing = append(existing, call)
521 a.messageQueue.Set(call.SessionID, existing)
522 }
523 }
524
525 // Release active request before processing queued messages.
526 a.activeRequests.Del(call.SessionID)
527 cancel()
528
529 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
530 if !ok || len(queuedMessages) == 0 {
531 return result, err
532 }
533 // There are queued messages restart the loop.
534 firstQueuedMessage := queuedMessages[0]
535 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
536 return a.Run(ctx, firstQueuedMessage)
537}
538
539func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
540 if a.IsSessionBusy(sessionID) {
541 return ErrSessionBusy
542 }
543
544 currentSession, err := a.sessions.Get(ctx, sessionID)
545 if err != nil {
546 return fmt.Errorf("failed to get session: %w", err)
547 }
548 msgs, err := a.getSessionMessages(ctx, currentSession)
549 if err != nil {
550 return err
551 }
552 if len(msgs) == 0 {
553 // Nothing to summarize.
554 return nil
555 }
556
557 aiMsgs, _ := a.preparePrompt(msgs)
558
559 genCtx, cancel := context.WithCancel(ctx)
560 a.activeRequests.Set(sessionID, cancel)
561 defer a.activeRequests.Del(sessionID)
562 defer cancel()
563
564 agent := fantasy.NewAgent(a.largeModel.Model,
565 fantasy.WithSystemPrompt(string(summaryPrompt)),
566 )
567 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
568 Role: message.Assistant,
569 Model: a.largeModel.Model.Model(),
570 Provider: a.largeModel.Model.Provider(),
571 IsSummaryMessage: true,
572 })
573 if err != nil {
574 return err
575 }
576
577 summaryPromptText := "Provide a detailed summary of our conversation above."
578 if len(currentSession.Todos) > 0 {
579 summaryPromptText += "\n\n## Current Todo List\n\n"
580 for _, t := range currentSession.Todos {
581 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
582 }
583 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
584 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
585 }
586
587 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
588 Prompt: summaryPromptText,
589 Messages: aiMsgs,
590 ProviderOptions: opts,
591 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
592 prepared.Messages = options.Messages
593 if a.systemPromptPrefix != "" {
594 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
595 }
596 return callContext, prepared, nil
597 },
598 OnReasoningDelta: func(id string, text string) error {
599 summaryMessage.AppendReasoningContent(text)
600 return a.messages.Update(genCtx, summaryMessage)
601 },
602 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
603 // Handle anthropic signature.
604 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
605 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
606 summaryMessage.AppendReasoningSignature(signature.Signature)
607 }
608 }
609 summaryMessage.FinishThinking()
610 return a.messages.Update(genCtx, summaryMessage)
611 },
612 OnTextDelta: func(id, text string) error {
613 summaryMessage.AppendContent(text)
614 return a.messages.Update(genCtx, summaryMessage)
615 },
616 })
617 if err != nil {
618 isCancelErr := errors.Is(err, context.Canceled)
619 if isCancelErr {
620 // User cancelled summarize we need to remove the summary message.
621 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
622 return deleteErr
623 }
624 return err
625 }
626
627 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
628 err = a.messages.Update(genCtx, summaryMessage)
629 if err != nil {
630 return err
631 }
632
633 var openrouterCost *float64
634 for _, step := range resp.Steps {
635 stepCost := a.openrouterCost(step.ProviderMetadata)
636 if stepCost != nil {
637 newCost := *stepCost
638 if openrouterCost != nil {
639 newCost += *openrouterCost
640 }
641 openrouterCost = &newCost
642 }
643 }
644
645 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
646
647 // Just in case, get just the last usage info.
648 usage := resp.Response.Usage
649 currentSession.SummaryMessageID = summaryMessage.ID
650 currentSession.CompletionTokens = usage.OutputTokens
651 currentSession.PromptTokens = 0
652 _, err = a.sessions.Save(genCtx, currentSession)
653 return err
654}
655
656func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
657 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
658 return fantasy.ProviderOptions{}
659 }
660 return fantasy.ProviderOptions{
661 anthropic.Name: &anthropic.ProviderCacheControlOptions{
662 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
663 },
664 bedrock.Name: &anthropic.ProviderCacheControlOptions{
665 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
666 },
667 }
668}
669
670func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
671 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
672 var attachmentParts []message.ContentPart
673 for _, attachment := range call.Attachments {
674 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
675 }
676 parts = append(parts, attachmentParts...)
677 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
678 Role: message.User,
679 Parts: parts,
680 })
681 if err != nil {
682 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
683 }
684 return msg, nil
685}
686
687func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
688 var history []fantasy.Message
689 if !a.isSubAgent {
690 history = append(history, fantasy.NewUserMessage(
691 fmt.Sprintf("<system_reminder>%s</system_reminder>",
692 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
693If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
694If not, please feel free to ignore. Again do not mention this message to the user.`,
695 ),
696 ))
697 }
698 for _, m := range msgs {
699 if len(m.Parts) == 0 {
700 continue
701 }
702 // Assistant message without content or tool calls (cancelled before it
703 // returned anything).
704 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
705 continue
706 }
707 history = append(history, m.ToAIMessage()...)
708 }
709
710 var files []fantasy.FilePart
711 for _, attachment := range attachments {
712 if attachment.IsText() {
713 continue
714 }
715 files = append(files, fantasy.FilePart{
716 Filename: attachment.FileName,
717 Data: attachment.Content,
718 MediaType: attachment.MimeType,
719 })
720 }
721
722 return history, files
723}
724
725func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
726 msgs, err := a.messages.List(ctx, session.ID)
727 if err != nil {
728 return nil, fmt.Errorf("failed to list messages: %w", err)
729 }
730
731 if session.SummaryMessageID != "" {
732 summaryMsgInex := -1
733 for i, msg := range msgs {
734 if msg.ID == session.SummaryMessageID {
735 summaryMsgInex = i
736 break
737 }
738 }
739 if summaryMsgInex != -1 {
740 msgs = msgs[summaryMsgInex:]
741 msgs[0].Role = message.User
742 }
743 }
744 return msgs, nil
745}
746
747// generateTitle generates a session titled based on the initial prompt.
748func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
749 if userPrompt == "" {
750 return
751 }
752
753 var maxOutputTokens int64 = 40
754 if a.smallModel.CatwalkCfg.CanReason {
755 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
756 }
757
758 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
759 return fantasy.NewAgent(m,
760 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
761 fantasy.WithMaxOutputTokens(tok),
762 )
763 }
764
765 streamCall := fantasy.AgentStreamCall{
766 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
767 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
768 prepared.Messages = opts.Messages
769 if a.systemPromptPrefix != "" {
770 prepared.Messages = append([]fantasy.Message{
771 fantasy.NewSystemMessage(a.systemPromptPrefix),
772 }, prepared.Messages...)
773 }
774 return callCtx, prepared, nil
775 },
776 }
777
778 // Use the small model to generate the title.
779 model := &a.smallModel
780 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
781 resp, err := agent.Stream(ctx, streamCall)
782 if err == nil {
783 // We successfully generated a title with the small model.
784 slog.Info("generated title with small model")
785 } else {
786 // It didn't work. Let's try with the big model.
787 slog.Error("error generating title with small model; trying big model", "err", err)
788 model = &a.largeModel
789 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
790 resp, err = agent.Stream(ctx, streamCall)
791 if err == nil {
792 slog.Info("generated title with large model")
793 } else {
794 // Welp, the large model didn't work either. Use the default
795 // session name and return.
796 slog.Error("error generating title with large model", "err", err)
797 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
798 if saveErr != nil {
799 slog.Error("failed to save session title and usage", "error", saveErr)
800 }
801 return
802 }
803 }
804
805 if resp == nil {
806 // Actually, we didn't get a response so we can't. Use the default
807 // session name and return.
808 slog.Error("response is nil; can't generate title")
809 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
810 if saveErr != nil {
811 slog.Error("failed to save session title and usage", "error", saveErr)
812 }
813 return
814 }
815
816 // Clean up title.
817 var title string
818 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
819 slog.Info("generated title", "title", title)
820
821 // Remove thinking tags if present.
822 title = thinkTagRegex.ReplaceAllString(title, "")
823
824 title = strings.TrimSpace(title)
825 if title == "" {
826 slog.Warn("empty title; using fallback")
827 title = defaultSessionName
828 }
829
830 // Calculate usage and cost.
831 var openrouterCost *float64
832 for _, step := range resp.Steps {
833 stepCost := a.openrouterCost(step.ProviderMetadata)
834 if stepCost != nil {
835 newCost := *stepCost
836 if openrouterCost != nil {
837 newCost += *openrouterCost
838 }
839 openrouterCost = &newCost
840 }
841 }
842
843 modelConfig := model.CatwalkCfg
844 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
845 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
846 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
847 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
848
849 if a.isClaudeCode() {
850 cost = 0
851 }
852
853 // Use override cost if available (e.g., from OpenRouter).
854 if openrouterCost != nil {
855 cost = *openrouterCost
856 }
857
858 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
859 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
860
861 // Atomically update only title and usage fields to avoid overriding other
862 // concurrent session updates.
863 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
864 if saveErr != nil {
865 slog.Error("failed to save session title and usage", "error", saveErr)
866 return
867 }
868}
869
870func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
871 openrouterMetadata, ok := metadata[openrouter.Name]
872 if !ok {
873 return nil
874 }
875
876 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
877 if !ok {
878 return nil
879 }
880 return &opts.Usage.Cost
881}
882
883func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
884 modelConfig := model.CatwalkCfg
885 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
886 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
887 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
888 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
889
890 if a.isClaudeCode() {
891 cost = 0
892 }
893
894 a.eventTokensUsed(session.ID, model, usage, cost)
895
896 if overrideCost != nil {
897 session.Cost += *overrideCost
898 } else {
899 session.Cost += cost
900 }
901
902 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
903 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
904}
905
906func (a *sessionAgent) Cancel(sessionID string) {
907 // Cancel regular requests.
908 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
909 slog.Info("Request cancellation initiated", "session_id", sessionID)
910 cancel()
911 }
912
913 // Also check for summarize requests.
914 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
915 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
916 cancel()
917 }
918
919 if a.QueuedPrompts(sessionID) > 0 {
920 slog.Info("Clearing queued prompts", "session_id", sessionID)
921 a.messageQueue.Del(sessionID)
922 }
923}
924
925func (a *sessionAgent) ClearQueue(sessionID string) {
926 if a.QueuedPrompts(sessionID) > 0 {
927 slog.Info("Clearing queued prompts", "session_id", sessionID)
928 a.messageQueue.Del(sessionID)
929 }
930}
931
932func (a *sessionAgent) CancelAll() {
933 if !a.IsBusy() {
934 return
935 }
936 for key := range a.activeRequests.Seq2() {
937 a.Cancel(key) // key is sessionID
938 }
939
940 timeout := time.After(5 * time.Second)
941 for a.IsBusy() {
942 select {
943 case <-timeout:
944 return
945 default:
946 time.Sleep(200 * time.Millisecond)
947 }
948 }
949}
950
951func (a *sessionAgent) IsBusy() bool {
952 var busy bool
953 for cancelFunc := range a.activeRequests.Seq() {
954 if cancelFunc != nil {
955 busy = true
956 break
957 }
958 }
959 return busy
960}
961
962func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
963 _, busy := a.activeRequests.Get(sessionID)
964 return busy
965}
966
967func (a *sessionAgent) QueuedPrompts(sessionID string) int {
968 l, ok := a.messageQueue.Get(sessionID)
969 if !ok {
970 return 0
971 }
972 return len(l)
973}
974
975func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
976 l, ok := a.messageQueue.Get(sessionID)
977 if !ok {
978 return nil
979 }
980 prompts := make([]string, len(l))
981 for i, call := range l {
982 prompts[i] = call.Prompt
983 }
984 return prompts
985}
986
987func (a *sessionAgent) SetModels(large Model, small Model) {
988 a.largeModel = large
989 a.smallModel = small
990}
991
992func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
993 a.tools = tools
994}
995
996func (a *sessionAgent) Model() Model {
997 return a.largeModel
998}
999
1000func (a *sessionAgent) promptPrefix() string {
1001 if a.isClaudeCode() {
1002 return "You are Claude Code, Anthropic's official CLI for Claude."
1003 }
1004 return a.systemPromptPrefix
1005}
1006
1007// XXX: this should be generalized to cover other subscription plans, like Copilot.
1008func (a *sessionAgent) isClaudeCode() bool {
1009 cfg := config.Get()
1010 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
1011 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
1012}
1013
1014// convertToToolResult converts a fantasy tool result to a message tool result.
1015func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1016 baseResult := message.ToolResult{
1017 ToolCallID: result.ToolCallID,
1018 Name: result.ToolName,
1019 Metadata: result.ClientMetadata,
1020 }
1021
1022 switch result.Result.GetType() {
1023 case fantasy.ToolResultContentTypeText:
1024 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1025 baseResult.Content = r.Text
1026 }
1027 case fantasy.ToolResultContentTypeError:
1028 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1029 baseResult.Content = r.Error.Error()
1030 baseResult.IsError = true
1031 }
1032 case fantasy.ToolResultContentTypeMedia:
1033 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1034 content := r.Text
1035 if content == "" {
1036 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1037 }
1038 baseResult.Content = content
1039 baseResult.Data = r.Data
1040 baseResult.MIMEType = r.MediaType
1041 }
1042 }
1043
1044 return baseResult
1045}
1046
1047// workaroundProviderMediaLimitations converts media content in tool results to
1048// user messages for providers that don't natively support images in tool results.
1049//
1050// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1051// don't support sending images/media in tool result messages - they only accept
1052// text in tool results. However, they DO support images in user messages.
1053//
1054// If we send media in tool results to these providers, the API returns an error.
1055//
1056// Solution: For these providers, we:
1057// 1. Replace the media in the tool result with a text placeholder
1058// 2. Inject a user message immediately after with the image as a file attachment
1059// 3. This maintains the tool execution flow while working around API limitations
1060//
1061// Anthropic and Bedrock support images natively in tool results, so we skip
1062// this workaround for them.
1063//
1064// Example transformation:
1065//
1066// BEFORE: [tool result: image data]
1067// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1068func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1069 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1070 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1071
1072 if providerSupportsMedia {
1073 return messages
1074 }
1075
1076 convertedMessages := make([]fantasy.Message, 0, len(messages))
1077
1078 for _, msg := range messages {
1079 if msg.Role != fantasy.MessageRoleTool {
1080 convertedMessages = append(convertedMessages, msg)
1081 continue
1082 }
1083
1084 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1085 var mediaFiles []fantasy.FilePart
1086
1087 for _, part := range msg.Content {
1088 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1089 if !ok {
1090 textParts = append(textParts, part)
1091 continue
1092 }
1093
1094 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1095 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1096 if err != nil {
1097 slog.Warn("failed to decode media data", "error", err)
1098 textParts = append(textParts, part)
1099 continue
1100 }
1101
1102 mediaFiles = append(mediaFiles, fantasy.FilePart{
1103 Data: decoded,
1104 MediaType: media.MediaType,
1105 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1106 })
1107
1108 textParts = append(textParts, fantasy.ToolResultPart{
1109 ToolCallID: toolResult.ToolCallID,
1110 Output: fantasy.ToolResultOutputContentText{
1111 Text: "[Image/media content loaded - see attached file]",
1112 },
1113 ProviderOptions: toolResult.ProviderOptions,
1114 })
1115 } else {
1116 textParts = append(textParts, part)
1117 }
1118 }
1119
1120 convertedMessages = append(convertedMessages, fantasy.Message{
1121 Role: fantasy.MessageRoleTool,
1122 Content: textParts,
1123 })
1124
1125 if len(mediaFiles) > 0 {
1126 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1127 "Here is the media content from the tool result:",
1128 mediaFiles...,
1129 ))
1130 }
1131 }
1132
1133 return convertedMessages
1134}