1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "net/http"
19 "os"
20 "regexp"
21 "strconv"
22 "strings"
23 "sync"
24 "time"
25
26 "charm.land/catwalk/pkg/catwalk"
27 "charm.land/fantasy"
28 "charm.land/fantasy/providers/anthropic"
29 "charm.land/fantasy/providers/bedrock"
30 "charm.land/fantasy/providers/google"
31 "charm.land/fantasy/providers/openai"
32 "charm.land/fantasy/providers/openrouter"
33 "charm.land/fantasy/providers/vercel"
34 "charm.land/lipgloss/v2"
35 "github.com/charmbracelet/crush/internal/agent/hyper"
36 "github.com/charmbracelet/crush/internal/agent/notify"
37 "github.com/charmbracelet/crush/internal/agent/tools"
38 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
39 "github.com/charmbracelet/crush/internal/config"
40 "github.com/charmbracelet/crush/internal/csync"
41 "github.com/charmbracelet/crush/internal/message"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 // RunID, when non-empty, is the caller-supplied correlator that
75 // gets echoed back on the notify.RunComplete event emitted for
76 // this turn. It is preserved when the call is enqueued behind a
77 // busy session so the queued turn's terminal event is still
78 // recognisable to the original caller. Callers that need a
79 // reliable completion contract (e.g. `crush run` against a
80 // session that may be busy) MUST set it; SessionID alone is
81 // ambiguous when concurrent turns share the same session.
82 RunID string
83 Prompt string
84 ProviderOptions fantasy.ProviderOptions
85 Attachments []message.Attachment
86 MaxOutputTokens int64
87 Temperature *float64
88 TopP *float64
89 TopK *int64
90 FrequencyPenalty *float64
91 PresencePenalty *float64
92 NonInteractive bool
93 // OnComplete, when non-nil, replaces the default RunComplete
94 // publish path: the inner Run hands the terminal payload to this
95 // callback instead of emitting it on the RunComplete broker. The
96 // coordinator uses this hook to coalesce the unauthorized →
97 // re-auth → retry chain into a single user-visible terminal
98 // event, so non-interactive clients (e.g. `crush run`) don't
99 // exit on a stale failed-attempt RunComplete before the
100 // successful retry. It is intentionally stripped when queueing
101 // a busy-session call (see Run): the originating
102 // coordinator.Run has long returned by the time the queued
103 // recursion drains, so falling back to the default broker
104 // publish keeps the event visible to subscribers.
105 OnComplete func(notify.RunComplete)
106}
107
108type SessionAgent interface {
109 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
110 SetModels(large Model, small Model)
111 SetTools(tools []fantasy.AgentTool)
112 SetSystemPrompt(systemPrompt string)
113 Cancel(sessionID string)
114 CancelAll()
115 IsSessionBusy(sessionID string) bool
116 IsBusy() bool
117 QueuedPrompts(sessionID string) int
118 QueuedPromptsList(sessionID string) []string
119 ClearQueue(sessionID string)
120 Summarize(context.Context, string, fantasy.ProviderOptions) error
121 Model() Model
122}
123
124type Model struct {
125 Model fantasy.LanguageModel
126 CatwalkCfg catwalk.Model
127 ModelCfg config.SelectedModel
128 FlatRate bool
129}
130
131type sessionAgent struct {
132 largeModel *csync.Value[Model]
133 smallModel *csync.Value[Model]
134 systemPromptPrefix *csync.Value[string]
135 systemPrompt *csync.Value[string]
136 tools *csync.Slice[fantasy.AgentTool]
137
138 isSubAgent bool
139 sessions session.Service
140 messages message.Service
141 disableAutoSummarize bool
142 isYolo bool
143 notify pubsub.Publisher[notify.Notification]
144 runComplete pubsub.Publisher[notify.RunComplete]
145
146 messageQueue *csync.Map[string, []SessionAgentCall]
147 activeRequests *csync.Map[string, context.CancelFunc]
148}
149
150type SessionAgentOptions struct {
151 LargeModel Model
152 SmallModel Model
153 SystemPromptPrefix string
154 SystemPrompt string
155 IsSubAgent bool
156 DisableAutoSummarize bool
157 IsYolo bool
158 Sessions session.Service
159 Messages message.Service
160 Tools []fantasy.AgentTool
161 Notify pubsub.Publisher[notify.Notification]
162 RunComplete pubsub.Publisher[notify.RunComplete]
163}
164
165func NewSessionAgent(
166 opts SessionAgentOptions,
167) SessionAgent {
168 return &sessionAgent{
169 largeModel: csync.NewValue(opts.LargeModel),
170 smallModel: csync.NewValue(opts.SmallModel),
171 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
172 systemPrompt: csync.NewValue(opts.SystemPrompt),
173 isSubAgent: opts.IsSubAgent,
174 sessions: opts.Sessions,
175 messages: opts.Messages,
176 disableAutoSummarize: opts.DisableAutoSummarize,
177 tools: csync.NewSliceFrom(opts.Tools),
178 isYolo: opts.IsYolo,
179 notify: opts.Notify,
180 runComplete: opts.RunComplete,
181 messageQueue: csync.NewMap[string, []SessionAgentCall](),
182 activeRequests: csync.NewMap[string, context.CancelFunc](),
183 }
184}
185
186// ValidateCall performs the cheap structural validation that
187// sessionAgent.Run requires before a call can be dispatched: a call must
188// carry either a non-empty prompt or a text attachment, and it must name a
189// session. It is exported so callers that accept a run before dispatching it
190// (e.g. backend.SendMessage) can apply the same checks and keep the error
191// contract consistent.
192func ValidateCall(call SessionAgentCall) error {
193 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
194 return ErrEmptyPrompt
195 }
196 if call.SessionID == "" {
197 return ErrSessionMissing
198 }
199 return nil
200}
201
202func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) {
203 if err := ValidateCall(call); err != nil {
204 return nil, err
205 }
206
207 // Queue the message if busy. Strip OnComplete: the caller that
208 // supplied the hook (typically coordinator.Run) has its own
209 // retry/coalesce scope that ends when it returns, so by the time
210 // the queue drains nobody is left to consume the buffered
211 // terminal event. The recursive Run will fall back to the
212 // default broker publish, which is what existing subscribers
213 // expect for queued turns.
214 if a.IsSessionBusy(call.SessionID) {
215 existing, ok := a.messageQueue.Get(call.SessionID)
216 if !ok {
217 existing = []SessionAgentCall{}
218 }
219 queued := call
220 queued.OnComplete = nil
221 existing = append(existing, queued)
222 a.messageQueue.Set(call.SessionID, existing)
223 return nil, nil
224 }
225
226 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
227 agentTools := a.tools.Copy()
228 largeModel := a.largeModel.Get()
229 systemPrompt := a.systemPrompt.Get()
230 promptPrefix := a.systemPromptPrefix.Get()
231 var instructions strings.Builder
232
233 for _, server := range mcp.GetStates() {
234 if server.State != mcp.StateConnected {
235 continue
236 }
237 if s := server.Client.InitializeResult().Instructions; s != "" {
238 instructions.WriteString(s)
239 instructions.WriteString("\n\n")
240 }
241 }
242
243 if s := instructions.String(); s != "" {
244 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
245 }
246
247 if len(agentTools) > 0 {
248 // Add Anthropic caching to the last tool.
249 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
250 }
251
252 agent := fantasy.NewAgent(
253 largeModel.Model,
254 fantasy.WithSystemPrompt(systemPrompt),
255 fantasy.WithTools(agentTools...),
256 fantasy.WithUserAgent(userAgent),
257 )
258
259 sessionLock := sync.Mutex{}
260 currentSession, err := a.sessions.Get(ctx, call.SessionID)
261 if err != nil {
262 return nil, fmt.Errorf("failed to get session: %w", err)
263 }
264
265 msgs, err := a.getSessionMessages(ctx, currentSession)
266 if err != nil {
267 return nil, fmt.Errorf("failed to get session messages: %w", err)
268 }
269
270 var wg sync.WaitGroup
271 // Generate title if first message.
272 if len(msgs) == 0 {
273 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
274 wg.Go(func() {
275 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
276 })
277 }
278 defer wg.Wait()
279
280 // Add the user message to the session.
281 _, err = a.createUserMessage(ctx, call)
282 if err != nil {
283 return nil, err
284 }
285
286 // Add the session to the context.
287 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
288
289 genCtx, cancel := context.WithCancel(ctx)
290 a.activeRequests.Set(call.SessionID, cancel)
291
292 defer cancel()
293 defer a.activeRequests.Del(call.SessionID)
294 // skipRunComplete is set just before the queued-recursion path so
295 // the outer Run doesn't publish a RunComplete that would race
296 // with — and be superseded by — the recursive call's own
297 // RunComplete (each queued user prompt is its own turn and
298 // publishes exactly one terminal event).
299 var skipRunComplete bool
300 // currentAssistant is declared here so the deferred RunComplete
301 // publish below can capture the pointer that PrepareStep will
302 // later (re)assign for each streaming step. The final assistant
303 // message of the turn is the value reachable through this
304 // pointer when the defer runs.
305 var currentAssistant *message.Message
306 // Drain any debounced message updates before returning. message.Service
307 // already flushes synchronously on terminal updates, but a defer here
308 // guarantees the contract at every Run exit (success, error, panic
309 // recovery upstream) without callers needing to know.
310 //
311 // After the flush completes — meaning all per-message
312 // Publish(UpdatedEvent) calls have fired and been buffered into
313 // every subscriber's channel — publish the authoritative
314 // RunComplete event for this turn. The flush-then-publish order
315 // gives well-behaved clients the best chance of seeing the final
316 // message event before RunComplete; the embedded Text field
317 // reconciles for clients that observe the events out of order
318 // (the pubsub broker fan-in does not serialize publishes from
319 // different upstream brokers).
320 defer func() {
321 // Use a context detached from the run context: workspace
322 // shutdown cancels ctx before this goroutine returns, but the
323 // buffered streaming deltas must still land before the DB is
324 // closed. A short timeout bounds the flush.
325 flushCtx, flushCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
326 defer flushCancel()
327 if flushErr := a.messages.FlushAll(flushCtx); flushErr != nil {
328 slog.Error("Failed to flush pending message updates after run", "error", flushErr)
329 }
330 if skipRunComplete {
331 return
332 }
333 complete := notify.RunComplete{SessionID: call.SessionID, RunID: call.RunID}
334 if currentAssistant != nil {
335 complete.MessageID = currentAssistant.ID
336 complete.Text = currentAssistant.Content().String()
337 }
338 if retErr != nil {
339 complete.Error = retErr.Error()
340 complete.Cancelled = errors.Is(retErr, context.Canceled)
341 } else if ctx.Err() != nil {
342 complete.Cancelled = true
343 }
344 // Prefer the per-call hook when supplied so the coordinator
345 // can coalesce retries (e.g. unauthorized → re-auth → retry)
346 // into a single user-visible terminal event. The fallback
347 // must-deliver publish applies bounded-blocking semantics to
348 // the authoritative terminal event so a momentarily-full
349 // subscriber channel can't silently drop it and hang
350 // non-interactive clients waiting on RunComplete.
351 if call.OnComplete != nil {
352 call.OnComplete(complete)
353 return
354 }
355 if a.runComplete == nil {
356 return
357 }
358 a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete)
359 }()
360
361 history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
362
363 startTime := time.Now()
364 a.eventPromptSent(call.SessionID)
365
366 var stepMessages []fantasy.Message
367 var shouldSummarize bool
368 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
369 var maxOutputTokens *int64
370 if call.MaxOutputTokens > 0 {
371 maxOutputTokens = &call.MaxOutputTokens
372 }
373 result, err = agent.Stream(genCtx, fantasy.AgentStreamCall{
374 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
375 Files: files,
376 Messages: history,
377 ProviderOptions: call.ProviderOptions,
378 MaxOutputTokens: maxOutputTokens,
379 TopP: call.TopP,
380 Temperature: call.Temperature,
381 PresencePenalty: call.PresencePenalty,
382 TopK: call.TopK,
383 FrequencyPenalty: call.FrequencyPenalty,
384 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
385 prepared.Messages = options.Messages
386 for i := range prepared.Messages {
387 prepared.Messages[i].ProviderOptions = nil
388 }
389
390 // Use latest tools (updated by SetTools when MCP tools change).
391 prepared.Tools = a.tools.Copy()
392
393 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
394 a.messageQueue.Del(call.SessionID)
395 for _, queued := range queuedCalls {
396 userMessage, createErr := a.createUserMessage(callContext, queued)
397 if createErr != nil {
398 return callContext, prepared, createErr
399 }
400 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
401 }
402
403 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
404
405 lastSystemRoleInx := 0
406 systemMessageUpdated := false
407 for i, msg := range prepared.Messages {
408 // Only add cache control to the last message.
409 if msg.Role == fantasy.MessageRoleSystem {
410 lastSystemRoleInx = i
411 } else if !systemMessageUpdated {
412 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
413 systemMessageUpdated = true
414 }
415 // Than add cache control to the last 2 messages.
416 if i > len(prepared.Messages)-3 {
417 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
418 }
419 }
420
421 if promptPrefix != "" {
422 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
423 }
424
425 sessionLock.Lock()
426 stepMessages = cloneFantasyMessages(prepared.Messages)
427 sessionLock.Unlock()
428
429 var assistantMsg message.Message
430 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
431 Role: message.Assistant,
432 Parts: []message.ContentPart{},
433 Model: largeModel.ModelCfg.Model,
434 Provider: largeModel.ModelCfg.Provider,
435 })
436 if err != nil {
437 return callContext, prepared, err
438 }
439 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
440 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
441 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
442 currentAssistant = &assistantMsg
443 return callContext, prepared, err
444 },
445 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
446 currentAssistant.AppendReasoningContent(reasoning.Text)
447 return a.messages.Update(genCtx, *currentAssistant)
448 },
449 OnReasoningDelta: func(id string, text string) error {
450 currentAssistant.AppendReasoningContent(text)
451 return a.messages.Update(genCtx, *currentAssistant)
452 },
453 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
454 // handle anthropic signature
455 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
456 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
457 currentAssistant.AppendReasoningSignature(reasoning.Signature)
458 }
459 }
460 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
461 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
462 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
463 }
464 }
465 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
466 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
467 currentAssistant.SetReasoningResponsesData(reasoning)
468 }
469 }
470 currentAssistant.FinishThinking()
471 return a.messages.Update(genCtx, *currentAssistant)
472 },
473 OnTextDelta: func(id string, text string) error {
474 // Strip leading newline from initial text content. This is is
475 // particularly important in non-interactive mode where leading
476 // newlines are very visible.
477 if len(currentAssistant.Parts) == 0 {
478 text = strings.TrimPrefix(text, "\n")
479 }
480
481 currentAssistant.AppendContent(text)
482 return a.messages.Update(genCtx, *currentAssistant)
483 },
484 OnToolInputStart: func(id string, toolName string) error {
485 toolCall := message.ToolCall{
486 ID: id,
487 Name: toolName,
488 ProviderExecuted: false,
489 Finished: false,
490 }
491 currentAssistant.AddToolCall(toolCall)
492 // Use parent ctx instead of genCtx to ensure the update succeeds
493 // even if the request is canceled mid-stream
494 return a.messages.Update(ctx, *currentAssistant)
495 },
496 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
497 slog.Warn("Provider request failed, retrying", providerRetryLogFields(err, delay)...)
498 },
499 OnToolCall: func(tc fantasy.ToolCallContent) error {
500 toolCall := message.ToolCall{
501 ID: tc.ToolCallID,
502 Name: tc.ToolName,
503 Input: tc.Input,
504 ProviderExecuted: false,
505 Finished: true,
506 }
507 currentAssistant.AddToolCall(toolCall)
508 // Use parent ctx instead of genCtx to ensure the update succeeds
509 // even if the request is canceled mid-stream
510 return a.messages.Update(ctx, *currentAssistant)
511 },
512 OnToolResult: func(result fantasy.ToolResultContent) error {
513 toolResult := a.convertToToolResult(result)
514 // Use parent ctx instead of genCtx to ensure the message is created
515 // even if the request is canceled mid-stream
516 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
517 Role: message.Tool,
518 Parts: []message.ContentPart{
519 toolResult,
520 },
521 })
522 return createMsgErr
523 },
524 OnStepFinish: func(stepResult fantasy.StepResult) error {
525 finishReason := message.FinishReasonUnknown
526 switch stepResult.FinishReason {
527 case fantasy.FinishReasonLength:
528 finishReason = message.FinishReasonMaxTokens
529 case fantasy.FinishReasonStop:
530 finishReason = message.FinishReasonEndTurn
531 case fantasy.FinishReasonToolCalls:
532 finishReason = message.FinishReasonToolUse
533 }
534 // If a tool result halted the turn (e.g. a hook halt or a
535 // permission denial), the step ends on FinishReasonToolCalls but
536 // the model will not be called again. Treat it as the end of the
537 // turn so the UI can render the assistant footer.
538 if finishReason == message.FinishReasonToolUse {
539 for _, tr := range stepResult.Content.ToolResults() {
540 if tr.StopTurn {
541 finishReason = message.FinishReasonEndTurn
542 break
543 }
544 }
545 }
546 currentAssistant.AddFinish(finishReason, "", "")
547 sessionLock.Lock()
548 defer sessionLock.Unlock()
549
550 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
551 if getSessionErr != nil {
552 return getSessionErr
553 }
554 usage, estimated := fallbackStepUsage(stepMessages, stepResult)
555 a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated)
556 _, sessionErr := a.sessions.Save(ctx, updatedSession)
557 if sessionErr != nil {
558 return sessionErr
559 }
560 currentSession = updatedSession
561 return a.messages.Update(genCtx, *currentAssistant)
562 },
563 StopWhen: []fantasy.StopCondition{
564 func(_ []fantasy.StepResult) bool {
565 cw := int64(largeModel.CatwalkCfg.ContextWindow)
566 // If context window is unknown (0), skip auto-summarize
567 // to avoid immediately truncating custom/local models.
568 if cw == 0 {
569 return false
570 }
571 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
572 remaining := cw - tokens
573 var threshold int64
574 if cw > largeContextWindowThreshold {
575 threshold = largeContextWindowBuffer
576 } else {
577 threshold = int64(float64(cw) * smallContextWindowRatio)
578 }
579 if (remaining <= threshold) && !a.disableAutoSummarize {
580 shouldSummarize = true
581 return true
582 }
583 return false
584 },
585 func(steps []fantasy.StepResult) bool {
586 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
587 },
588 },
589 })
590
591 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
592
593 if err != nil {
594 isHyper := largeModel.ModelCfg.Provider == hyper.Name
595 isCancelErr := errors.Is(err, context.Canceled)
596 if currentAssistant == nil {
597 return result, err
598 }
599 // Persist final state with a context detached from the run
600 // context. The run context (ctx) is derived from the
601 // workspace context, which workspace shutdown cancels before
602 // agent goroutines finish; using ctx here would drop the
603 // final assistant state. WithoutCancel keeps the values
604 // (e.g. session ID) while ignoring cancellation, and a short
605 // timeout bounds the cleanup writes.
606 cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
607 defer cleanupCancel()
608 // Ensure we finish thinking on error to close the reasoning state.
609 currentAssistant.FinishThinking()
610 toolCalls := currentAssistant.ToolCalls()
611 // INFO: we use the cleanup context here because the genCtx has been cancelled.
612 msgs, createErr := a.messages.List(cleanupCtx, currentAssistant.SessionID)
613 if createErr != nil {
614 return nil, createErr
615 }
616 for _, tc := range toolCalls {
617 if !tc.Finished {
618 tc.Finished = true
619 tc.Input = "{}"
620 currentAssistant.AddToolCall(tc)
621 updateErr := a.messages.Update(cleanupCtx, *currentAssistant)
622 if updateErr != nil {
623 return nil, updateErr
624 }
625 }
626
627 found := false
628 for _, msg := range msgs {
629 if msg.Role == message.Tool {
630 for _, tr := range msg.ToolResults() {
631 if tr.ToolCallID == tc.ID {
632 found = true
633 break
634 }
635 }
636 }
637 if found {
638 break
639 }
640 }
641 if found {
642 continue
643 }
644 content := "There was an error while executing the tool"
645 if isCancelErr {
646 content = "Error: user cancelled assistant tool calling"
647 }
648 toolResult := message.ToolResult{
649 ToolCallID: tc.ID,
650 Name: tc.Name,
651 Content: content,
652 IsError: true,
653 }
654 _, createErr = a.messages.Create(cleanupCtx, currentAssistant.SessionID, message.CreateMessageParams{
655 Role: message.Tool,
656 Parts: []message.ContentPart{
657 toolResult,
658 },
659 })
660 if createErr != nil {
661 return nil, createErr
662 }
663 }
664 var fantasyErr *fantasy.Error
665 var providerErr *fantasy.ProviderError
666 const defaultTitle = "Provider Error"
667 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
668 if isCancelErr {
669 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
670 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized {
671 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
672 if a.notify != nil {
673 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
674 SessionID: call.SessionID,
675 SessionTitle: currentSession.Title,
676 Type: notify.TypeReAuthenticate,
677 ProviderID: largeModel.ModelCfg.Provider,
678 })
679 }
680 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired {
681 url := hyper.BaseURL()
682 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
683 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
684 } else if errors.As(err, &providerErr) {
685 if providerErr.Message == "The requested model is not supported." {
686 url := "https://github.com/settings/copilot/features"
687 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
688 currentAssistant.AddFinish(
689 message.FinishReasonError,
690 "Copilot model not enabled",
691 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
692 )
693 } else {
694 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
695 }
696 } else if errors.As(err, &fantasyErr) {
697 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
698 } else {
699 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
700 }
701 // Note: we use the cleanup context here because the genCtx has been
702 // cancelled.
703 updateErr := a.messages.Update(cleanupCtx, *currentAssistant)
704 if updateErr != nil {
705 return nil, updateErr
706 }
707 return nil, err
708 }
709
710 if shouldSummarize {
711 a.activeRequests.Del(call.SessionID)
712 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
713 return nil, summarizeErr
714 }
715 // If the agent wasn't done...
716 if len(currentAssistant.ToolCalls()) > 0 {
717 existing, ok := a.messageQueue.Get(call.SessionID)
718 if !ok {
719 existing = []SessionAgentCall{}
720 }
721 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
722 existing = append(existing, call)
723 a.messageQueue.Set(call.SessionID, existing)
724 }
725 }
726
727 // Release active request before publishing the notification.
728 // TUI handlers poll IsSessionBusy() and only re-evaluate when a
729 // tea.Msg arrives, so the cleanup must precede the notify or
730 // subscribers see stale busy state at the moment of receipt.
731 a.activeRequests.Del(call.SessionID)
732 cancel()
733
734 // Send notification that agent has finished its turn (skip for
735 // nested/non-interactive sessions).
736 if !call.NonInteractive && a.notify != nil {
737 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
738 SessionID: call.SessionID,
739 SessionTitle: currentSession.Title,
740 Type: notify.TypeAgentFinished,
741 })
742 }
743
744 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
745 if !ok || len(queuedMessages) == 0 {
746 return result, err
747 }
748 // There are queued messages restart the loop. The recursive Run
749 // publishes its own RunComplete for the queued prompt, so suppress
750 // the outer defer's emit to avoid a duplicate event whose Error
751 // field would belong to the recursive turn but whose MessageID/Text
752 // would belong to the outer turn.
753 skipRunComplete = true
754 firstQueuedMessage := queuedMessages[0]
755 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
756 return a.Run(ctx, firstQueuedMessage)
757}
758
759func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
760 if a.IsSessionBusy(sessionID) {
761 return ErrSessionBusy
762 }
763
764 // Copy mutable fields under lock to avoid races with SetModels.
765 largeModel := a.largeModel.Get()
766 systemPromptPrefix := a.systemPromptPrefix.Get()
767
768 currentSession, err := a.sessions.Get(ctx, sessionID)
769 if err != nil {
770 return fmt.Errorf("failed to get session: %w", err)
771 }
772 msgs, err := a.getSessionMessages(ctx, currentSession)
773 if err != nil {
774 return err
775 }
776 if len(msgs) == 0 {
777 // Nothing to summarize.
778 return nil
779 }
780
781 aiMsgs, _ := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages)
782
783 genCtx, cancel := context.WithCancel(ctx)
784 a.activeRequests.Set(sessionID, cancel)
785 defer a.activeRequests.Del(sessionID)
786 defer cancel()
787 defer func() {
788 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
789 slog.Error("Failed to flush pending message updates after summarize", "error", flushErr)
790 }
791 }()
792
793 agent := fantasy.NewAgent(
794 largeModel.Model,
795 fantasy.WithSystemPrompt(string(summaryPrompt)),
796 fantasy.WithUserAgent(userAgent),
797 )
798 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
799 Role: message.Assistant,
800 Model: largeModel.ModelCfg.Model,
801 Provider: largeModel.ModelCfg.Provider,
802 IsSummaryMessage: true,
803 })
804 if err != nil {
805 return err
806 }
807
808 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
809
810 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
811 Prompt: summaryPromptText,
812 Messages: aiMsgs,
813 ProviderOptions: opts,
814 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
815 prepared.Messages = options.Messages
816 if systemPromptPrefix != "" {
817 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
818 }
819 return callContext, prepared, nil
820 },
821 OnReasoningDelta: func(id string, text string) error {
822 summaryMessage.AppendReasoningContent(text)
823 return a.messages.Update(genCtx, summaryMessage)
824 },
825 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
826 // Handle anthropic signature.
827 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
828 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
829 summaryMessage.AppendReasoningSignature(signature.Signature)
830 }
831 }
832 summaryMessage.FinishThinking()
833 return a.messages.Update(genCtx, summaryMessage)
834 },
835 OnTextDelta: func(id, text string) error {
836 summaryMessage.AppendContent(text)
837 return a.messages.Update(genCtx, summaryMessage)
838 },
839 })
840 if err != nil {
841 isCancelErr := errors.Is(err, context.Canceled)
842 if isCancelErr {
843 // User cancelled summarize we need to remove the summary message.
844 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
845 return deleteErr
846 }
847 // Mark the summary message as finished with an error so the UI
848 // stops spinning.
849 summaryMessage.AddFinish(message.FinishReasonError, "Summarization Error", err.Error())
850 if updateErr := a.messages.Update(ctx, summaryMessage); updateErr != nil {
851 return updateErr
852 }
853 return err
854 }
855
856 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
857 err = a.messages.Update(genCtx, summaryMessage)
858 if err != nil {
859 return err
860 }
861
862 var openrouterCost *float64
863 for _, step := range resp.Steps {
864 stepCost := a.openrouterCost(step.ProviderMetadata)
865 if stepCost != nil {
866 newCost := *stepCost
867 if openrouterCost != nil {
868 newCost += *openrouterCost
869 }
870 openrouterCost = &newCost
871 }
872 }
873
874 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost, false)
875
876 // Just in case, get just the last usage info.
877 usage := resp.Response.Usage
878 currentSession.SummaryMessageID = summaryMessage.ID
879 currentSession.CompletionTokens = summaryCompletionTokens(usage, summaryMessage)
880 currentSession.PromptTokens = 0
881 currentSession.EstimatedUsage = usageIsZero(usage)
882 _, err = a.sessions.Save(genCtx, currentSession)
883 if err != nil {
884 return err
885 }
886
887 // Release the active request before processing queued messages so that
888 // Run() does not see the session as busy.
889 a.activeRequests.Del(sessionID)
890 cancel()
891
892 // Process any messages that were queued while summarizing.
893 queuedMessages, ok := a.messageQueue.Get(sessionID)
894 if !ok || len(queuedMessages) == 0 {
895 return nil
896 }
897 firstQueuedMessage := queuedMessages[0]
898 a.messageQueue.Set(sessionID, queuedMessages[1:])
899 _, qErr := a.Run(ctx, firstQueuedMessage)
900 return qErr
901}
902
903func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
904 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
905 return fantasy.ProviderOptions{}
906 }
907 return fantasy.ProviderOptions{
908 anthropic.Name: &anthropic.ProviderCacheControlOptions{
909 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
910 },
911 bedrock.Name: &anthropic.ProviderCacheControlOptions{
912 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
913 },
914 vercel.Name: &anthropic.ProviderCacheControlOptions{
915 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
916 },
917 }
918}
919
920func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
921 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
922 var attachmentParts []message.ContentPart
923 for _, attachment := range call.Attachments {
924 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
925 }
926 parts = append(parts, attachmentParts...)
927 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
928 Role: message.User,
929 Parts: parts,
930 })
931 if err != nil {
932 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
933 }
934 return msg, nil
935}
936
937func (a *sessionAgent) preparePrompt(msgs []message.Message, supportsImages bool, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
938 var history []fantasy.Message
939 if !a.isSubAgent {
940 history = append(history, fantasy.NewUserMessage(
941 fmt.Sprintf(
942 "<system_reminder>%s</system_reminder>",
943 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
944If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
945If not, please feel free to ignore. Again do not mention this message to the user.`,
946 ),
947 ))
948 }
949 // Collect all tool call IDs present in assistant messages and all tool
950 // result IDs present in tool messages. This lets us detect both orphaned
951 // tool results (result without a call) and orphaned tool calls (call
952 // without a result).
953 knownToolCallIDs := make(map[string]struct{})
954 knownToolResultIDs := make(map[string]struct{})
955 for _, m := range msgs {
956 switch m.Role {
957 case message.Assistant:
958 for _, tc := range m.ToolCalls() {
959 knownToolCallIDs[tc.ID] = struct{}{}
960 }
961 case message.Tool:
962 for _, tr := range m.ToolResults() {
963 knownToolResultIDs[tr.ToolCallID] = struct{}{}
964 }
965 }
966 }
967
968 for _, m := range msgs {
969 if len(m.Parts) == 0 {
970 continue
971 }
972 // Assistant message without content or tool calls (cancelled before it returned anything).
973 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
974 continue
975 }
976 if m.Role == message.Tool {
977 if msg, ok := filterOrphanedToolResults(m, knownToolCallIDs); ok {
978 history = append(history, msg)
979 }
980 continue
981 }
982 aiMsgs := m.ToAIMessage()
983 if !supportsImages {
984 for i := range aiMsgs {
985 if aiMsgs[i].Role == fantasy.MessageRoleUser {
986 aiMsgs[i].Content = filterFileParts(aiMsgs[i].Content)
987 }
988 }
989 }
990 history = append(history, aiMsgs...)
991
992 if m.Role == message.Assistant {
993 if msg, ok := syntheticToolResultsForOrphanedCalls(m, knownToolResultIDs); ok {
994 history = append(history, msg)
995 }
996 }
997 }
998
999 var files []fantasy.FilePart
1000 for _, attachment := range attachments {
1001 if attachment.IsText() {
1002 continue
1003 }
1004 files = append(files, fantasy.FilePart{
1005 Filename: attachment.FileName,
1006 Data: attachment.Content,
1007 MediaType: attachment.MimeType,
1008 })
1009 }
1010
1011 return history, files
1012}
1013
1014// filterFileParts removes fantasy.FilePart entries from a slice of message
1015// parts. Used to strip image attachments from historical user messages when
1016// the current model does not support them.
1017func filterFileParts(parts []fantasy.MessagePart) []fantasy.MessagePart {
1018 filtered := make([]fantasy.MessagePart, 0, len(parts))
1019 for _, part := range parts {
1020 if _, ok := fantasy.AsMessagePart[fantasy.FilePart](part); ok {
1021 continue
1022 }
1023 filtered = append(filtered, part)
1024 }
1025 return filtered
1026}
1027
1028// filterOrphanedToolResults converts a tool message to a fantasy.Message,
1029// dropping any tool result parts whose tool_call_id has no matching tool call
1030// in the known set. An orphaned result causes API validation to fail on every
1031// subsequent turn, permanently locking the session. Returns the filtered
1032// message and true if at least one valid part remains.
1033func filterOrphanedToolResults(m message.Message, knownToolCallIDs map[string]struct{}) (fantasy.Message, bool) {
1034 aiMsgs := m.ToAIMessage()
1035 if len(aiMsgs) == 0 {
1036 return fantasy.Message{}, false
1037 }
1038 var validParts []fantasy.MessagePart
1039 for _, part := range aiMsgs[0].Content {
1040 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1041 if !ok {
1042 validParts = append(validParts, part)
1043 continue
1044 }
1045 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
1046 validParts = append(validParts, part)
1047 } else {
1048 slog.Warn(
1049 "Dropping orphaned tool result with no matching tool call",
1050 "tool_call_id", tr.ToolCallID,
1051 )
1052 }
1053 }
1054 if len(validParts) == 0 {
1055 return fantasy.Message{}, false
1056 }
1057 msg := aiMsgs[0]
1058 msg.Content = validParts
1059 return msg, true
1060}
1061
1062// syntheticToolResultsForOrphanedCalls returns a tool message containing
1063// synthetic tool results for any tool calls in the assistant message that
1064// have no matching result in knownToolResultIDs. LLM APIs require every
1065// tool_use to be immediately followed by a tool_result; an interrupted
1066// session can leave orphaned tool_use blocks that permanently lock the
1067// conversation. Returns the message and true if any synthetic results were
1068// produced.
1069func syntheticToolResultsForOrphanedCalls(m message.Message, knownToolResultIDs map[string]struct{}) (fantasy.Message, bool) {
1070 var syntheticParts []fantasy.MessagePart
1071 for _, tc := range m.ToolCalls() {
1072 if _, hasResult := knownToolResultIDs[tc.ID]; hasResult {
1073 continue
1074 }
1075 slog.Warn(
1076 "Injecting synthetic tool result for orphaned tool call",
1077 "tool_call_id", tc.ID,
1078 "tool_name", tc.Name,
1079 )
1080 syntheticParts = append(syntheticParts, fantasy.ToolResultPart{
1081 ToolCallID: tc.ID,
1082 Output: fantasy.ToolResultOutputContentError{
1083 Error: errors.New("tool call was interrupted and did not produce a result, you may retry this call if the result is still needed"),
1084 },
1085 })
1086 }
1087 if len(syntheticParts) == 0 {
1088 return fantasy.Message{}, false
1089 }
1090 return fantasy.Message{
1091 Role: fantasy.MessageRoleTool,
1092 Content: syntheticParts,
1093 }, true
1094}
1095
1096func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
1097 msgs, err := a.messages.List(ctx, session.ID)
1098 if err != nil {
1099 return nil, fmt.Errorf("failed to list messages: %w", err)
1100 }
1101
1102 if session.SummaryMessageID != "" {
1103 summaryMsgIndex := -1
1104 for i, msg := range msgs {
1105 if msg.ID == session.SummaryMessageID {
1106 summaryMsgIndex = i
1107 break
1108 }
1109 }
1110 if summaryMsgIndex != -1 {
1111 msgs = msgs[summaryMsgIndex:]
1112 msgs[0].Role = message.User
1113 }
1114 }
1115 return msgs, nil
1116}
1117
1118// generateTitle generates a session titled based on the initial prompt.
1119func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
1120 if userPrompt == "" {
1121 return
1122 }
1123
1124 smallModel := a.smallModel.Get()
1125 largeModel := a.largeModel.Get()
1126 systemPromptPrefix := a.systemPromptPrefix.Get()
1127
1128 var maxOutputTokens int64 = 40
1129 if smallModel.CatwalkCfg.CanReason {
1130 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
1131 }
1132
1133 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
1134 return fantasy.NewAgent(
1135 m,
1136 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
1137 fantasy.WithMaxOutputTokens(tok),
1138 fantasy.WithUserAgent(userAgent),
1139 )
1140 }
1141
1142 streamCall := fantasy.AgentStreamCall{
1143 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
1144 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
1145 prepared.Messages = opts.Messages
1146 if systemPromptPrefix != "" {
1147 prepared.Messages = append([]fantasy.Message{
1148 fantasy.NewSystemMessage(systemPromptPrefix),
1149 }, prepared.Messages...)
1150 }
1151 return callCtx, prepared, nil
1152 },
1153 }
1154
1155 // Use the small model to generate the title.
1156 model := smallModel
1157 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
1158 resp, err := agent.Stream(ctx, streamCall)
1159 if err == nil {
1160 // We successfully generated a title with the small model.
1161 slog.Debug("Generated title with small model")
1162 } else {
1163 // It didn't work. Let's try with the big model.
1164 slog.Error("Error generating title with small model; trying big model", "err", err)
1165 model = largeModel
1166 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
1167 resp, err = agent.Stream(ctx, streamCall)
1168 if err == nil {
1169 slog.Debug("Generated title with large model")
1170 } else {
1171 // Welp, the large model didn't work either. Use the default
1172 // session name and return.
1173 slog.Error("Error generating title with large model", "err", err)
1174 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1175 if saveErr != nil {
1176 slog.Error("Failed to save session title", "error", saveErr)
1177 }
1178 return
1179 }
1180 }
1181
1182 if resp == nil {
1183 // Actually, we didn't get a response so we can't. Use the default
1184 // session name and return.
1185 slog.Error("Response is nil; can't generate title")
1186 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1187 if saveErr != nil {
1188 slog.Error("Failed to save session title", "error", saveErr)
1189 }
1190 return
1191 }
1192
1193 // Clean up title.
1194 var title string
1195 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
1196
1197 // Remove thinking tags if present.
1198 title = thinkTagRegex.ReplaceAllString(title, "")
1199 title = orphanThinkTagRegex.ReplaceAllString(title, "")
1200
1201 title = strings.TrimSpace(title)
1202 title = cmp.Or(title, DefaultSessionName)
1203
1204 // Calculate usage and cost.
1205 var openrouterCost *float64
1206 for _, step := range resp.Steps {
1207 stepCost := a.openrouterCost(step.ProviderMetadata)
1208 if stepCost != nil {
1209 newCost := *stepCost
1210 if openrouterCost != nil {
1211 newCost += *openrouterCost
1212 }
1213 openrouterCost = &newCost
1214 }
1215 }
1216
1217 modelConfig := model.CatwalkCfg
1218 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
1219 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
1220 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
1221 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
1222
1223 // Use override cost if available (e.g., from OpenRouter).
1224 if openrouterCost != nil {
1225 cost = *openrouterCost
1226 }
1227
1228 // Skip cost accumulation
1229 if model.FlatRate {
1230 cost = 0
1231 }
1232
1233 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
1234 completionTokens := resp.TotalUsage.OutputTokens
1235
1236 // Atomically update only title and usage fields to avoid overriding other
1237 // concurrent session updates.
1238 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
1239 if saveErr != nil {
1240 slog.Error("Failed to save session title and usage", "error", saveErr)
1241 return
1242 }
1243}
1244
1245func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
1246 openrouterMetadata, ok := metadata[openrouter.Name]
1247 if !ok {
1248 return nil
1249 }
1250
1251 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
1252 if !ok {
1253 return nil
1254 }
1255 return &opts.Usage.Cost
1256}
1257
1258func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) {
1259 if !usageIsZero(usage) {
1260 session.EstimatedUsage = estimated
1261 }
1262
1263 modelConfig := model.CatwalkCfg
1264 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1265 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1266 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1267 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1268
1269 if !estimated {
1270 a.eventTokensUsed(session.ID, model, usage, cost)
1271 }
1272
1273 if estimated {
1274 cost = 0
1275 } else {
1276 // Use override cost if available (e.g., from OpenRouter).
1277 if overrideCost != nil {
1278 cost = *overrideCost
1279 }
1280
1281 // Skip cost accumulation
1282 if model.FlatRate {
1283 cost = 0
1284 }
1285 }
1286
1287 session.Cost += cost
1288 updateSessionTokenCounters(session, usage)
1289}
1290
1291func updateSessionTokenCounters(session *session.Session, usage fantasy.Usage) {
1292 if usage.OutputTokens != 0 {
1293 session.CompletionTokens = usage.OutputTokens
1294 }
1295 if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
1296 session.PromptTokens = promptTokens
1297 }
1298}
1299
1300func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message) int64 {
1301 if usage.OutputTokens != 0 {
1302 return usage.OutputTokens
1303 }
1304 return approxTokenCount(summaryMessage.Content().Text) + approxTokenCount(summaryMessage.ReasoningContent().String())
1305}
1306
1307func (a *sessionAgent) Cancel(sessionID string) {
1308 // Cancel regular requests. Don't use Take() here - we need the entry to
1309 // remain in activeRequests so IsBusy() returns true until the goroutine
1310 // fully completes (including error handling that may access the DB).
1311 // The defer in processRequest will clean up the entry.
1312 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1313 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1314 cancel()
1315 }
1316
1317 // Also check for summarize requests.
1318 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1319 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1320 cancel()
1321 }
1322
1323 if a.QueuedPrompts(sessionID) > 0 {
1324 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1325 a.messageQueue.Del(sessionID)
1326 }
1327}
1328
1329func (a *sessionAgent) ClearQueue(sessionID string) {
1330 if a.QueuedPrompts(sessionID) > 0 {
1331 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1332 a.messageQueue.Del(sessionID)
1333 }
1334}
1335
1336func (a *sessionAgent) CancelAll() {
1337 if !a.IsBusy() {
1338 return
1339 }
1340 for key := range a.activeRequests.Seq2() {
1341 a.Cancel(key) // key is sessionID
1342 }
1343
1344 timeout := time.After(5 * time.Second)
1345 for a.IsBusy() {
1346 select {
1347 case <-timeout:
1348 return
1349 default:
1350 time.Sleep(200 * time.Millisecond)
1351 }
1352 }
1353}
1354
1355func (a *sessionAgent) IsBusy() bool {
1356 var busy bool
1357 for cancelFunc := range a.activeRequests.Seq() {
1358 if cancelFunc != nil {
1359 busy = true
1360 break
1361 }
1362 }
1363 return busy
1364}
1365
1366func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1367 _, busy := a.activeRequests.Get(sessionID)
1368 return busy
1369}
1370
1371func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1372 l, ok := a.messageQueue.Get(sessionID)
1373 if !ok {
1374 return 0
1375 }
1376 return len(l)
1377}
1378
1379func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1380 l, ok := a.messageQueue.Get(sessionID)
1381 if !ok {
1382 return nil
1383 }
1384 prompts := make([]string, len(l))
1385 for i, call := range l {
1386 prompts[i] = call.Prompt
1387 }
1388 return prompts
1389}
1390
1391func (a *sessionAgent) SetModels(large Model, small Model) {
1392 a.largeModel.Set(large)
1393 a.smallModel.Set(small)
1394}
1395
1396func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1397 a.tools.SetSlice(tools)
1398}
1399
1400func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1401 a.systemPrompt.Set(systemPrompt)
1402}
1403
1404func (a *sessionAgent) Model() Model {
1405 return a.largeModel.Get()
1406}
1407
1408// convertToToolResult converts a fantasy tool result to a message tool result.
1409func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1410 baseResult := message.ToolResult{
1411 ToolCallID: result.ToolCallID,
1412 Name: result.ToolName,
1413 Metadata: result.ClientMetadata,
1414 }
1415
1416 switch result.Result.GetType() {
1417 case fantasy.ToolResultContentTypeText:
1418 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1419 baseResult.Content = r.Text
1420 }
1421 case fantasy.ToolResultContentTypeError:
1422 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1423 baseResult.Content = r.Error.Error()
1424 baseResult.IsError = true
1425 }
1426 case fantasy.ToolResultContentTypeMedia:
1427 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1428 if !stringext.IsValidBase64(r.Data) {
1429 slog.Warn(
1430 "Tool returned media with invalid base64 data, discarding image",
1431 "tool", result.ToolName,
1432 "tool_call_id", result.ToolCallID,
1433 )
1434 baseResult.Content = "Tool returned image data with invalid encoding"
1435 baseResult.IsError = true
1436 } else {
1437 content := r.Text
1438 if content == "" {
1439 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1440 }
1441 baseResult.Content = content
1442 baseResult.Data = r.Data
1443 baseResult.MIMEType = r.MediaType
1444 }
1445 }
1446 }
1447
1448 return baseResult
1449}
1450
1451// workaroundProviderMediaLimitations converts media content in tool results to
1452// user messages for providers that don't natively support images in tool results.
1453//
1454// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1455// don't support sending images/media in tool result messages - they only accept
1456// text in tool results. However, they DO support images in user messages.
1457//
1458// If we send media in tool results to these providers, the API returns an error.
1459//
1460// Solution: For these providers, we:
1461// 1. Replace the media in the tool result with a text placeholder
1462// 2. Inject a user message immediately after with the image as a file attachment
1463// 3. This maintains the tool execution flow while working around API limitations
1464//
1465// Anthropic and Bedrock support images natively in tool results, so we skip
1466// this workaround for them.
1467//
1468// Example transformation:
1469//
1470// BEFORE: [tool result: image data]
1471// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1472func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1473 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1474 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock) ||
1475 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrockEurope)
1476
1477 if providerSupportsMedia {
1478 return messages
1479 }
1480
1481 convertedMessages := make([]fantasy.Message, 0, len(messages))
1482
1483 for _, msg := range messages {
1484 if msg.Role != fantasy.MessageRoleTool {
1485 convertedMessages = append(convertedMessages, msg)
1486 continue
1487 }
1488
1489 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1490 var mediaFiles []fantasy.FilePart
1491
1492 for _, part := range msg.Content {
1493 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1494 if !ok {
1495 textParts = append(textParts, part)
1496 continue
1497 }
1498
1499 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1500 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1501 if err != nil {
1502 slog.Warn("Failed to decode media data", "error", err)
1503 textParts = append(textParts, part)
1504 continue
1505 }
1506
1507 mediaFiles = append(mediaFiles, fantasy.FilePart{
1508 Data: decoded,
1509 MediaType: media.MediaType,
1510 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1511 })
1512
1513 textParts = append(textParts, fantasy.ToolResultPart{
1514 ToolCallID: toolResult.ToolCallID,
1515 Output: fantasy.ToolResultOutputContentText{
1516 Text: "[Image/media content loaded - see attached file]",
1517 },
1518 ProviderOptions: toolResult.ProviderOptions,
1519 })
1520 } else {
1521 textParts = append(textParts, part)
1522 }
1523 }
1524
1525 convertedMessages = append(convertedMessages, fantasy.Message{
1526 Role: fantasy.MessageRoleTool,
1527 Content: textParts,
1528 })
1529
1530 if len(mediaFiles) > 0 {
1531 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1532 "Here is the media content from the tool result:",
1533 mediaFiles...,
1534 ))
1535 }
1536 }
1537
1538 return convertedMessages
1539}
1540
1541// buildSummaryPrompt constructs the prompt text for session summarization.
1542func buildSummaryPrompt(todos []session.Todo) string {
1543 var sb strings.Builder
1544 sb.WriteString("Provide a detailed summary of our conversation above.")
1545 if len(todos) > 0 {
1546 sb.WriteString("\n\n## Current Todo List\n\n")
1547 for _, t := range todos {
1548 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1549 }
1550 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1551 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1552 }
1553 return sb.String()
1554}
1555
1556func providerRetryLogFields(err *fantasy.ProviderError, delay time.Duration) []any {
1557 fields := []any{
1558 "retry_delay", delay.String(),
1559 }
1560 if err == nil {
1561 return fields
1562 }
1563 fields = append(fields, "status_code", err.StatusCode)
1564 if err.Title != "" {
1565 fields = append(fields, "title", err.Title)
1566 }
1567 if err.Message != "" {
1568 fields = append(fields, "message", err.Message)
1569 }
1570 return fields
1571}