1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/lipgloss/v2"
33 "git.secluded.site/crush/internal/agent/hyper"
34 "git.secluded.site/crush/internal/agent/tools"
35 "git.secluded.site/crush/internal/agent/tools/mcp"
36 "git.secluded.site/crush/internal/config"
37 "git.secluded.site/crush/internal/csync"
38 "git.secluded.site/crush/internal/message"
39 "git.secluded.site/crush/internal/permission"
40 "git.secluded.site/crush/internal/session"
41 "git.secluded.site/crush/internal/stringext"
42 "git.secluded.site/crush/internal/ui/notification"
43 "github.com/charmbracelet/x/exp/charmtone"
44)
45
46const (
47 defaultSessionName = "Untitled Session"
48
49 // Constants for auto-summarization thresholds
50 largeContextWindowThreshold = 200_000
51 largeContextWindowBuffer = 20_000
52 smallContextWindowRatio = 0.2
53)
54
55//go:embed templates/title.md
56var titlePrompt []byte
57
58//go:embed templates/summary.md
59var summaryPrompt []byte
60
61// Used to remove <think> tags from generated titles.
62var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
63
64type SessionAgentCall struct {
65 SessionID string
66 Prompt string
67 ProviderOptions fantasy.ProviderOptions
68 Attachments []message.Attachment
69 MaxOutputTokens int64
70 Temperature *float64
71 TopP *float64
72 TopK *int64
73 FrequencyPenalty *float64
74 PresencePenalty *float64
75 NonInteractive bool
76}
77
78type SessionAgent interface {
79 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
80 SetModels(large Model, small Model)
81 SetTools(tools []fantasy.AgentTool)
82 SetSystemPrompt(systemPrompt string)
83 Cancel(sessionID string)
84 CancelAll()
85 IsSessionBusy(sessionID string) bool
86 IsBusy() bool
87 QueuedPrompts(sessionID string) int
88 QueuedPromptsList(sessionID string) []string
89 ClearQueue(sessionID string)
90 Summarize(context.Context, string, fantasy.ProviderOptions) error
91 Model() Model
92}
93
94type Model struct {
95 Model fantasy.LanguageModel
96 CatwalkCfg catwalk.Model
97 ModelCfg config.SelectedModel
98}
99
100type sessionAgent struct {
101 largeModel *csync.Value[Model]
102 smallModel *csync.Value[Model]
103 systemPromptPrefix *csync.Value[string]
104 systemPrompt *csync.Value[string]
105 tools *csync.Slice[fantasy.AgentTool]
106
107 isSubAgent bool
108 sessions session.Service
109 messages message.Service
110 disableAutoSummarize bool
111 isYolo bool
112 notify notification.Sink
113
114 messageQueue *csync.Map[string, []SessionAgentCall]
115 activeRequests *csync.Map[string, context.CancelFunc]
116}
117
118type SessionAgentOptions struct {
119 LargeModel Model
120 SmallModel Model
121 SystemPromptPrefix string
122 SystemPrompt string
123 IsSubAgent bool
124 DisableAutoSummarize bool
125 IsYolo bool
126 Sessions session.Service
127 Messages message.Service
128 Tools []fantasy.AgentTool
129 Notify notification.Sink
130}
131
132func NewSessionAgent(
133 opts SessionAgentOptions,
134) SessionAgent {
135 return &sessionAgent{
136 largeModel: csync.NewValue(opts.LargeModel),
137 smallModel: csync.NewValue(opts.SmallModel),
138 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
139 systemPrompt: csync.NewValue(opts.SystemPrompt),
140 isSubAgent: opts.IsSubAgent,
141 sessions: opts.Sessions,
142 messages: opts.Messages,
143 disableAutoSummarize: opts.DisableAutoSummarize,
144 tools: csync.NewSliceFrom(opts.Tools),
145 isYolo: opts.IsYolo,
146 notify: opts.Notify,
147 messageQueue: csync.NewMap[string, []SessionAgentCall](),
148 activeRequests: csync.NewMap[string, context.CancelFunc](),
149 }
150}
151
152func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
153 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
154 return nil, ErrEmptyPrompt
155 }
156 if call.SessionID == "" {
157 return nil, ErrSessionMissing
158 }
159
160 // Queue the message if busy
161 if a.IsSessionBusy(call.SessionID) {
162 existing, ok := a.messageQueue.Get(call.SessionID)
163 if !ok {
164 existing = []SessionAgentCall{}
165 }
166 existing = append(existing, call)
167 a.messageQueue.Set(call.SessionID, existing)
168 return nil, nil
169 }
170
171 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
172 agentTools := a.tools.Copy()
173 largeModel := a.largeModel.Get()
174 systemPrompt := a.systemPrompt.Get()
175 promptPrefix := a.systemPromptPrefix.Get()
176 var instructions strings.Builder
177
178 for _, server := range mcp.GetStates() {
179 if server.State != mcp.StateConnected {
180 continue
181 }
182 if s := server.Client.InitializeResult().Instructions; s != "" {
183 instructions.WriteString(s)
184 instructions.WriteString("\n\n")
185 }
186 }
187
188 if s := instructions.String(); s != "" {
189 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
190 }
191
192 if len(agentTools) > 0 {
193 // Add Anthropic caching to the last tool.
194 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
195 }
196
197 agent := fantasy.NewAgent(
198 largeModel.Model,
199 fantasy.WithSystemPrompt(systemPrompt),
200 fantasy.WithTools(agentTools...),
201 )
202
203 sessionLock := sync.Mutex{}
204 currentSession, err := a.sessions.Get(ctx, call.SessionID)
205 if err != nil {
206 return nil, fmt.Errorf("failed to get session: %w", err)
207 }
208
209 msgs, err := a.getSessionMessages(ctx, currentSession)
210 if err != nil {
211 return nil, fmt.Errorf("failed to get session messages: %w", err)
212 }
213
214 var wg sync.WaitGroup
215 // Generate title if first message.
216 if len(msgs) == 0 {
217 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
218 wg.Go(func() {
219 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
220 })
221 }
222 defer wg.Wait()
223
224 // Add the user message to the session.
225 _, err = a.createUserMessage(ctx, call)
226 if err != nil {
227 return nil, err
228 }
229
230 // Add the session to the context.
231 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
232
233 genCtx, cancel := context.WithCancel(ctx)
234 a.activeRequests.Set(call.SessionID, cancel)
235
236 defer cancel()
237 defer a.activeRequests.Del(call.SessionID)
238
239 history, files := a.preparePrompt(msgs, call.Attachments...)
240
241 startTime := time.Now()
242 a.eventPromptSent(call.SessionID)
243
244 var currentAssistant *message.Message
245 var shouldSummarize bool
246 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
247 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
248 Files: files,
249 Messages: history,
250 ProviderOptions: call.ProviderOptions,
251 MaxOutputTokens: &call.MaxOutputTokens,
252 TopP: call.TopP,
253 Temperature: call.Temperature,
254 PresencePenalty: call.PresencePenalty,
255 TopK: call.TopK,
256 FrequencyPenalty: call.FrequencyPenalty,
257 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
258 prepared.Messages = options.Messages
259 for i := range prepared.Messages {
260 prepared.Messages[i].ProviderOptions = nil
261 }
262
263 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
264 a.messageQueue.Del(call.SessionID)
265 for _, queued := range queuedCalls {
266 userMessage, createErr := a.createUserMessage(callContext, queued)
267 if createErr != nil {
268 return callContext, prepared, createErr
269 }
270 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
271 }
272
273 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
274
275 lastSystemRoleInx := 0
276 systemMessageUpdated := false
277 for i, msg := range prepared.Messages {
278 // Only add cache control to the last message.
279 if msg.Role == fantasy.MessageRoleSystem {
280 lastSystemRoleInx = i
281 } else if !systemMessageUpdated {
282 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
283 systemMessageUpdated = true
284 }
285 // Than add cache control to the last 2 messages.
286 if i > len(prepared.Messages)-3 {
287 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
288 }
289 }
290
291 if promptPrefix != "" {
292 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
293 }
294
295 var assistantMsg message.Message
296 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
297 Role: message.Assistant,
298 Parts: []message.ContentPart{},
299 Model: largeModel.ModelCfg.Model,
300 Provider: largeModel.ModelCfg.Provider,
301 })
302 if err != nil {
303 return callContext, prepared, err
304 }
305 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
306 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
307 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
308 currentAssistant = &assistantMsg
309 return callContext, prepared, err
310 },
311 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
312 currentAssistant.AppendReasoningContent(reasoning.Text)
313 return a.messages.Update(genCtx, *currentAssistant)
314 },
315 OnReasoningDelta: func(id string, text string) error {
316 currentAssistant.AppendReasoningContent(text)
317 return a.messages.Update(genCtx, *currentAssistant)
318 },
319 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
320 // handle anthropic signature
321 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
322 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
323 currentAssistant.AppendReasoningSignature(reasoning.Signature)
324 }
325 }
326 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
327 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
328 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
329 }
330 }
331 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
332 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
333 currentAssistant.SetReasoningResponsesData(reasoning)
334 }
335 }
336 currentAssistant.FinishThinking()
337 return a.messages.Update(genCtx, *currentAssistant)
338 },
339 OnTextDelta: func(id string, text string) error {
340 // Strip leading newline from initial text content. This is is
341 // particularly important in non-interactive mode where leading
342 // newlines are very visible.
343 if len(currentAssistant.Parts) == 0 {
344 text = strings.TrimPrefix(text, "\n")
345 }
346
347 currentAssistant.AppendContent(text)
348 return a.messages.Update(genCtx, *currentAssistant)
349 },
350 OnToolInputStart: func(id string, toolName string) error {
351 toolCall := message.ToolCall{
352 ID: id,
353 Name: toolName,
354 ProviderExecuted: false,
355 Finished: false,
356 }
357 currentAssistant.AddToolCall(toolCall)
358 return a.messages.Update(genCtx, *currentAssistant)
359 },
360 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
361 // TODO: implement
362 },
363 OnToolCall: func(tc fantasy.ToolCallContent) error {
364 toolCall := message.ToolCall{
365 ID: tc.ToolCallID,
366 Name: tc.ToolName,
367 Input: tc.Input,
368 ProviderExecuted: false,
369 Finished: true,
370 }
371 currentAssistant.AddToolCall(toolCall)
372 return a.messages.Update(genCtx, *currentAssistant)
373 },
374 OnToolResult: func(result fantasy.ToolResultContent) error {
375 toolResult := a.convertToToolResult(result)
376 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
377 Role: message.Tool,
378 Parts: []message.ContentPart{
379 toolResult,
380 },
381 })
382 return createMsgErr
383 },
384 OnStepFinish: func(stepResult fantasy.StepResult) error {
385 finishReason := message.FinishReasonUnknown
386 switch stepResult.FinishReason {
387 case fantasy.FinishReasonLength:
388 finishReason = message.FinishReasonMaxTokens
389 case fantasy.FinishReasonStop:
390 finishReason = message.FinishReasonEndTurn
391 case fantasy.FinishReasonToolCalls:
392 finishReason = message.FinishReasonToolUse
393 }
394 currentAssistant.AddFinish(finishReason, "", "")
395 sessionLock.Lock()
396 defer sessionLock.Unlock()
397
398 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
399 if getSessionErr != nil {
400 return getSessionErr
401 }
402 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
403 _, sessionErr := a.sessions.Save(ctx, updatedSession)
404 if sessionErr != nil {
405 return sessionErr
406 }
407 currentSession = updatedSession
408 return a.messages.Update(genCtx, *currentAssistant)
409 },
410 StopWhen: []fantasy.StopCondition{
411 func(_ []fantasy.StepResult) bool {
412 cw := int64(largeModel.CatwalkCfg.ContextWindow)
413 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
414 remaining := cw - tokens
415 var threshold int64
416 if cw > largeContextWindowThreshold {
417 threshold = largeContextWindowBuffer
418 } else {
419 threshold = int64(float64(cw) * smallContextWindowRatio)
420 }
421 if (remaining <= threshold) && !a.disableAutoSummarize {
422 shouldSummarize = true
423 return true
424 }
425 return false
426 },
427 },
428 })
429
430 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
431
432 if err != nil {
433 isCancelErr := errors.Is(err, context.Canceled)
434 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
435 if currentAssistant == nil {
436 return result, err
437 }
438 // Ensure we finish thinking on error to close the reasoning state.
439 currentAssistant.FinishThinking()
440 toolCalls := currentAssistant.ToolCalls()
441 // INFO: we use the parent context here because the genCtx has been cancelled.
442 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
443 if createErr != nil {
444 return nil, createErr
445 }
446 for _, tc := range toolCalls {
447 if !tc.Finished {
448 tc.Finished = true
449 tc.Input = "{}"
450 currentAssistant.AddToolCall(tc)
451 updateErr := a.messages.Update(ctx, *currentAssistant)
452 if updateErr != nil {
453 return nil, updateErr
454 }
455 }
456
457 found := false
458 for _, msg := range msgs {
459 if msg.Role == message.Tool {
460 for _, tr := range msg.ToolResults() {
461 if tr.ToolCallID == tc.ID {
462 found = true
463 break
464 }
465 }
466 }
467 if found {
468 break
469 }
470 }
471 if found {
472 continue
473 }
474 content := "There was an error while executing the tool"
475 if isCancelErr {
476 content = "Tool execution canceled by user"
477 } else if isPermissionErr {
478 content = "User denied permission"
479 }
480 toolResult := message.ToolResult{
481 ToolCallID: tc.ID,
482 Name: tc.Name,
483 Content: content,
484 IsError: true,
485 }
486 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
487 Role: message.Tool,
488 Parts: []message.ContentPart{
489 toolResult,
490 },
491 })
492 if createErr != nil {
493 return nil, createErr
494 }
495 }
496 var fantasyErr *fantasy.Error
497 var providerErr *fantasy.ProviderError
498 const defaultTitle = "Provider Error"
499 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
500 if isCancelErr {
501 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
502 } else if isPermissionErr {
503 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
504 } else if errors.Is(err, hyper.ErrNoCredits) {
505 url := hyper.BaseURL()
506 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
507 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
508 } else if errors.As(err, &providerErr) {
509 if providerErr.Message == "The requested model is not supported." {
510 url := "https://github.com/settings/copilot/features"
511 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
512 currentAssistant.AddFinish(
513 message.FinishReasonError,
514 "Copilot model not enabled",
515 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
516 )
517 } else {
518 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
519 }
520 } else if errors.As(err, &fantasyErr) {
521 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
522 } else {
523 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
524 }
525 // Note: we use the parent context here because the genCtx has been
526 // cancelled.
527 updateErr := a.messages.Update(ctx, *currentAssistant)
528 if updateErr != nil {
529 return nil, updateErr
530 }
531 return nil, err
532 }
533
534 // Send notification that agent has finished its turn (skip for
535 // nested/non-interactive sessions).
536 if !call.NonInteractive && a.notify != nil {
537 a.notify(notification.Notification{
538 Title: "Crush is waiting...",
539 Message: fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title),
540 })
541 }
542
543 if shouldSummarize {
544 a.activeRequests.Del(call.SessionID)
545 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
546 return nil, summarizeErr
547 }
548 // If the agent wasn't done...
549 if len(currentAssistant.ToolCalls()) > 0 {
550 existing, ok := a.messageQueue.Get(call.SessionID)
551 if !ok {
552 existing = []SessionAgentCall{}
553 }
554 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
555 existing = append(existing, call)
556 a.messageQueue.Set(call.SessionID, existing)
557 }
558 }
559
560 // Release active request before processing queued messages.
561 a.activeRequests.Del(call.SessionID)
562 cancel()
563
564 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
565 if !ok || len(queuedMessages) == 0 {
566 return result, err
567 }
568 // There are queued messages restart the loop.
569 firstQueuedMessage := queuedMessages[0]
570 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
571 return a.Run(ctx, firstQueuedMessage)
572}
573
574func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
575 if a.IsSessionBusy(sessionID) {
576 return ErrSessionBusy
577 }
578
579 // Copy mutable fields under lock to avoid races with SetModels.
580 largeModel := a.largeModel.Get()
581 systemPromptPrefix := a.systemPromptPrefix.Get()
582
583 currentSession, err := a.sessions.Get(ctx, sessionID)
584 if err != nil {
585 return fmt.Errorf("failed to get session: %w", err)
586 }
587 msgs, err := a.getSessionMessages(ctx, currentSession)
588 if err != nil {
589 return err
590 }
591 if len(msgs) == 0 {
592 // Nothing to summarize.
593 return nil
594 }
595
596 aiMsgs, _ := a.preparePrompt(msgs)
597
598 genCtx, cancel := context.WithCancel(ctx)
599 a.activeRequests.Set(sessionID, cancel)
600 defer a.activeRequests.Del(sessionID)
601 defer cancel()
602
603 agent := fantasy.NewAgent(largeModel.Model,
604 fantasy.WithSystemPrompt(string(summaryPrompt)),
605 )
606 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
607 Role: message.Assistant,
608 Model: largeModel.Model.Model(),
609 Provider: largeModel.Model.Provider(),
610 IsSummaryMessage: true,
611 })
612 if err != nil {
613 return err
614 }
615
616 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
617
618 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
619 Prompt: summaryPromptText,
620 Messages: aiMsgs,
621 ProviderOptions: opts,
622 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
623 prepared.Messages = options.Messages
624 if systemPromptPrefix != "" {
625 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
626 }
627 return callContext, prepared, nil
628 },
629 OnReasoningDelta: func(id string, text string) error {
630 summaryMessage.AppendReasoningContent(text)
631 return a.messages.Update(genCtx, summaryMessage)
632 },
633 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
634 // Handle anthropic signature.
635 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
636 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
637 summaryMessage.AppendReasoningSignature(signature.Signature)
638 }
639 }
640 summaryMessage.FinishThinking()
641 return a.messages.Update(genCtx, summaryMessage)
642 },
643 OnTextDelta: func(id, text string) error {
644 summaryMessage.AppendContent(text)
645 return a.messages.Update(genCtx, summaryMessage)
646 },
647 })
648 if err != nil {
649 isCancelErr := errors.Is(err, context.Canceled)
650 if isCancelErr {
651 // User cancelled summarize we need to remove the summary message.
652 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
653 return deleteErr
654 }
655 return err
656 }
657
658 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
659 err = a.messages.Update(genCtx, summaryMessage)
660 if err != nil {
661 return err
662 }
663
664 var openrouterCost *float64
665 for _, step := range resp.Steps {
666 stepCost := a.openrouterCost(step.ProviderMetadata)
667 if stepCost != nil {
668 newCost := *stepCost
669 if openrouterCost != nil {
670 newCost += *openrouterCost
671 }
672 openrouterCost = &newCost
673 }
674 }
675
676 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
677
678 // Just in case, get just the last usage info.
679 usage := resp.Response.Usage
680 currentSession.SummaryMessageID = summaryMessage.ID
681 currentSession.CompletionTokens = usage.OutputTokens
682 currentSession.PromptTokens = 0
683 _, err = a.sessions.Save(genCtx, currentSession)
684 return err
685}
686
687func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
688 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
689 return fantasy.ProviderOptions{}
690 }
691 return fantasy.ProviderOptions{
692 anthropic.Name: &anthropic.ProviderCacheControlOptions{
693 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
694 },
695 bedrock.Name: &anthropic.ProviderCacheControlOptions{
696 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
697 },
698 }
699}
700
701func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
702 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
703 var attachmentParts []message.ContentPart
704 for _, attachment := range call.Attachments {
705 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
706 }
707 parts = append(parts, attachmentParts...)
708 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
709 Role: message.User,
710 Parts: parts,
711 })
712 if err != nil {
713 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
714 }
715 return msg, nil
716}
717
718func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
719 var history []fantasy.Message
720 if !a.isSubAgent {
721 history = append(history, fantasy.NewUserMessage(
722 fmt.Sprintf("<system_reminder>%s</system_reminder>",
723 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
724If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
725If not, please feel free to ignore. Again do not mention this message to the user.`,
726 ),
727 ))
728 }
729 for _, m := range msgs {
730 if len(m.Parts) == 0 {
731 continue
732 }
733 // Assistant message without content or tool calls (cancelled before it
734 // returned anything).
735 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
736 continue
737 }
738 history = append(history, m.ToAIMessage()...)
739 }
740
741 var files []fantasy.FilePart
742 for _, attachment := range attachments {
743 if attachment.IsText() {
744 continue
745 }
746 files = append(files, fantasy.FilePart{
747 Filename: attachment.FileName,
748 Data: attachment.Content,
749 MediaType: attachment.MimeType,
750 })
751 }
752
753 return history, files
754}
755
756func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
757 msgs, err := a.messages.List(ctx, session.ID)
758 if err != nil {
759 return nil, fmt.Errorf("failed to list messages: %w", err)
760 }
761
762 if session.SummaryMessageID != "" {
763 summaryMsgIndex := -1
764 for i, msg := range msgs {
765 if msg.ID == session.SummaryMessageID {
766 summaryMsgIndex = i
767 break
768 }
769 }
770 if summaryMsgIndex != -1 {
771 msgs = msgs[summaryMsgIndex:]
772 msgs[0].Role = message.User
773 }
774 }
775 return msgs, nil
776}
777
778// generateTitle generates a session titled based on the initial prompt.
779func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
780 if userPrompt == "" {
781 return
782 }
783
784 smallModel := a.smallModel.Get()
785 largeModel := a.largeModel.Get()
786 systemPromptPrefix := a.systemPromptPrefix.Get()
787
788 var maxOutputTokens int64 = 40
789 if smallModel.CatwalkCfg.CanReason {
790 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
791 }
792
793 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
794 return fantasy.NewAgent(m,
795 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
796 fantasy.WithMaxOutputTokens(tok),
797 )
798 }
799
800 streamCall := fantasy.AgentStreamCall{
801 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
802 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
803 prepared.Messages = opts.Messages
804 if systemPromptPrefix != "" {
805 prepared.Messages = append([]fantasy.Message{
806 fantasy.NewSystemMessage(systemPromptPrefix),
807 }, prepared.Messages...)
808 }
809 return callCtx, prepared, nil
810 },
811 }
812
813 // Use the small model to generate the title.
814 model := smallModel
815 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
816 resp, err := agent.Stream(ctx, streamCall)
817 if err == nil {
818 // We successfully generated a title with the small model.
819 slog.Debug("Generated title with small model")
820 } else {
821 // It didn't work. Let's try with the big model.
822 slog.Error("Error generating title with small model; trying big model", "err", err)
823 model = largeModel
824 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
825 resp, err = agent.Stream(ctx, streamCall)
826 if err == nil {
827 slog.Debug("Generated title with large model")
828 } else {
829 // Welp, the large model didn't work either. Use the default
830 // session name and return.
831 slog.Error("Error generating title with large model", "err", err)
832 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
833 if saveErr != nil {
834 slog.Error("Failed to save session title and usage", "error", saveErr)
835 }
836 return
837 }
838 }
839
840 if resp == nil {
841 // Actually, we didn't get a response so we can't. Use the default
842 // session name and return.
843 slog.Error("Response is nil; can't generate title")
844 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
845 if saveErr != nil {
846 slog.Error("Failed to save session title and usage", "error", saveErr)
847 }
848 return
849 }
850
851 // Clean up title.
852 var title string
853 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
854
855 // Remove thinking tags if present.
856 title = thinkTagRegex.ReplaceAllString(title, "")
857
858 title = strings.TrimSpace(title)
859 if title == "" {
860 slog.Debug("Empty title; using fallback")
861 title = defaultSessionName
862 }
863
864 // Calculate usage and cost.
865 var openrouterCost *float64
866 for _, step := range resp.Steps {
867 stepCost := a.openrouterCost(step.ProviderMetadata)
868 if stepCost != nil {
869 newCost := *stepCost
870 if openrouterCost != nil {
871 newCost += *openrouterCost
872 }
873 openrouterCost = &newCost
874 }
875 }
876
877 modelConfig := model.CatwalkCfg
878 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
879 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
880 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
881 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
882
883 // Use override cost if available (e.g., from OpenRouter).
884 if openrouterCost != nil {
885 cost = *openrouterCost
886 }
887
888 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
889 completionTokens := resp.TotalUsage.OutputTokens
890
891 // Atomically update only title and usage fields to avoid overriding other
892 // concurrent session updates.
893 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
894 if saveErr != nil {
895 slog.Error("Failed to save session title and usage", "error", saveErr)
896 return
897 }
898}
899
900func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
901 openrouterMetadata, ok := metadata[openrouter.Name]
902 if !ok {
903 return nil
904 }
905
906 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
907 if !ok {
908 return nil
909 }
910 return &opts.Usage.Cost
911}
912
913func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
914 modelConfig := model.CatwalkCfg
915 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
916 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
917 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
918 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
919
920 a.eventTokensUsed(session.ID, model, usage, cost)
921
922 if overrideCost != nil {
923 session.Cost += *overrideCost
924 } else {
925 session.Cost += cost
926 }
927
928 session.CompletionTokens = usage.OutputTokens
929 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
930}
931
932func (a *sessionAgent) Cancel(sessionID string) {
933 // Cancel regular requests. Don't use Take() here - we need the entry to
934 // remain in activeRequests so IsBusy() returns true until the goroutine
935 // fully completes (including error handling that may access the DB).
936 // The defer in processRequest will clean up the entry.
937 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
938 slog.Debug("Request cancellation initiated", "session_id", sessionID)
939 cancel()
940 }
941
942 // Also check for summarize requests.
943 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
944 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
945 cancel()
946 }
947
948 if a.QueuedPrompts(sessionID) > 0 {
949 slog.Debug("Clearing queued prompts", "session_id", sessionID)
950 a.messageQueue.Del(sessionID)
951 }
952}
953
954func (a *sessionAgent) ClearQueue(sessionID string) {
955 if a.QueuedPrompts(sessionID) > 0 {
956 slog.Debug("Clearing queued prompts", "session_id", sessionID)
957 a.messageQueue.Del(sessionID)
958 }
959}
960
961func (a *sessionAgent) CancelAll() {
962 if !a.IsBusy() {
963 return
964 }
965 for key := range a.activeRequests.Seq2() {
966 a.Cancel(key) // key is sessionID
967 }
968
969 timeout := time.After(5 * time.Second)
970 for a.IsBusy() {
971 select {
972 case <-timeout:
973 return
974 default:
975 time.Sleep(200 * time.Millisecond)
976 }
977 }
978}
979
980func (a *sessionAgent) IsBusy() bool {
981 var busy bool
982 for cancelFunc := range a.activeRequests.Seq() {
983 if cancelFunc != nil {
984 busy = true
985 break
986 }
987 }
988 return busy
989}
990
991func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
992 _, busy := a.activeRequests.Get(sessionID)
993 return busy
994}
995
996func (a *sessionAgent) QueuedPrompts(sessionID string) int {
997 l, ok := a.messageQueue.Get(sessionID)
998 if !ok {
999 return 0
1000 }
1001 return len(l)
1002}
1003
1004func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1005 l, ok := a.messageQueue.Get(sessionID)
1006 if !ok {
1007 return nil
1008 }
1009 prompts := make([]string, len(l))
1010 for i, call := range l {
1011 prompts[i] = call.Prompt
1012 }
1013 return prompts
1014}
1015
1016func (a *sessionAgent) SetModels(large Model, small Model) {
1017 a.largeModel.Set(large)
1018 a.smallModel.Set(small)
1019}
1020
1021func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1022 a.tools.SetSlice(tools)
1023}
1024
1025func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1026 a.systemPrompt.Set(systemPrompt)
1027}
1028
1029func (a *sessionAgent) Model() Model {
1030 return a.largeModel.Get()
1031}
1032
1033// convertToToolResult converts a fantasy tool result to a message tool result.
1034func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1035 baseResult := message.ToolResult{
1036 ToolCallID: result.ToolCallID,
1037 Name: result.ToolName,
1038 Metadata: result.ClientMetadata,
1039 }
1040
1041 switch result.Result.GetType() {
1042 case fantasy.ToolResultContentTypeText:
1043 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1044 baseResult.Content = r.Text
1045 }
1046 case fantasy.ToolResultContentTypeError:
1047 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1048 baseResult.Content = r.Error.Error()
1049 baseResult.IsError = true
1050 }
1051 case fantasy.ToolResultContentTypeMedia:
1052 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1053 content := r.Text
1054 if content == "" {
1055 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1056 }
1057 baseResult.Content = content
1058 baseResult.Data = r.Data
1059 baseResult.MIMEType = r.MediaType
1060 }
1061 }
1062
1063 return baseResult
1064}
1065
1066// workaroundProviderMediaLimitations converts media content in tool results to
1067// user messages for providers that don't natively support images in tool results.
1068//
1069// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1070// don't support sending images/media in tool result messages - they only accept
1071// text in tool results. However, they DO support images in user messages.
1072//
1073// If we send media in tool results to these providers, the API returns an error.
1074//
1075// Solution: For these providers, we:
1076// 1. Replace the media in the tool result with a text placeholder
1077// 2. Inject a user message immediately after with the image as a file attachment
1078// 3. This maintains the tool execution flow while working around API limitations
1079//
1080// Anthropic and Bedrock support images natively in tool results, so we skip
1081// this workaround for them.
1082//
1083// Example transformation:
1084//
1085// BEFORE: [tool result: image data]
1086// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1087func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1088 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1089 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1090
1091 if providerSupportsMedia {
1092 return messages
1093 }
1094
1095 convertedMessages := make([]fantasy.Message, 0, len(messages))
1096
1097 for _, msg := range messages {
1098 if msg.Role != fantasy.MessageRoleTool {
1099 convertedMessages = append(convertedMessages, msg)
1100 continue
1101 }
1102
1103 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1104 var mediaFiles []fantasy.FilePart
1105
1106 for _, part := range msg.Content {
1107 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1108 if !ok {
1109 textParts = append(textParts, part)
1110 continue
1111 }
1112
1113 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1114 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1115 if err != nil {
1116 slog.Warn("Failed to decode media data", "error", err)
1117 textParts = append(textParts, part)
1118 continue
1119 }
1120
1121 mediaFiles = append(mediaFiles, fantasy.FilePart{
1122 Data: decoded,
1123 MediaType: media.MediaType,
1124 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1125 })
1126
1127 textParts = append(textParts, fantasy.ToolResultPart{
1128 ToolCallID: toolResult.ToolCallID,
1129 Output: fantasy.ToolResultOutputContentText{
1130 Text: "[Image/media content loaded - see attached file]",
1131 },
1132 ProviderOptions: toolResult.ProviderOptions,
1133 })
1134 } else {
1135 textParts = append(textParts, part)
1136 }
1137 }
1138
1139 convertedMessages = append(convertedMessages, fantasy.Message{
1140 Role: fantasy.MessageRoleTool,
1141 Content: textParts,
1142 })
1143
1144 if len(mediaFiles) > 0 {
1145 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1146 "Here is the media content from the tool result:",
1147 mediaFiles...,
1148 ))
1149 }
1150 }
1151
1152 return convertedMessages
1153}
1154
1155// buildSummaryPrompt constructs the prompt text for session summarization.
1156func buildSummaryPrompt(todos []session.Todo) string {
1157 var sb strings.Builder
1158 sb.WriteString("Provide a detailed summary of our conversation above.")
1159 if len(todos) > 0 {
1160 sb.WriteString("\n\n## Current Todo List\n\n")
1161 for _, t := range todos {
1162 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1163 }
1164 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1165 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1166 }
1167 return sb.String()
1168}