1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "git.secluded.site/crush/internal/agent/hyper"
33 "git.secluded.site/crush/internal/agent/tools"
34 "git.secluded.site/crush/internal/config"
35 "git.secluded.site/crush/internal/csync"
36 "git.secluded.site/crush/internal/message"
37 "git.secluded.site/crush/internal/notification"
38 "git.secluded.site/crush/internal/permission"
39 "git.secluded.site/crush/internal/session"
40 "git.secluded.site/crush/internal/stringext"
41 "github.com/charmbracelet/catwalk/pkg/catwalk"
42)
43
44const defaultSessionName = "Untitled Session"
45
46//go:embed templates/title.md
47var titlePrompt []byte
48
49//go:embed templates/summary.md
50var summaryPrompt []byte
51
52// Used to remove <think> tags from generated titles.
53var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
54
55type SessionAgentCall struct {
56 SessionID string
57 Prompt string
58 ProviderOptions fantasy.ProviderOptions
59 Attachments []message.Attachment
60 MaxOutputTokens int64
61 Temperature *float64
62 TopP *float64
63 TopK *int64
64 FrequencyPenalty *float64
65 PresencePenalty *float64
66 NonInteractive bool
67}
68
69type SessionAgent interface {
70 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
71 SetModels(large Model, small Model)
72 SetTools(tools []fantasy.AgentTool)
73 Cancel(sessionID string)
74 CancelAll()
75 IsSessionBusy(sessionID string) bool
76 IsBusy() bool
77 QueuedPrompts(sessionID string) int
78 QueuedPromptsList(sessionID string) []string
79 ClearQueue(sessionID string)
80 Summarize(context.Context, string, fantasy.ProviderOptions) error
81 Model() Model
82}
83
84type Model struct {
85 Model fantasy.LanguageModel
86 CatwalkCfg catwalk.Model
87 ModelCfg config.SelectedModel
88}
89
90type sessionAgent struct {
91 largeModel Model
92 smallModel Model
93 systemPromptPrefix string
94 systemPrompt string
95 isSubAgent bool
96 tools []fantasy.AgentTool
97 sessions session.Service
98 messages message.Service
99 disableAutoSummarize bool
100 isYolo bool
101
102 messageQueue *csync.Map[string, []SessionAgentCall]
103 activeRequests *csync.Map[string, context.CancelFunc]
104}
105
106type SessionAgentOptions struct {
107 LargeModel Model
108 SmallModel Model
109 SystemPromptPrefix string
110 SystemPrompt string
111 IsSubAgent bool
112 DisableAutoSummarize bool
113 IsYolo bool
114 Sessions session.Service
115 Messages message.Service
116 Tools []fantasy.AgentTool
117}
118
119func NewSessionAgent(
120 opts SessionAgentOptions,
121) SessionAgent {
122 return &sessionAgent{
123 largeModel: opts.LargeModel,
124 smallModel: opts.SmallModel,
125 systemPromptPrefix: opts.SystemPromptPrefix,
126 systemPrompt: opts.SystemPrompt,
127 isSubAgent: opts.IsSubAgent,
128 sessions: opts.Sessions,
129 messages: opts.Messages,
130 disableAutoSummarize: opts.DisableAutoSummarize,
131 tools: opts.Tools,
132 isYolo: opts.IsYolo,
133 messageQueue: csync.NewMap[string, []SessionAgentCall](),
134 activeRequests: csync.NewMap[string, context.CancelFunc](),
135 }
136}
137
138func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
139 if call.Prompt == "" {
140 return nil, ErrEmptyPrompt
141 }
142 if call.SessionID == "" {
143 return nil, ErrSessionMissing
144 }
145
146 // Queue the message if busy
147 if a.IsSessionBusy(call.SessionID) {
148 existing, ok := a.messageQueue.Get(call.SessionID)
149 if !ok {
150 existing = []SessionAgentCall{}
151 }
152 existing = append(existing, call)
153 a.messageQueue.Set(call.SessionID, existing)
154 return nil, nil
155 }
156
157 if len(a.tools) > 0 {
158 // Add Anthropic caching to the last tool.
159 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
160 }
161
162 agent := fantasy.NewAgent(
163 a.largeModel.Model,
164 fantasy.WithSystemPrompt(a.systemPrompt),
165 fantasy.WithTools(a.tools...),
166 )
167
168 sessionLock := sync.Mutex{}
169 currentSession, err := a.sessions.Get(ctx, call.SessionID)
170 if err != nil {
171 return nil, fmt.Errorf("failed to get session: %w", err)
172 }
173
174 msgs, err := a.getSessionMessages(ctx, currentSession)
175 if err != nil {
176 return nil, fmt.Errorf("failed to get session messages: %w", err)
177 }
178
179 var wg sync.WaitGroup
180 // Generate title if first message.
181 if len(msgs) == 0 {
182 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
183 wg.Go(func() {
184 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
185 })
186 }
187
188 // Add the user message to the session.
189 _, err = a.createUserMessage(ctx, call)
190 if err != nil {
191 return nil, err
192 }
193
194 // Add the session to the context.
195 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
196
197 genCtx, cancel := context.WithCancel(ctx)
198 a.activeRequests.Set(call.SessionID, cancel)
199
200 defer cancel()
201 defer a.activeRequests.Del(call.SessionID)
202
203 history, files := a.preparePrompt(msgs, call.Attachments...)
204
205 startTime := time.Now()
206 a.eventPromptSent(call.SessionID)
207
208 var currentAssistant *message.Message
209 var shouldSummarize bool
210 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
211 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
212 Files: files,
213 Messages: history,
214 ProviderOptions: call.ProviderOptions,
215 MaxOutputTokens: &call.MaxOutputTokens,
216 TopP: call.TopP,
217 Temperature: call.Temperature,
218 PresencePenalty: call.PresencePenalty,
219 TopK: call.TopK,
220 FrequencyPenalty: call.FrequencyPenalty,
221 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
222 prepared.Messages = options.Messages
223 for i := range prepared.Messages {
224 prepared.Messages[i].ProviderOptions = nil
225 }
226
227 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
228 a.messageQueue.Del(call.SessionID)
229 for _, queued := range queuedCalls {
230 userMessage, createErr := a.createUserMessage(callContext, queued)
231 if createErr != nil {
232 return callContext, prepared, createErr
233 }
234 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
235 }
236
237 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
238
239 lastSystemRoleInx := 0
240 systemMessageUpdated := false
241 for i, msg := range prepared.Messages {
242 // Only add cache control to the last message.
243 if msg.Role == fantasy.MessageRoleSystem {
244 lastSystemRoleInx = i
245 } else if !systemMessageUpdated {
246 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
247 systemMessageUpdated = true
248 }
249 // Than add cache control to the last 2 messages.
250 if i > len(prepared.Messages)-3 {
251 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
252 }
253 }
254
255 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
256 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
257 }
258
259 var assistantMsg message.Message
260 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
261 Role: message.Assistant,
262 Parts: []message.ContentPart{},
263 Model: a.largeModel.ModelCfg.Model,
264 Provider: a.largeModel.ModelCfg.Provider,
265 })
266 if err != nil {
267 return callContext, prepared, err
268 }
269 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
270 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
271 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
272 currentAssistant = &assistantMsg
273 return callContext, prepared, err
274 },
275 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
276 currentAssistant.AppendReasoningContent(reasoning.Text)
277 return a.messages.Update(genCtx, *currentAssistant)
278 },
279 OnReasoningDelta: func(id string, text string) error {
280 currentAssistant.AppendReasoningContent(text)
281 return a.messages.Update(genCtx, *currentAssistant)
282 },
283 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
284 // handle anthropic signature
285 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
286 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
287 currentAssistant.AppendReasoningSignature(reasoning.Signature)
288 }
289 }
290 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
291 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
292 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
293 }
294 }
295 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
296 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
297 currentAssistant.SetReasoningResponsesData(reasoning)
298 }
299 }
300 currentAssistant.FinishThinking()
301 return a.messages.Update(genCtx, *currentAssistant)
302 },
303 OnTextDelta: func(id string, text string) error {
304 // Strip leading newline from initial text content. This is is
305 // particularly important in non-interactive mode where leading
306 // newlines are very visible.
307 if len(currentAssistant.Parts) == 0 {
308 text = strings.TrimPrefix(text, "\n")
309 }
310
311 currentAssistant.AppendContent(text)
312 return a.messages.Update(genCtx, *currentAssistant)
313 },
314 OnToolInputStart: func(id string, toolName string) error {
315 toolCall := message.ToolCall{
316 ID: id,
317 Name: toolName,
318 ProviderExecuted: false,
319 Finished: false,
320 }
321 currentAssistant.AddToolCall(toolCall)
322 return a.messages.Update(genCtx, *currentAssistant)
323 },
324 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
325 // TODO: implement
326 },
327 OnToolCall: func(tc fantasy.ToolCallContent) error {
328 toolCall := message.ToolCall{
329 ID: tc.ToolCallID,
330 Name: tc.ToolName,
331 Input: tc.Input,
332 ProviderExecuted: false,
333 Finished: true,
334 }
335 currentAssistant.AddToolCall(toolCall)
336 return a.messages.Update(genCtx, *currentAssistant)
337 },
338 OnToolResult: func(result fantasy.ToolResultContent) error {
339 toolResult := a.convertToToolResult(result)
340 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
341 Role: message.Tool,
342 Parts: []message.ContentPart{
343 toolResult,
344 },
345 })
346 return createMsgErr
347 },
348 OnStepFinish: func(stepResult fantasy.StepResult) error {
349 finishReason := message.FinishReasonUnknown
350 switch stepResult.FinishReason {
351 case fantasy.FinishReasonLength:
352 finishReason = message.FinishReasonMaxTokens
353 case fantasy.FinishReasonStop:
354 finishReason = message.FinishReasonEndTurn
355 case fantasy.FinishReasonToolCalls:
356 finishReason = message.FinishReasonToolUse
357 }
358 currentAssistant.AddFinish(finishReason, "", "")
359 sessionLock.Lock()
360 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
361 if getSessionErr != nil {
362 sessionLock.Unlock()
363 return getSessionErr
364 }
365 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
366 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
367 sessionLock.Unlock()
368 if sessionErr != nil {
369 return sessionErr
370 }
371 return a.messages.Update(genCtx, *currentAssistant)
372 },
373 StopWhen: []fantasy.StopCondition{
374 func(steps []fantasy.StepResult) bool {
375 if len(steps) == 0 {
376 return false
377 }
378 lastStep := steps[len(steps)-1]
379 usage := lastStep.Usage
380 tokens := usage.InputTokens + usage.OutputTokens + usage.CacheCreationTokens + usage.CacheReadTokens
381 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
382 remaining := cw - tokens
383 var threshold int64
384 if cw > 200_000 {
385 threshold = 20_000
386 } else {
387 threshold = int64(float64(cw) * 0.2)
388 }
389 if (remaining <= threshold) && !a.disableAutoSummarize {
390 shouldSummarize = true
391 return true
392 }
393 return false
394 },
395 },
396 })
397
398 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
399
400 if err != nil {
401 isCancelErr := errors.Is(err, context.Canceled)
402 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
403 if currentAssistant == nil {
404 return result, err
405 }
406 // Ensure we finish thinking on error to close the reasoning state.
407 currentAssistant.FinishThinking()
408 toolCalls := currentAssistant.ToolCalls()
409 // INFO: we use the parent context here because the genCtx has been cancelled.
410 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
411 if createErr != nil {
412 return nil, createErr
413 }
414 for _, tc := range toolCalls {
415 if !tc.Finished {
416 tc.Finished = true
417 tc.Input = "{}"
418 currentAssistant.AddToolCall(tc)
419 updateErr := a.messages.Update(ctx, *currentAssistant)
420 if updateErr != nil {
421 return nil, updateErr
422 }
423 }
424
425 found := false
426 for _, msg := range msgs {
427 if msg.Role == message.Tool {
428 for _, tr := range msg.ToolResults() {
429 if tr.ToolCallID == tc.ID {
430 found = true
431 break
432 }
433 }
434 }
435 if found {
436 break
437 }
438 }
439 if found {
440 continue
441 }
442 content := "There was an error while executing the tool"
443 if isCancelErr {
444 content = "Tool execution canceled by user"
445 } else if isPermissionErr {
446 content = "User denied permission"
447 }
448 toolResult := message.ToolResult{
449 ToolCallID: tc.ID,
450 Name: tc.Name,
451 Content: content,
452 IsError: true,
453 }
454 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
455 Role: message.Tool,
456 Parts: []message.ContentPart{
457 toolResult,
458 },
459 })
460 if createErr != nil {
461 return nil, createErr
462 }
463 }
464 var fantasyErr *fantasy.Error
465 var providerErr *fantasy.ProviderError
466 const defaultTitle = "Provider Error"
467 if isCancelErr {
468 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
469 } else if isPermissionErr {
470 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
471 } else if errors.Is(err, hyper.ErrNoCredits) {
472 url := hyper.BaseURL()
473 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
474 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
475 } else if errors.As(err, &providerErr) {
476 if providerErr.Message == "The requested model is not supported." {
477 url := "https://github.com/settings/copilot/features"
478 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
479 currentAssistant.AddFinish(
480 message.FinishReasonError,
481 "Copilot model not enabled",
482 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
483 )
484 } else {
485 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
486 }
487 } else if errors.As(err, &fantasyErr) {
488 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
489 } else {
490 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
491 }
492 // Note: we use the parent context here because the genCtx has been
493 // cancelled.
494 updateErr := a.messages.Update(ctx, *currentAssistant)
495 if updateErr != nil {
496 return nil, updateErr
497 }
498 return nil, err
499 }
500 wg.Wait()
501
502 // Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
503 if !call.NonInteractive {
504 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
505 _ = notification.Send("Crush is waiting...", notifBody)
506 }
507
508 if shouldSummarize {
509 a.activeRequests.Del(call.SessionID)
510 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
511 return nil, summarizeErr
512 }
513 // If the agent wasn't done...
514 if len(currentAssistant.ToolCalls()) > 0 {
515 existing, ok := a.messageQueue.Get(call.SessionID)
516 if !ok {
517 existing = []SessionAgentCall{}
518 }
519 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
520 existing = append(existing, call)
521 a.messageQueue.Set(call.SessionID, existing)
522 }
523 }
524
525 // Release active request before processing queued messages.
526 a.activeRequests.Del(call.SessionID)
527 cancel()
528
529 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
530 if !ok || len(queuedMessages) == 0 {
531 return result, err
532 }
533 // There are queued messages restart the loop.
534 firstQueuedMessage := queuedMessages[0]
535 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
536 return a.Run(ctx, firstQueuedMessage)
537}
538
539func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
540 if a.IsSessionBusy(sessionID) {
541 return ErrSessionBusy
542 }
543
544 currentSession, err := a.sessions.Get(ctx, sessionID)
545 if err != nil {
546 return fmt.Errorf("failed to get session: %w", err)
547 }
548 msgs, err := a.getSessionMessages(ctx, currentSession)
549 if err != nil {
550 return err
551 }
552 if len(msgs) == 0 {
553 // Nothing to summarize.
554 return nil
555 }
556
557 aiMsgs, _ := a.preparePrompt(msgs)
558
559 genCtx, cancel := context.WithCancel(ctx)
560 a.activeRequests.Set(sessionID, cancel)
561 defer a.activeRequests.Del(sessionID)
562 defer cancel()
563
564 agent := fantasy.NewAgent(a.largeModel.Model,
565 fantasy.WithSystemPrompt(string(summaryPrompt)),
566 )
567 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
568 Role: message.Assistant,
569 Model: a.largeModel.Model.Model(),
570 Provider: a.largeModel.Model.Provider(),
571 IsSummaryMessage: true,
572 })
573 if err != nil {
574 return err
575 }
576
577 summaryPromptText := "Provide a detailed summary of our conversation above."
578 if len(currentSession.Todos) > 0 {
579 summaryPromptText += "\n\n## Current Todo List\n\n"
580 for _, t := range currentSession.Todos {
581 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
582 }
583 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
584 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
585 }
586
587 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
588 Prompt: summaryPromptText,
589 Messages: aiMsgs,
590 ProviderOptions: opts,
591 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
592 prepared.Messages = options.Messages
593 if a.systemPromptPrefix != "" {
594 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
595 }
596 return callContext, prepared, nil
597 },
598 OnReasoningDelta: func(id string, text string) error {
599 summaryMessage.AppendReasoningContent(text)
600 return a.messages.Update(genCtx, summaryMessage)
601 },
602 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
603 // Handle anthropic signature.
604 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
605 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
606 summaryMessage.AppendReasoningSignature(signature.Signature)
607 }
608 }
609 summaryMessage.FinishThinking()
610 return a.messages.Update(genCtx, summaryMessage)
611 },
612 OnTextDelta: func(id, text string) error {
613 summaryMessage.AppendContent(text)
614 return a.messages.Update(genCtx, summaryMessage)
615 },
616 })
617 if err != nil {
618 isCancelErr := errors.Is(err, context.Canceled)
619 if isCancelErr {
620 // User cancelled summarize we need to remove the summary message.
621 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
622 return deleteErr
623 }
624 return err
625 }
626
627 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
628 err = a.messages.Update(genCtx, summaryMessage)
629 if err != nil {
630 return err
631 }
632
633 var openrouterCost *float64
634 for _, step := range resp.Steps {
635 stepCost := a.openrouterCost(step.ProviderMetadata)
636 if stepCost != nil {
637 newCost := *stepCost
638 if openrouterCost != nil {
639 newCost += *openrouterCost
640 }
641 openrouterCost = &newCost
642 }
643 }
644
645 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
646
647 // Just in case, get just the last usage info.
648 usage := resp.Response.Usage
649 currentSession.SummaryMessageID = summaryMessage.ID
650 currentSession.CompletionTokens = usage.OutputTokens
651 currentSession.PromptTokens = 0
652 _, err = a.sessions.Save(genCtx, currentSession)
653 return err
654}
655
656func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
657 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
658 return fantasy.ProviderOptions{}
659 }
660 return fantasy.ProviderOptions{
661 anthropic.Name: &anthropic.ProviderCacheControlOptions{
662 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
663 },
664 bedrock.Name: &anthropic.ProviderCacheControlOptions{
665 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
666 },
667 }
668}
669
670func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
671 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
672 var attachmentParts []message.ContentPart
673 for _, attachment := range call.Attachments {
674 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
675 }
676 parts = append(parts, attachmentParts...)
677 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
678 Role: message.User,
679 Parts: parts,
680 })
681 if err != nil {
682 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
683 }
684 return msg, nil
685}
686
687func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
688 var history []fantasy.Message
689 if !a.isSubAgent {
690 history = append(history, fantasy.NewUserMessage(
691 fmt.Sprintf("<system_reminder>%s</system_reminder>",
692 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
693If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
694If not, please feel free to ignore. Again do not mention this message to the user.`,
695 ),
696 ))
697 }
698 for _, m := range msgs {
699 if len(m.Parts) == 0 {
700 continue
701 }
702 // Assistant message without content or tool calls (cancelled before it
703 // returned anything).
704 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
705 continue
706 }
707 history = append(history, m.ToAIMessage()...)
708 }
709
710 var files []fantasy.FilePart
711 for _, attachment := range attachments {
712 if attachment.IsText() {
713 continue
714 }
715 files = append(files, fantasy.FilePart{
716 Filename: attachment.FileName,
717 Data: attachment.Content,
718 MediaType: attachment.MimeType,
719 })
720 }
721
722 return history, files
723}
724
725func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
726 msgs, err := a.messages.List(ctx, session.ID)
727 if err != nil {
728 return nil, fmt.Errorf("failed to list messages: %w", err)
729 }
730
731 if session.SummaryMessageID != "" {
732 summaryMsgInex := -1
733 for i, msg := range msgs {
734 if msg.ID == session.SummaryMessageID {
735 summaryMsgInex = i
736 break
737 }
738 }
739 if summaryMsgInex != -1 {
740 msgs = msgs[summaryMsgInex:]
741 msgs[0].Role = message.User
742 }
743 }
744 return msgs, nil
745}
746
747// generateTitle generates a session titled based on the initial prompt.
748func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
749 if userPrompt == "" {
750 return
751 }
752
753 var maxOutputTokens int64 = 40
754 if a.smallModel.CatwalkCfg.CanReason {
755 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
756 }
757
758 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
759 return fantasy.NewAgent(m,
760 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
761 fantasy.WithMaxOutputTokens(tok),
762 )
763 }
764
765 streamCall := fantasy.AgentStreamCall{
766 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
767 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
768 prepared.Messages = opts.Messages
769 if a.systemPromptPrefix != "" {
770 prepared.Messages = append([]fantasy.Message{
771 fantasy.NewSystemMessage(a.systemPromptPrefix),
772 }, prepared.Messages...)
773 }
774 return callCtx, prepared, nil
775 },
776 }
777
778 // Use the small model to generate the title.
779 model := &a.smallModel
780 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
781 resp, err := agent.Stream(ctx, streamCall)
782 if err == nil {
783 // We successfully generated a title with the small model.
784 slog.Info("generated title with small model")
785 } else {
786 // It didn't work. Let's try with the big model.
787 slog.Error("error generating title with small model; trying big model", "err", err)
788 model = &a.largeModel
789 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
790 resp, err = agent.Stream(ctx, streamCall)
791 if err == nil {
792 slog.Info("generated title with large model")
793 } else {
794 // Welp, the large model didn't work either. Use the default
795 // session name and return.
796 slog.Error("error generating title with large model", "err", err)
797 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
798 if saveErr != nil {
799 slog.Error("failed to save session title and usage", "error", saveErr)
800 }
801 return
802 }
803 }
804
805 if resp == nil {
806 // Actually, we didn't get a response so we can't. Use the default
807 // session name and return.
808 slog.Error("response is nil; can't generate title")
809 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
810 if saveErr != nil {
811 slog.Error("failed to save session title and usage", "error", saveErr)
812 }
813 return
814 }
815
816 // Clean up title.
817 var title string
818 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
819 slog.Info("generated title", "title", title)
820
821 // Remove thinking tags if present.
822 title = thinkTagRegex.ReplaceAllString(title, "")
823
824 title = strings.TrimSpace(title)
825 if title == "" {
826 slog.Warn("empty title; using fallback")
827 title = defaultSessionName
828 }
829
830 // Calculate usage and cost.
831 var openrouterCost *float64
832 for _, step := range resp.Steps {
833 stepCost := a.openrouterCost(step.ProviderMetadata)
834 if stepCost != nil {
835 newCost := *stepCost
836 if openrouterCost != nil {
837 newCost += *openrouterCost
838 }
839 openrouterCost = &newCost
840 }
841 }
842
843 modelConfig := model.CatwalkCfg
844 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
845 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
846 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
847 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
848
849 // Use override cost if available (e.g., from OpenRouter).
850 if openrouterCost != nil {
851 cost = *openrouterCost
852 }
853
854 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
855 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
856
857 // Atomically update only title and usage fields to avoid overriding other
858 // concurrent session updates.
859 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
860 if saveErr != nil {
861 slog.Error("failed to save session title and usage", "error", saveErr)
862 return
863 }
864}
865
866func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
867 openrouterMetadata, ok := metadata[openrouter.Name]
868 if !ok {
869 return nil
870 }
871
872 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
873 if !ok {
874 return nil
875 }
876 return &opts.Usage.Cost
877}
878
879func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
880 modelConfig := model.CatwalkCfg
881 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
882 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
883 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
884 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
885
886 a.eventTokensUsed(session.ID, model, usage, cost)
887
888 if overrideCost != nil {
889 session.Cost += *overrideCost
890 } else {
891 session.Cost += cost
892 }
893
894 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
895 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
896}
897
898func (a *sessionAgent) Cancel(sessionID string) {
899 // Cancel regular requests.
900 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
901 slog.Info("Request cancellation initiated", "session_id", sessionID)
902 cancel()
903 }
904
905 // Also check for summarize requests.
906 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
907 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
908 cancel()
909 }
910
911 if a.QueuedPrompts(sessionID) > 0 {
912 slog.Info("Clearing queued prompts", "session_id", sessionID)
913 a.messageQueue.Del(sessionID)
914 }
915}
916
917func (a *sessionAgent) ClearQueue(sessionID string) {
918 if a.QueuedPrompts(sessionID) > 0 {
919 slog.Info("Clearing queued prompts", "session_id", sessionID)
920 a.messageQueue.Del(sessionID)
921 }
922}
923
924func (a *sessionAgent) CancelAll() {
925 if !a.IsBusy() {
926 return
927 }
928 for key := range a.activeRequests.Seq2() {
929 a.Cancel(key) // key is sessionID
930 }
931
932 timeout := time.After(5 * time.Second)
933 for a.IsBusy() {
934 select {
935 case <-timeout:
936 return
937 default:
938 time.Sleep(200 * time.Millisecond)
939 }
940 }
941}
942
943func (a *sessionAgent) IsBusy() bool {
944 var busy bool
945 for cancelFunc := range a.activeRequests.Seq() {
946 if cancelFunc != nil {
947 busy = true
948 break
949 }
950 }
951 return busy
952}
953
954func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
955 _, busy := a.activeRequests.Get(sessionID)
956 return busy
957}
958
959func (a *sessionAgent) QueuedPrompts(sessionID string) int {
960 l, ok := a.messageQueue.Get(sessionID)
961 if !ok {
962 return 0
963 }
964 return len(l)
965}
966
967func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
968 l, ok := a.messageQueue.Get(sessionID)
969 if !ok {
970 return nil
971 }
972 prompts := make([]string, len(l))
973 for i, call := range l {
974 prompts[i] = call.Prompt
975 }
976 return prompts
977}
978
979func (a *sessionAgent) SetModels(large Model, small Model) {
980 a.largeModel = large
981 a.smallModel = small
982}
983
984func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
985 a.tools = tools
986}
987
988func (a *sessionAgent) Model() Model {
989 return a.largeModel
990}
991
992func (a *sessionAgent) promptPrefix() string {
993 return a.systemPromptPrefix
994}
995
996// convertToToolResult converts a fantasy tool result to a message tool result.
997func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
998 baseResult := message.ToolResult{
999 ToolCallID: result.ToolCallID,
1000 Name: result.ToolName,
1001 Metadata: result.ClientMetadata,
1002 }
1003
1004 switch result.Result.GetType() {
1005 case fantasy.ToolResultContentTypeText:
1006 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1007 baseResult.Content = r.Text
1008 }
1009 case fantasy.ToolResultContentTypeError:
1010 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1011 baseResult.Content = r.Error.Error()
1012 baseResult.IsError = true
1013 }
1014 case fantasy.ToolResultContentTypeMedia:
1015 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1016 content := r.Text
1017 if content == "" {
1018 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1019 }
1020 baseResult.Content = content
1021 baseResult.Data = r.Data
1022 baseResult.MIMEType = r.MediaType
1023 }
1024 }
1025
1026 return baseResult
1027}
1028
1029// workaroundProviderMediaLimitations converts media content in tool results to
1030// user messages for providers that don't natively support images in tool results.
1031//
1032// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1033// don't support sending images/media in tool result messages - they only accept
1034// text in tool results. However, they DO support images in user messages.
1035//
1036// If we send media in tool results to these providers, the API returns an error.
1037//
1038// Solution: For these providers, we:
1039// 1. Replace the media in the tool result with a text placeholder
1040// 2. Inject a user message immediately after with the image as a file attachment
1041// 3. This maintains the tool execution flow while working around API limitations
1042//
1043// Anthropic and Bedrock support images natively in tool results, so we skip
1044// this workaround for them.
1045//
1046// Example transformation:
1047//
1048// BEFORE: [tool result: image data]
1049// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1050func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1051 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1052 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1053
1054 if providerSupportsMedia {
1055 return messages
1056 }
1057
1058 convertedMessages := make([]fantasy.Message, 0, len(messages))
1059
1060 for _, msg := range messages {
1061 if msg.Role != fantasy.MessageRoleTool {
1062 convertedMessages = append(convertedMessages, msg)
1063 continue
1064 }
1065
1066 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1067 var mediaFiles []fantasy.FilePart
1068
1069 for _, part := range msg.Content {
1070 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1071 if !ok {
1072 textParts = append(textParts, part)
1073 continue
1074 }
1075
1076 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1077 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1078 if err != nil {
1079 slog.Warn("failed to decode media data", "error", err)
1080 textParts = append(textParts, part)
1081 continue
1082 }
1083
1084 mediaFiles = append(mediaFiles, fantasy.FilePart{
1085 Data: decoded,
1086 MediaType: media.MediaType,
1087 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1088 })
1089
1090 textParts = append(textParts, fantasy.ToolResultPart{
1091 ToolCallID: toolResult.ToolCallID,
1092 Output: fantasy.ToolResultOutputContentText{
1093 Text: "[Image/media content loaded - see attached file]",
1094 },
1095 ProviderOptions: toolResult.ProviderOptions,
1096 })
1097 } else {
1098 textParts = append(textParts, part)
1099 }
1100 }
1101
1102 convertedMessages = append(convertedMessages, fantasy.Message{
1103 Role: fantasy.MessageRoleTool,
1104 Content: textParts,
1105 })
1106
1107 if len(mediaFiles) > 0 {
1108 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1109 "Here is the media content from the tool result:",
1110 mediaFiles...,
1111 ))
1112 }
1113 }
1114
1115 return convertedMessages
1116}