1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "net/http"
19 "os"
20 "regexp"
21 "strconv"
22 "strings"
23 "sync"
24 "time"
25
26 "charm.land/catwalk/pkg/catwalk"
27 "charm.land/fantasy"
28 "charm.land/fantasy/providers/anthropic"
29 "charm.land/fantasy/providers/bedrock"
30 "charm.land/fantasy/providers/google"
31 "charm.land/fantasy/providers/openai"
32 "charm.land/fantasy/providers/openrouter"
33 "charm.land/fantasy/providers/vercel"
34 "charm.land/lipgloss/v2"
35 "github.com/charmbracelet/crush/internal/agent/hyper"
36 "github.com/charmbracelet/crush/internal/agent/notify"
37 "github.com/charmbracelet/crush/internal/agent/tools"
38 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
39 "github.com/charmbracelet/crush/internal/config"
40 "github.com/charmbracelet/crush/internal/csync"
41 "github.com/charmbracelet/crush/internal/message"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 Prompt string
75 ProviderOptions fantasy.ProviderOptions
76 Attachments []message.Attachment
77 MaxOutputTokens int64
78 Temperature *float64
79 TopP *float64
80 TopK *int64
81 FrequencyPenalty *float64
82 PresencePenalty *float64
83 NonInteractive bool
84}
85
86type SessionAgent interface {
87 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
88 SetModels(large Model, small Model)
89 SetTools(tools []fantasy.AgentTool)
90 SetSystemPrompt(systemPrompt string)
91 Cancel(sessionID string)
92 CancelAll()
93 IsSessionBusy(sessionID string) bool
94 IsBusy() bool
95 QueuedPrompts(sessionID string) int
96 QueuedPromptsList(sessionID string) []string
97 ClearQueue(sessionID string)
98 Summarize(context.Context, string, fantasy.ProviderOptions) error
99 Model() Model
100}
101
102type Model struct {
103 Model fantasy.LanguageModel
104 CatwalkCfg catwalk.Model
105 ModelCfg config.SelectedModel
106 FlatRate bool
107}
108
109type sessionAgent struct {
110 largeModel *csync.Value[Model]
111 smallModel *csync.Value[Model]
112 systemPromptPrefix *csync.Value[string]
113 systemPrompt *csync.Value[string]
114 tools *csync.Slice[fantasy.AgentTool]
115
116 isSubAgent bool
117 sessions session.Service
118 messages message.Service
119 disableAutoSummarize bool
120 isYolo bool
121 notify pubsub.Publisher[notify.Notification]
122
123 messageQueue *csync.Map[string, []SessionAgentCall]
124 activeRequests *csync.Map[string, context.CancelFunc]
125}
126
127type SessionAgentOptions struct {
128 LargeModel Model
129 SmallModel Model
130 SystemPromptPrefix string
131 SystemPrompt string
132 IsSubAgent bool
133 DisableAutoSummarize bool
134 IsYolo bool
135 Sessions session.Service
136 Messages message.Service
137 Tools []fantasy.AgentTool
138 Notify pubsub.Publisher[notify.Notification]
139}
140
141func NewSessionAgent(
142 opts SessionAgentOptions,
143) SessionAgent {
144 return &sessionAgent{
145 largeModel: csync.NewValue(opts.LargeModel),
146 smallModel: csync.NewValue(opts.SmallModel),
147 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
148 systemPrompt: csync.NewValue(opts.SystemPrompt),
149 isSubAgent: opts.IsSubAgent,
150 sessions: opts.Sessions,
151 messages: opts.Messages,
152 disableAutoSummarize: opts.DisableAutoSummarize,
153 tools: csync.NewSliceFrom(opts.Tools),
154 isYolo: opts.IsYolo,
155 notify: opts.Notify,
156 messageQueue: csync.NewMap[string, []SessionAgentCall](),
157 activeRequests: csync.NewMap[string, context.CancelFunc](),
158 }
159}
160
161func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
162 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
163 return nil, ErrEmptyPrompt
164 }
165 if call.SessionID == "" {
166 return nil, ErrSessionMissing
167 }
168
169 // Queue the message if busy
170 if a.IsSessionBusy(call.SessionID) {
171 existing, ok := a.messageQueue.Get(call.SessionID)
172 if !ok {
173 existing = []SessionAgentCall{}
174 }
175 existing = append(existing, call)
176 a.messageQueue.Set(call.SessionID, existing)
177 return nil, nil
178 }
179
180 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
181 agentTools := a.tools.Copy()
182 largeModel := a.largeModel.Get()
183 systemPrompt := a.systemPrompt.Get()
184 promptPrefix := a.systemPromptPrefix.Get()
185 var instructions strings.Builder
186
187 for _, server := range mcp.GetStates() {
188 if server.State != mcp.StateConnected {
189 continue
190 }
191 if s := server.Client.InitializeResult().Instructions; s != "" {
192 instructions.WriteString(s)
193 instructions.WriteString("\n\n")
194 }
195 }
196
197 if s := instructions.String(); s != "" {
198 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
199 }
200
201 if len(agentTools) > 0 {
202 // Add Anthropic caching to the last tool.
203 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
204 }
205
206 agent := fantasy.NewAgent(
207 largeModel.Model,
208 fantasy.WithSystemPrompt(systemPrompt),
209 fantasy.WithTools(agentTools...),
210 fantasy.WithUserAgent(userAgent),
211 )
212
213 sessionLock := sync.Mutex{}
214 currentSession, err := a.sessions.Get(ctx, call.SessionID)
215 if err != nil {
216 return nil, fmt.Errorf("failed to get session: %w", err)
217 }
218
219 msgs, err := a.getSessionMessages(ctx, currentSession)
220 if err != nil {
221 return nil, fmt.Errorf("failed to get session messages: %w", err)
222 }
223
224 var wg sync.WaitGroup
225 // Generate title if first message.
226 if len(msgs) == 0 {
227 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
228 wg.Go(func() {
229 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
230 })
231 }
232 defer wg.Wait()
233
234 // Add the user message to the session.
235 _, err = a.createUserMessage(ctx, call)
236 if err != nil {
237 return nil, err
238 }
239
240 // Add the session to the context.
241 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
242
243 genCtx, cancel := context.WithCancel(ctx)
244 a.activeRequests.Set(call.SessionID, cancel)
245
246 defer cancel()
247 defer a.activeRequests.Del(call.SessionID)
248 // Drain any debounced message updates before returning. message.Service
249 // already flushes synchronously on terminal updates, but a defer here
250 // guarantees the contract at every Run exit (success, error, panic
251 // recovery upstream) without callers needing to know.
252 defer func() {
253 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
254 slog.Error("Failed to flush pending message updates after run", "error", flushErr)
255 }
256 }()
257
258 history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
259
260 startTime := time.Now()
261 a.eventPromptSent(call.SessionID)
262
263 var currentAssistant *message.Message
264 var stepMessages []fantasy.Message
265 var shouldSummarize bool
266 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
267 var maxOutputTokens *int64
268 if call.MaxOutputTokens > 0 {
269 maxOutputTokens = &call.MaxOutputTokens
270 }
271 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
272 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
273 Files: files,
274 Messages: history,
275 ProviderOptions: call.ProviderOptions,
276 MaxOutputTokens: maxOutputTokens,
277 TopP: call.TopP,
278 Temperature: call.Temperature,
279 PresencePenalty: call.PresencePenalty,
280 TopK: call.TopK,
281 FrequencyPenalty: call.FrequencyPenalty,
282 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
283 prepared.Messages = options.Messages
284 for i := range prepared.Messages {
285 prepared.Messages[i].ProviderOptions = nil
286 }
287
288 // Use latest tools (updated by SetTools when MCP tools change).
289 prepared.Tools = a.tools.Copy()
290
291 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
292 a.messageQueue.Del(call.SessionID)
293 for _, queued := range queuedCalls {
294 userMessage, createErr := a.createUserMessage(callContext, queued)
295 if createErr != nil {
296 return callContext, prepared, createErr
297 }
298 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
299 }
300
301 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
302
303 lastSystemRoleInx := 0
304 systemMessageUpdated := false
305 for i, msg := range prepared.Messages {
306 // Only add cache control to the last message.
307 if msg.Role == fantasy.MessageRoleSystem {
308 lastSystemRoleInx = i
309 } else if !systemMessageUpdated {
310 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
311 systemMessageUpdated = true
312 }
313 // Than add cache control to the last 2 messages.
314 if i > len(prepared.Messages)-3 {
315 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
316 }
317 }
318
319 if promptPrefix != "" {
320 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
321 }
322
323 sessionLock.Lock()
324 stepMessages = cloneFantasyMessages(prepared.Messages)
325 sessionLock.Unlock()
326
327 var assistantMsg message.Message
328 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
329 Role: message.Assistant,
330 Parts: []message.ContentPart{},
331 Model: largeModel.ModelCfg.Model,
332 Provider: largeModel.ModelCfg.Provider,
333 })
334 if err != nil {
335 return callContext, prepared, err
336 }
337 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
338 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
339 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
340 currentAssistant = &assistantMsg
341 return callContext, prepared, err
342 },
343 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
344 currentAssistant.AppendReasoningContent(reasoning.Text)
345 return a.messages.Update(genCtx, *currentAssistant)
346 },
347 OnReasoningDelta: func(id string, text string) error {
348 currentAssistant.AppendReasoningContent(text)
349 return a.messages.Update(genCtx, *currentAssistant)
350 },
351 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
352 // handle anthropic signature
353 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
354 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
355 currentAssistant.AppendReasoningSignature(reasoning.Signature)
356 }
357 }
358 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
359 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
360 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
361 }
362 }
363 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
364 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
365 currentAssistant.SetReasoningResponsesData(reasoning)
366 }
367 }
368 currentAssistant.FinishThinking()
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 OnTextDelta: func(id string, text string) error {
372 // Strip leading newline from initial text content. This is is
373 // particularly important in non-interactive mode where leading
374 // newlines are very visible.
375 if len(currentAssistant.Parts) == 0 {
376 text = strings.TrimPrefix(text, "\n")
377 }
378
379 currentAssistant.AppendContent(text)
380 return a.messages.Update(genCtx, *currentAssistant)
381 },
382 OnToolInputStart: func(id string, toolName string) error {
383 toolCall := message.ToolCall{
384 ID: id,
385 Name: toolName,
386 ProviderExecuted: false,
387 Finished: false,
388 }
389 currentAssistant.AddToolCall(toolCall)
390 // Use parent ctx instead of genCtx to ensure the update succeeds
391 // even if the request is canceled mid-stream
392 return a.messages.Update(ctx, *currentAssistant)
393 },
394 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
395 slog.Warn("Provider request failed, retrying", providerRetryLogFields(err, delay)...)
396 },
397 OnToolCall: func(tc fantasy.ToolCallContent) error {
398 toolCall := message.ToolCall{
399 ID: tc.ToolCallID,
400 Name: tc.ToolName,
401 Input: tc.Input,
402 ProviderExecuted: false,
403 Finished: true,
404 }
405 currentAssistant.AddToolCall(toolCall)
406 // Use parent ctx instead of genCtx to ensure the update succeeds
407 // even if the request is canceled mid-stream
408 return a.messages.Update(ctx, *currentAssistant)
409 },
410 OnToolResult: func(result fantasy.ToolResultContent) error {
411 toolResult := a.convertToToolResult(result)
412 // Use parent ctx instead of genCtx to ensure the message is created
413 // even if the request is canceled mid-stream
414 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
415 Role: message.Tool,
416 Parts: []message.ContentPart{
417 toolResult,
418 },
419 })
420 return createMsgErr
421 },
422 OnStepFinish: func(stepResult fantasy.StepResult) error {
423 finishReason := message.FinishReasonUnknown
424 switch stepResult.FinishReason {
425 case fantasy.FinishReasonLength:
426 finishReason = message.FinishReasonMaxTokens
427 case fantasy.FinishReasonStop:
428 finishReason = message.FinishReasonEndTurn
429 case fantasy.FinishReasonToolCalls:
430 finishReason = message.FinishReasonToolUse
431 }
432 // If a tool result halted the turn (e.g. a hook halt or a
433 // permission denial), the step ends on FinishReasonToolCalls but
434 // the model will not be called again. Treat it as the end of the
435 // turn so the UI can render the assistant footer.
436 if finishReason == message.FinishReasonToolUse {
437 for _, tr := range stepResult.Content.ToolResults() {
438 if tr.StopTurn {
439 finishReason = message.FinishReasonEndTurn
440 break
441 }
442 }
443 }
444 currentAssistant.AddFinish(finishReason, "", "")
445 sessionLock.Lock()
446 defer sessionLock.Unlock()
447
448 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
449 if getSessionErr != nil {
450 return getSessionErr
451 }
452 usage, estimated := fallbackStepUsage(stepMessages, stepResult)
453 a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated)
454 _, sessionErr := a.sessions.Save(ctx, updatedSession)
455 if sessionErr != nil {
456 return sessionErr
457 }
458 currentSession = updatedSession
459 return a.messages.Update(genCtx, *currentAssistant)
460 },
461 StopWhen: []fantasy.StopCondition{
462 func(_ []fantasy.StepResult) bool {
463 cw := int64(largeModel.CatwalkCfg.ContextWindow)
464 // If context window is unknown (0), skip auto-summarize
465 // to avoid immediately truncating custom/local models.
466 if cw == 0 {
467 return false
468 }
469 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
470 remaining := cw - tokens
471 var threshold int64
472 if cw > largeContextWindowThreshold {
473 threshold = largeContextWindowBuffer
474 } else {
475 threshold = int64(float64(cw) * smallContextWindowRatio)
476 }
477 if (remaining <= threshold) && !a.disableAutoSummarize {
478 shouldSummarize = true
479 return true
480 }
481 return false
482 },
483 func(steps []fantasy.StepResult) bool {
484 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
485 },
486 },
487 })
488
489 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
490
491 if err != nil {
492 isHyper := largeModel.ModelCfg.Provider == hyper.Name
493 isCancelErr := errors.Is(err, context.Canceled)
494 if currentAssistant == nil {
495 return result, err
496 }
497 // Ensure we finish thinking on error to close the reasoning state.
498 currentAssistant.FinishThinking()
499 toolCalls := currentAssistant.ToolCalls()
500 // INFO: we use the parent context here because the genCtx has been cancelled.
501 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
502 if createErr != nil {
503 return nil, createErr
504 }
505 for _, tc := range toolCalls {
506 if !tc.Finished {
507 tc.Finished = true
508 tc.Input = "{}"
509 currentAssistant.AddToolCall(tc)
510 updateErr := a.messages.Update(ctx, *currentAssistant)
511 if updateErr != nil {
512 return nil, updateErr
513 }
514 }
515
516 found := false
517 for _, msg := range msgs {
518 if msg.Role == message.Tool {
519 for _, tr := range msg.ToolResults() {
520 if tr.ToolCallID == tc.ID {
521 found = true
522 break
523 }
524 }
525 }
526 if found {
527 break
528 }
529 }
530 if found {
531 continue
532 }
533 content := "There was an error while executing the tool"
534 if isCancelErr {
535 content = "Error: user cancelled assistant tool calling"
536 }
537 toolResult := message.ToolResult{
538 ToolCallID: tc.ID,
539 Name: tc.Name,
540 Content: content,
541 IsError: true,
542 }
543 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
544 Role: message.Tool,
545 Parts: []message.ContentPart{
546 toolResult,
547 },
548 })
549 if createErr != nil {
550 return nil, createErr
551 }
552 }
553 var fantasyErr *fantasy.Error
554 var providerErr *fantasy.ProviderError
555 const defaultTitle = "Provider Error"
556 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
557 if isCancelErr {
558 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
559 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized {
560 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
561 if a.notify != nil {
562 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
563 SessionID: call.SessionID,
564 SessionTitle: currentSession.Title,
565 Type: notify.TypeReAuthenticate,
566 ProviderID: largeModel.ModelCfg.Provider,
567 })
568 }
569 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired {
570 url := hyper.BaseURL()
571 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
572 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
573 } else if errors.As(err, &providerErr) {
574 if providerErr.Message == "The requested model is not supported." {
575 url := "https://github.com/settings/copilot/features"
576 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
577 currentAssistant.AddFinish(
578 message.FinishReasonError,
579 "Copilot model not enabled",
580 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
581 )
582 } else {
583 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
584 }
585 } else if errors.As(err, &fantasyErr) {
586 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
587 } else {
588 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
589 }
590 // Note: we use the parent context here because the genCtx has been
591 // cancelled.
592 updateErr := a.messages.Update(ctx, *currentAssistant)
593 if updateErr != nil {
594 return nil, updateErr
595 }
596 return nil, err
597 }
598
599 if shouldSummarize {
600 a.activeRequests.Del(call.SessionID)
601 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
602 return nil, summarizeErr
603 }
604 // If the agent wasn't done...
605 if len(currentAssistant.ToolCalls()) > 0 {
606 existing, ok := a.messageQueue.Get(call.SessionID)
607 if !ok {
608 existing = []SessionAgentCall{}
609 }
610 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
611 existing = append(existing, call)
612 a.messageQueue.Set(call.SessionID, existing)
613 }
614 }
615
616 // Release active request before publishing the notification.
617 // TUI handlers poll IsSessionBusy() and only re-evaluate when a
618 // tea.Msg arrives, so the cleanup must precede the notify or
619 // subscribers see stale busy state at the moment of receipt.
620 a.activeRequests.Del(call.SessionID)
621 cancel()
622
623 // Send notification that agent has finished its turn (skip for
624 // nested/non-interactive sessions).
625 if !call.NonInteractive && a.notify != nil {
626 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
627 SessionID: call.SessionID,
628 SessionTitle: currentSession.Title,
629 Type: notify.TypeAgentFinished,
630 })
631 }
632
633 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
634 if !ok || len(queuedMessages) == 0 {
635 return result, err
636 }
637 // There are queued messages restart the loop.
638 firstQueuedMessage := queuedMessages[0]
639 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
640 return a.Run(ctx, firstQueuedMessage)
641}
642
643func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
644 if a.IsSessionBusy(sessionID) {
645 return ErrSessionBusy
646 }
647
648 // Copy mutable fields under lock to avoid races with SetModels.
649 largeModel := a.largeModel.Get()
650 systemPromptPrefix := a.systemPromptPrefix.Get()
651
652 currentSession, err := a.sessions.Get(ctx, sessionID)
653 if err != nil {
654 return fmt.Errorf("failed to get session: %w", err)
655 }
656 msgs, err := a.getSessionMessages(ctx, currentSession)
657 if err != nil {
658 return err
659 }
660 if len(msgs) == 0 {
661 // Nothing to summarize.
662 return nil
663 }
664
665 aiMsgs, _ := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages)
666
667 genCtx, cancel := context.WithCancel(ctx)
668 a.activeRequests.Set(sessionID, cancel)
669 defer a.activeRequests.Del(sessionID)
670 defer cancel()
671 defer func() {
672 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
673 slog.Error("Failed to flush pending message updates after summarize", "error", flushErr)
674 }
675 }()
676
677 agent := fantasy.NewAgent(
678 largeModel.Model,
679 fantasy.WithSystemPrompt(string(summaryPrompt)),
680 fantasy.WithUserAgent(userAgent),
681 )
682 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
683 Role: message.Assistant,
684 Model: largeModel.Model.Model(),
685 Provider: largeModel.Model.Provider(),
686 IsSummaryMessage: true,
687 })
688 if err != nil {
689 return err
690 }
691
692 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
693
694 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
695 Prompt: summaryPromptText,
696 Messages: aiMsgs,
697 ProviderOptions: opts,
698 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
699 prepared.Messages = options.Messages
700 if systemPromptPrefix != "" {
701 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
702 }
703 return callContext, prepared, nil
704 },
705 OnReasoningDelta: func(id string, text string) error {
706 summaryMessage.AppendReasoningContent(text)
707 return a.messages.Update(genCtx, summaryMessage)
708 },
709 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
710 // Handle anthropic signature.
711 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
712 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
713 summaryMessage.AppendReasoningSignature(signature.Signature)
714 }
715 }
716 summaryMessage.FinishThinking()
717 return a.messages.Update(genCtx, summaryMessage)
718 },
719 OnTextDelta: func(id, text string) error {
720 summaryMessage.AppendContent(text)
721 return a.messages.Update(genCtx, summaryMessage)
722 },
723 })
724 if err != nil {
725 isCancelErr := errors.Is(err, context.Canceled)
726 if isCancelErr {
727 // User cancelled summarize we need to remove the summary message.
728 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
729 return deleteErr
730 }
731 // Mark the summary message as finished with an error so the UI
732 // stops spinning.
733 summaryMessage.AddFinish(message.FinishReasonError, "Summarization Error", err.Error())
734 if updateErr := a.messages.Update(ctx, summaryMessage); updateErr != nil {
735 return updateErr
736 }
737 return err
738 }
739
740 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
741 err = a.messages.Update(genCtx, summaryMessage)
742 if err != nil {
743 return err
744 }
745
746 var openrouterCost *float64
747 for _, step := range resp.Steps {
748 stepCost := a.openrouterCost(step.ProviderMetadata)
749 if stepCost != nil {
750 newCost := *stepCost
751 if openrouterCost != nil {
752 newCost += *openrouterCost
753 }
754 openrouterCost = &newCost
755 }
756 }
757
758 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost, false)
759
760 // Just in case, get just the last usage info.
761 usage := resp.Response.Usage
762 currentSession.SummaryMessageID = summaryMessage.ID
763 if completionTokens := summaryCompletionTokens(usage, summaryMessage); completionTokens != 0 {
764 currentSession.CompletionTokens = completionTokens
765 }
766 currentSession.PromptTokens = 0
767 _, err = a.sessions.Save(genCtx, currentSession)
768 if err != nil {
769 return err
770 }
771
772 // Release the active request before processing queued messages so that
773 // Run() does not see the session as busy.
774 a.activeRequests.Del(sessionID)
775 cancel()
776
777 // Process any messages that were queued while summarizing.
778 queuedMessages, ok := a.messageQueue.Get(sessionID)
779 if !ok || len(queuedMessages) == 0 {
780 return nil
781 }
782 firstQueuedMessage := queuedMessages[0]
783 a.messageQueue.Set(sessionID, queuedMessages[1:])
784 _, qErr := a.Run(ctx, firstQueuedMessage)
785 return qErr
786}
787
788func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
789 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
790 return fantasy.ProviderOptions{}
791 }
792 return fantasy.ProviderOptions{
793 anthropic.Name: &anthropic.ProviderCacheControlOptions{
794 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
795 },
796 bedrock.Name: &anthropic.ProviderCacheControlOptions{
797 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
798 },
799 vercel.Name: &anthropic.ProviderCacheControlOptions{
800 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
801 },
802 }
803}
804
805func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
806 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
807 var attachmentParts []message.ContentPart
808 for _, attachment := range call.Attachments {
809 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
810 }
811 parts = append(parts, attachmentParts...)
812 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
813 Role: message.User,
814 Parts: parts,
815 })
816 if err != nil {
817 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
818 }
819 return msg, nil
820}
821
822func (a *sessionAgent) preparePrompt(msgs []message.Message, supportsImages bool, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
823 var history []fantasy.Message
824 if !a.isSubAgent {
825 history = append(history, fantasy.NewUserMessage(
826 fmt.Sprintf(
827 "<system_reminder>%s</system_reminder>",
828 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
829If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
830If not, please feel free to ignore. Again do not mention this message to the user.`,
831 ),
832 ))
833 }
834 // Collect all tool call IDs present in assistant messages and all tool
835 // result IDs present in tool messages. This lets us detect both orphaned
836 // tool results (result without a call) and orphaned tool calls (call
837 // without a result).
838 knownToolCallIDs := make(map[string]struct{})
839 knownToolResultIDs := make(map[string]struct{})
840 for _, m := range msgs {
841 switch m.Role {
842 case message.Assistant:
843 for _, tc := range m.ToolCalls() {
844 knownToolCallIDs[tc.ID] = struct{}{}
845 }
846 case message.Tool:
847 for _, tr := range m.ToolResults() {
848 knownToolResultIDs[tr.ToolCallID] = struct{}{}
849 }
850 }
851 }
852
853 for _, m := range msgs {
854 if len(m.Parts) == 0 {
855 continue
856 }
857 // Assistant message without content or tool calls (cancelled before it returned anything).
858 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
859 continue
860 }
861 if m.Role == message.Tool {
862 if msg, ok := filterOrphanedToolResults(m, knownToolCallIDs); ok {
863 history = append(history, msg)
864 }
865 continue
866 }
867 aiMsgs := m.ToAIMessage()
868 if !supportsImages {
869 for i := range aiMsgs {
870 if aiMsgs[i].Role == fantasy.MessageRoleUser {
871 aiMsgs[i].Content = filterFileParts(aiMsgs[i].Content)
872 }
873 }
874 }
875 history = append(history, aiMsgs...)
876
877 if m.Role == message.Assistant {
878 if msg, ok := syntheticToolResultsForOrphanedCalls(m, knownToolResultIDs); ok {
879 history = append(history, msg)
880 }
881 }
882 }
883
884 var files []fantasy.FilePart
885 for _, attachment := range attachments {
886 if attachment.IsText() {
887 continue
888 }
889 files = append(files, fantasy.FilePart{
890 Filename: attachment.FileName,
891 Data: attachment.Content,
892 MediaType: attachment.MimeType,
893 })
894 }
895
896 return history, files
897}
898
899// filterFileParts removes fantasy.FilePart entries from a slice of message
900// parts. Used to strip image attachments from historical user messages when
901// the current model does not support them.
902func filterFileParts(parts []fantasy.MessagePart) []fantasy.MessagePart {
903 filtered := make([]fantasy.MessagePart, 0, len(parts))
904 for _, part := range parts {
905 if _, ok := fantasy.AsMessagePart[fantasy.FilePart](part); ok {
906 continue
907 }
908 filtered = append(filtered, part)
909 }
910 return filtered
911}
912
913// filterOrphanedToolResults converts a tool message to a fantasy.Message,
914// dropping any tool result parts whose tool_call_id has no matching tool call
915// in the known set. An orphaned result causes API validation to fail on every
916// subsequent turn, permanently locking the session. Returns the filtered
917// message and true if at least one valid part remains.
918func filterOrphanedToolResults(m message.Message, knownToolCallIDs map[string]struct{}) (fantasy.Message, bool) {
919 aiMsgs := m.ToAIMessage()
920 if len(aiMsgs) == 0 {
921 return fantasy.Message{}, false
922 }
923 var validParts []fantasy.MessagePart
924 for _, part := range aiMsgs[0].Content {
925 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
926 if !ok {
927 validParts = append(validParts, part)
928 continue
929 }
930 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
931 validParts = append(validParts, part)
932 } else {
933 slog.Warn(
934 "Dropping orphaned tool result with no matching tool call",
935 "tool_call_id", tr.ToolCallID,
936 )
937 }
938 }
939 if len(validParts) == 0 {
940 return fantasy.Message{}, false
941 }
942 msg := aiMsgs[0]
943 msg.Content = validParts
944 return msg, true
945}
946
947// syntheticToolResultsForOrphanedCalls returns a tool message containing
948// synthetic tool results for any tool calls in the assistant message that
949// have no matching result in knownToolResultIDs. LLM APIs require every
950// tool_use to be immediately followed by a tool_result; an interrupted
951// session can leave orphaned tool_use blocks that permanently lock the
952// conversation. Returns the message and true if any synthetic results were
953// produced.
954func syntheticToolResultsForOrphanedCalls(m message.Message, knownToolResultIDs map[string]struct{}) (fantasy.Message, bool) {
955 var syntheticParts []fantasy.MessagePart
956 for _, tc := range m.ToolCalls() {
957 if _, hasResult := knownToolResultIDs[tc.ID]; hasResult {
958 continue
959 }
960 slog.Warn(
961 "Injecting synthetic tool result for orphaned tool call",
962 "tool_call_id", tc.ID,
963 "tool_name", tc.Name,
964 )
965 syntheticParts = append(syntheticParts, fantasy.ToolResultPart{
966 ToolCallID: tc.ID,
967 Output: fantasy.ToolResultOutputContentError{
968 Error: errors.New("tool call was interrupted and did not produce a result, you may retry this call if the result is still needed"),
969 },
970 })
971 }
972 if len(syntheticParts) == 0 {
973 return fantasy.Message{}, false
974 }
975 return fantasy.Message{
976 Role: fantasy.MessageRoleTool,
977 Content: syntheticParts,
978 }, true
979}
980
981func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
982 msgs, err := a.messages.List(ctx, session.ID)
983 if err != nil {
984 return nil, fmt.Errorf("failed to list messages: %w", err)
985 }
986
987 if session.SummaryMessageID != "" {
988 summaryMsgIndex := -1
989 for i, msg := range msgs {
990 if msg.ID == session.SummaryMessageID {
991 summaryMsgIndex = i
992 break
993 }
994 }
995 if summaryMsgIndex != -1 {
996 msgs = msgs[summaryMsgIndex:]
997 msgs[0].Role = message.User
998 }
999 }
1000 return msgs, nil
1001}
1002
1003// generateTitle generates a session titled based on the initial prompt.
1004func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
1005 if userPrompt == "" {
1006 return
1007 }
1008
1009 smallModel := a.smallModel.Get()
1010 largeModel := a.largeModel.Get()
1011 systemPromptPrefix := a.systemPromptPrefix.Get()
1012
1013 var maxOutputTokens int64 = 40
1014 if smallModel.CatwalkCfg.CanReason {
1015 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
1016 }
1017
1018 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
1019 return fantasy.NewAgent(
1020 m,
1021 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
1022 fantasy.WithMaxOutputTokens(tok),
1023 fantasy.WithUserAgent(userAgent),
1024 )
1025 }
1026
1027 streamCall := fantasy.AgentStreamCall{
1028 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
1029 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
1030 prepared.Messages = opts.Messages
1031 if systemPromptPrefix != "" {
1032 prepared.Messages = append([]fantasy.Message{
1033 fantasy.NewSystemMessage(systemPromptPrefix),
1034 }, prepared.Messages...)
1035 }
1036 return callCtx, prepared, nil
1037 },
1038 }
1039
1040 // Use the small model to generate the title.
1041 model := smallModel
1042 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
1043 resp, err := agent.Stream(ctx, streamCall)
1044 if err == nil {
1045 // We successfully generated a title with the small model.
1046 slog.Debug("Generated title with small model")
1047 } else {
1048 // It didn't work. Let's try with the big model.
1049 slog.Error("Error generating title with small model; trying big model", "err", err)
1050 model = largeModel
1051 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
1052 resp, err = agent.Stream(ctx, streamCall)
1053 if err == nil {
1054 slog.Debug("Generated title with large model")
1055 } else {
1056 // Welp, the large model didn't work either. Use the default
1057 // session name and return.
1058 slog.Error("Error generating title with large model", "err", err)
1059 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1060 if saveErr != nil {
1061 slog.Error("Failed to save session title", "error", saveErr)
1062 }
1063 return
1064 }
1065 }
1066
1067 if resp == nil {
1068 // Actually, we didn't get a response so we can't. Use the default
1069 // session name and return.
1070 slog.Error("Response is nil; can't generate title")
1071 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1072 if saveErr != nil {
1073 slog.Error("Failed to save session title", "error", saveErr)
1074 }
1075 return
1076 }
1077
1078 // Clean up title.
1079 var title string
1080 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
1081
1082 // Remove thinking tags if present.
1083 title = thinkTagRegex.ReplaceAllString(title, "")
1084 title = orphanThinkTagRegex.ReplaceAllString(title, "")
1085
1086 title = strings.TrimSpace(title)
1087 title = cmp.Or(title, DefaultSessionName)
1088
1089 // Calculate usage and cost.
1090 var openrouterCost *float64
1091 for _, step := range resp.Steps {
1092 stepCost := a.openrouterCost(step.ProviderMetadata)
1093 if stepCost != nil {
1094 newCost := *stepCost
1095 if openrouterCost != nil {
1096 newCost += *openrouterCost
1097 }
1098 openrouterCost = &newCost
1099 }
1100 }
1101
1102 modelConfig := model.CatwalkCfg
1103 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
1104 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
1105 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
1106 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
1107
1108 // Use override cost if available (e.g., from OpenRouter).
1109 if openrouterCost != nil {
1110 cost = *openrouterCost
1111 }
1112
1113 // Skip cost accumulation
1114 if model.FlatRate {
1115 cost = 0
1116 }
1117
1118 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
1119 completionTokens := resp.TotalUsage.OutputTokens
1120
1121 // Atomically update only title and usage fields to avoid overriding other
1122 // concurrent session updates.
1123 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
1124 if saveErr != nil {
1125 slog.Error("Failed to save session title and usage", "error", saveErr)
1126 return
1127 }
1128}
1129
1130func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
1131 openrouterMetadata, ok := metadata[openrouter.Name]
1132 if !ok {
1133 return nil
1134 }
1135
1136 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
1137 if !ok {
1138 return nil
1139 }
1140 return &opts.Usage.Cost
1141}
1142
1143func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) {
1144 modelConfig := model.CatwalkCfg
1145 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1146 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1147 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1148 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1149
1150 if !estimated {
1151 a.eventTokensUsed(session.ID, model, usage, cost)
1152 }
1153
1154 if estimated {
1155 cost = 0
1156 } else {
1157 // Use override cost if available (e.g., from OpenRouter).
1158 if overrideCost != nil {
1159 cost = *overrideCost
1160 }
1161
1162 // Skip cost accumulation
1163 if model.FlatRate {
1164 cost = 0
1165 }
1166 }
1167
1168 session.Cost += cost
1169 updateSessionTokenCounters(session, usage)
1170}
1171
1172func updateSessionTokenCounters(session *session.Session, usage fantasy.Usage) {
1173 if usage.OutputTokens != 0 {
1174 session.CompletionTokens = usage.OutputTokens
1175 }
1176 if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
1177 session.PromptTokens = promptTokens
1178 }
1179}
1180
1181func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message) int64 {
1182 if usage.OutputTokens != 0 {
1183 return usage.OutputTokens
1184 }
1185 return approxTokenCount(summaryMessage.Content().Text) + approxTokenCount(summaryMessage.ReasoningContent().String())
1186}
1187
1188func (a *sessionAgent) Cancel(sessionID string) {
1189 // Cancel regular requests. Don't use Take() here - we need the entry to
1190 // remain in activeRequests so IsBusy() returns true until the goroutine
1191 // fully completes (including error handling that may access the DB).
1192 // The defer in processRequest will clean up the entry.
1193 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1194 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1195 cancel()
1196 }
1197
1198 // Also check for summarize requests.
1199 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1200 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1201 cancel()
1202 }
1203
1204 if a.QueuedPrompts(sessionID) > 0 {
1205 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1206 a.messageQueue.Del(sessionID)
1207 }
1208}
1209
1210func (a *sessionAgent) ClearQueue(sessionID string) {
1211 if a.QueuedPrompts(sessionID) > 0 {
1212 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1213 a.messageQueue.Del(sessionID)
1214 }
1215}
1216
1217func (a *sessionAgent) CancelAll() {
1218 if !a.IsBusy() {
1219 return
1220 }
1221 for key := range a.activeRequests.Seq2() {
1222 a.Cancel(key) // key is sessionID
1223 }
1224
1225 timeout := time.After(5 * time.Second)
1226 for a.IsBusy() {
1227 select {
1228 case <-timeout:
1229 return
1230 default:
1231 time.Sleep(200 * time.Millisecond)
1232 }
1233 }
1234}
1235
1236func (a *sessionAgent) IsBusy() bool {
1237 var busy bool
1238 for cancelFunc := range a.activeRequests.Seq() {
1239 if cancelFunc != nil {
1240 busy = true
1241 break
1242 }
1243 }
1244 return busy
1245}
1246
1247func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1248 _, busy := a.activeRequests.Get(sessionID)
1249 return busy
1250}
1251
1252func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1253 l, ok := a.messageQueue.Get(sessionID)
1254 if !ok {
1255 return 0
1256 }
1257 return len(l)
1258}
1259
1260func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1261 l, ok := a.messageQueue.Get(sessionID)
1262 if !ok {
1263 return nil
1264 }
1265 prompts := make([]string, len(l))
1266 for i, call := range l {
1267 prompts[i] = call.Prompt
1268 }
1269 return prompts
1270}
1271
1272func (a *sessionAgent) SetModels(large Model, small Model) {
1273 a.largeModel.Set(large)
1274 a.smallModel.Set(small)
1275}
1276
1277func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1278 a.tools.SetSlice(tools)
1279}
1280
1281func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1282 a.systemPrompt.Set(systemPrompt)
1283}
1284
1285func (a *sessionAgent) Model() Model {
1286 return a.largeModel.Get()
1287}
1288
1289// convertToToolResult converts a fantasy tool result to a message tool result.
1290func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1291 baseResult := message.ToolResult{
1292 ToolCallID: result.ToolCallID,
1293 Name: result.ToolName,
1294 Metadata: result.ClientMetadata,
1295 }
1296
1297 switch result.Result.GetType() {
1298 case fantasy.ToolResultContentTypeText:
1299 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1300 baseResult.Content = r.Text
1301 }
1302 case fantasy.ToolResultContentTypeError:
1303 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1304 baseResult.Content = r.Error.Error()
1305 baseResult.IsError = true
1306 }
1307 case fantasy.ToolResultContentTypeMedia:
1308 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1309 if !stringext.IsValidBase64(r.Data) {
1310 slog.Warn(
1311 "Tool returned media with invalid base64 data, discarding image",
1312 "tool", result.ToolName,
1313 "tool_call_id", result.ToolCallID,
1314 )
1315 baseResult.Content = "Tool returned image data with invalid encoding"
1316 baseResult.IsError = true
1317 } else {
1318 content := r.Text
1319 if content == "" {
1320 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1321 }
1322 baseResult.Content = content
1323 baseResult.Data = r.Data
1324 baseResult.MIMEType = r.MediaType
1325 }
1326 }
1327 }
1328
1329 return baseResult
1330}
1331
1332// workaroundProviderMediaLimitations converts media content in tool results to
1333// user messages for providers that don't natively support images in tool results.
1334//
1335// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1336// don't support sending images/media in tool result messages - they only accept
1337// text in tool results. However, they DO support images in user messages.
1338//
1339// If we send media in tool results to these providers, the API returns an error.
1340//
1341// Solution: For these providers, we:
1342// 1. Replace the media in the tool result with a text placeholder
1343// 2. Inject a user message immediately after with the image as a file attachment
1344// 3. This maintains the tool execution flow while working around API limitations
1345//
1346// Anthropic and Bedrock support images natively in tool results, so we skip
1347// this workaround for them.
1348//
1349// Example transformation:
1350//
1351// BEFORE: [tool result: image data]
1352// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1353func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1354 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1355 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1356
1357 if providerSupportsMedia {
1358 return messages
1359 }
1360
1361 convertedMessages := make([]fantasy.Message, 0, len(messages))
1362
1363 for _, msg := range messages {
1364 if msg.Role != fantasy.MessageRoleTool {
1365 convertedMessages = append(convertedMessages, msg)
1366 continue
1367 }
1368
1369 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1370 var mediaFiles []fantasy.FilePart
1371
1372 for _, part := range msg.Content {
1373 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1374 if !ok {
1375 textParts = append(textParts, part)
1376 continue
1377 }
1378
1379 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1380 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1381 if err != nil {
1382 slog.Warn("Failed to decode media data", "error", err)
1383 textParts = append(textParts, part)
1384 continue
1385 }
1386
1387 mediaFiles = append(mediaFiles, fantasy.FilePart{
1388 Data: decoded,
1389 MediaType: media.MediaType,
1390 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1391 })
1392
1393 textParts = append(textParts, fantasy.ToolResultPart{
1394 ToolCallID: toolResult.ToolCallID,
1395 Output: fantasy.ToolResultOutputContentText{
1396 Text: "[Image/media content loaded - see attached file]",
1397 },
1398 ProviderOptions: toolResult.ProviderOptions,
1399 })
1400 } else {
1401 textParts = append(textParts, part)
1402 }
1403 }
1404
1405 convertedMessages = append(convertedMessages, fantasy.Message{
1406 Role: fantasy.MessageRoleTool,
1407 Content: textParts,
1408 })
1409
1410 if len(mediaFiles) > 0 {
1411 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1412 "Here is the media content from the tool result:",
1413 mediaFiles...,
1414 ))
1415 }
1416 }
1417
1418 return convertedMessages
1419}
1420
1421// buildSummaryPrompt constructs the prompt text for session summarization.
1422func buildSummaryPrompt(todos []session.Todo) string {
1423 var sb strings.Builder
1424 sb.WriteString("Provide a detailed summary of our conversation above.")
1425 if len(todos) > 0 {
1426 sb.WriteString("\n\n## Current Todo List\n\n")
1427 for _, t := range todos {
1428 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1429 }
1430 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1431 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1432 }
1433 return sb.String()
1434}
1435
1436func providerRetryLogFields(err *fantasy.ProviderError, delay time.Duration) []any {
1437 fields := []any{
1438 "retry_delay", delay.String(),
1439 }
1440 if err == nil {
1441 return fields
1442 }
1443 fields = append(fields, "status_code", err.StatusCode)
1444 if err.Title != "" {
1445 fields = append(fields, "title", err.Title)
1446 }
1447 if err.Message != "" {
1448 fields = append(fields, "message", err.Message)
1449 }
1450 return fields
1451}