1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "net/http"
19 "os"
20 "regexp"
21 "strconv"
22 "strings"
23 "sync"
24 "time"
25
26 "charm.land/catwalk/pkg/catwalk"
27 "charm.land/fantasy"
28 "charm.land/fantasy/providers/anthropic"
29 "charm.land/fantasy/providers/bedrock"
30 "charm.land/fantasy/providers/google"
31 "charm.land/fantasy/providers/openai"
32 "charm.land/fantasy/providers/openrouter"
33 "charm.land/fantasy/providers/vercel"
34 "charm.land/lipgloss/v2"
35 "github.com/charmbracelet/crush/internal/agent/hyper"
36 "github.com/charmbracelet/crush/internal/agent/notify"
37 "github.com/charmbracelet/crush/internal/agent/tools"
38 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
39 "github.com/charmbracelet/crush/internal/config"
40 "github.com/charmbracelet/crush/internal/csync"
41 "github.com/charmbracelet/crush/internal/message"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 Prompt string
75 ProviderOptions fantasy.ProviderOptions
76 Attachments []message.Attachment
77 MaxOutputTokens int64
78 Temperature *float64
79 TopP *float64
80 TopK *int64
81 FrequencyPenalty *float64
82 PresencePenalty *float64
83 NonInteractive bool
84}
85
86type SessionAgent interface {
87 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
88 SetModels(large Model, small Model)
89 SetTools(tools []fantasy.AgentTool)
90 SetSystemPrompt(systemPrompt string)
91 Cancel(sessionID string)
92 CancelAll()
93 IsSessionBusy(sessionID string) bool
94 IsBusy() bool
95 QueuedPrompts(sessionID string) int
96 QueuedPromptsList(sessionID string) []string
97 ClearQueue(sessionID string)
98 Summarize(context.Context, string, fantasy.ProviderOptions) error
99 Model() Model
100}
101
102type Model struct {
103 Model fantasy.LanguageModel
104 CatwalkCfg catwalk.Model
105 ModelCfg config.SelectedModel
106 FlatRate bool
107}
108
109type sessionAgent struct {
110 largeModel *csync.Value[Model]
111 smallModel *csync.Value[Model]
112 systemPromptPrefix *csync.Value[string]
113 systemPrompt *csync.Value[string]
114 tools *csync.Slice[fantasy.AgentTool]
115
116 isSubAgent bool
117 sessions session.Service
118 messages message.Service
119 disableAutoSummarize bool
120 isYolo bool
121 notify pubsub.Publisher[notify.Notification]
122
123 messageQueue *csync.Map[string, []SessionAgentCall]
124 activeRequests *csync.Map[string, context.CancelFunc]
125}
126
127type SessionAgentOptions struct {
128 LargeModel Model
129 SmallModel Model
130 SystemPromptPrefix string
131 SystemPrompt string
132 IsSubAgent bool
133 DisableAutoSummarize bool
134 IsYolo bool
135 Sessions session.Service
136 Messages message.Service
137 Tools []fantasy.AgentTool
138 Notify pubsub.Publisher[notify.Notification]
139}
140
141func NewSessionAgent(
142 opts SessionAgentOptions,
143) SessionAgent {
144 return &sessionAgent{
145 largeModel: csync.NewValue(opts.LargeModel),
146 smallModel: csync.NewValue(opts.SmallModel),
147 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
148 systemPrompt: csync.NewValue(opts.SystemPrompt),
149 isSubAgent: opts.IsSubAgent,
150 sessions: opts.Sessions,
151 messages: opts.Messages,
152 disableAutoSummarize: opts.DisableAutoSummarize,
153 tools: csync.NewSliceFrom(opts.Tools),
154 isYolo: opts.IsYolo,
155 notify: opts.Notify,
156 messageQueue: csync.NewMap[string, []SessionAgentCall](),
157 activeRequests: csync.NewMap[string, context.CancelFunc](),
158 }
159}
160
161func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
162 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
163 return nil, ErrEmptyPrompt
164 }
165 if call.SessionID == "" {
166 return nil, ErrSessionMissing
167 }
168
169 // Queue the message if busy
170 if a.IsSessionBusy(call.SessionID) {
171 existing, ok := a.messageQueue.Get(call.SessionID)
172 if !ok {
173 existing = []SessionAgentCall{}
174 }
175 existing = append(existing, call)
176 a.messageQueue.Set(call.SessionID, existing)
177 return nil, nil
178 }
179
180 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
181 agentTools := a.tools.Copy()
182 largeModel := a.largeModel.Get()
183 systemPrompt := a.systemPrompt.Get()
184 promptPrefix := a.systemPromptPrefix.Get()
185 var instructions strings.Builder
186
187 for _, server := range mcp.GetStates() {
188 if server.State != mcp.StateConnected {
189 continue
190 }
191 if s := server.Client.InitializeResult().Instructions; s != "" {
192 instructions.WriteString(s)
193 instructions.WriteString("\n\n")
194 }
195 }
196
197 if s := instructions.String(); s != "" {
198 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
199 }
200
201 if len(agentTools) > 0 {
202 // Add Anthropic caching to the last tool.
203 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
204 }
205
206 agent := fantasy.NewAgent(
207 largeModel.Model,
208 fantasy.WithSystemPrompt(systemPrompt),
209 fantasy.WithTools(agentTools...),
210 fantasy.WithUserAgent(userAgent),
211 )
212
213 sessionLock := sync.Mutex{}
214 currentSession, err := a.sessions.Get(ctx, call.SessionID)
215 if err != nil {
216 return nil, fmt.Errorf("failed to get session: %w", err)
217 }
218
219 msgs, err := a.getSessionMessages(ctx, currentSession)
220 if err != nil {
221 return nil, fmt.Errorf("failed to get session messages: %w", err)
222 }
223
224 var wg sync.WaitGroup
225 // Generate title if first message.
226 if len(msgs) == 0 {
227 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
228 wg.Go(func() {
229 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
230 })
231 }
232 defer wg.Wait()
233
234 // Add the user message to the session.
235 _, err = a.createUserMessage(ctx, call)
236 if err != nil {
237 return nil, err
238 }
239
240 // Add the session to the context.
241 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
242
243 genCtx, cancel := context.WithCancel(ctx)
244 a.activeRequests.Set(call.SessionID, cancel)
245
246 defer cancel()
247 defer a.activeRequests.Del(call.SessionID)
248 // Drain any debounced message updates before returning. message.Service
249 // already flushes synchronously on terminal updates, but a defer here
250 // guarantees the contract at every Run exit (success, error, panic
251 // recovery upstream) without callers needing to know.
252 defer func() {
253 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
254 slog.Error("Failed to flush pending message updates after run", "error", flushErr)
255 }
256 }()
257
258 history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
259
260 startTime := time.Now()
261 a.eventPromptSent(call.SessionID)
262
263 var currentAssistant *message.Message
264 var stepMessages []fantasy.Message
265 var shouldSummarize bool
266 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
267 var maxOutputTokens *int64
268 if call.MaxOutputTokens > 0 {
269 maxOutputTokens = &call.MaxOutputTokens
270 }
271 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
272 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
273 Files: files,
274 Messages: history,
275 ProviderOptions: call.ProviderOptions,
276 MaxOutputTokens: maxOutputTokens,
277 TopP: call.TopP,
278 Temperature: call.Temperature,
279 PresencePenalty: call.PresencePenalty,
280 TopK: call.TopK,
281 FrequencyPenalty: call.FrequencyPenalty,
282 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
283 prepared.Messages = options.Messages
284 for i := range prepared.Messages {
285 prepared.Messages[i].ProviderOptions = nil
286 }
287
288 // Use latest tools (updated by SetTools when MCP tools change).
289 prepared.Tools = a.tools.Copy()
290
291 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
292 a.messageQueue.Del(call.SessionID)
293 for _, queued := range queuedCalls {
294 userMessage, createErr := a.createUserMessage(callContext, queued)
295 if createErr != nil {
296 return callContext, prepared, createErr
297 }
298 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
299 }
300
301 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
302
303 lastSystemRoleInx := 0
304 systemMessageUpdated := false
305 for i, msg := range prepared.Messages {
306 // Only add cache control to the last message.
307 if msg.Role == fantasy.MessageRoleSystem {
308 lastSystemRoleInx = i
309 } else if !systemMessageUpdated {
310 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
311 systemMessageUpdated = true
312 }
313 // Than add cache control to the last 2 messages.
314 if i > len(prepared.Messages)-3 {
315 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
316 }
317 }
318
319 if promptPrefix != "" {
320 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
321 }
322
323 sessionLock.Lock()
324 stepMessages = cloneFantasyMessages(prepared.Messages)
325 sessionLock.Unlock()
326
327 var assistantMsg message.Message
328 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
329 Role: message.Assistant,
330 Parts: []message.ContentPart{},
331 Model: largeModel.ModelCfg.Model,
332 Provider: largeModel.ModelCfg.Provider,
333 })
334 if err != nil {
335 return callContext, prepared, err
336 }
337 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
338 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
339 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
340 currentAssistant = &assistantMsg
341 return callContext, prepared, err
342 },
343 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
344 currentAssistant.AppendReasoningContent(reasoning.Text)
345 return a.messages.Update(genCtx, *currentAssistant)
346 },
347 OnReasoningDelta: func(id string, text string) error {
348 currentAssistant.AppendReasoningContent(text)
349 return a.messages.Update(genCtx, *currentAssistant)
350 },
351 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
352 // handle anthropic signature
353 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
354 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
355 currentAssistant.AppendReasoningSignature(reasoning.Signature)
356 }
357 }
358 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
359 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
360 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
361 }
362 }
363 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
364 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
365 currentAssistant.SetReasoningResponsesData(reasoning)
366 }
367 }
368 currentAssistant.FinishThinking()
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 OnTextDelta: func(id string, text string) error {
372 // Strip leading newline from initial text content. This is is
373 // particularly important in non-interactive mode where leading
374 // newlines are very visible.
375 if len(currentAssistant.Parts) == 0 {
376 text = strings.TrimPrefix(text, "\n")
377 }
378
379 currentAssistant.AppendContent(text)
380 return a.messages.Update(genCtx, *currentAssistant)
381 },
382 OnToolInputStart: func(id string, toolName string) error {
383 toolCall := message.ToolCall{
384 ID: id,
385 Name: toolName,
386 ProviderExecuted: false,
387 Finished: false,
388 }
389 currentAssistant.AddToolCall(toolCall)
390 // Use parent ctx instead of genCtx to ensure the update succeeds
391 // even if the request is canceled mid-stream
392 return a.messages.Update(ctx, *currentAssistant)
393 },
394 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
395 slog.Warn("Provider request failed, retrying", providerRetryLogFields(err, delay)...)
396 },
397 OnToolCall: func(tc fantasy.ToolCallContent) error {
398 toolCall := message.ToolCall{
399 ID: tc.ToolCallID,
400 Name: tc.ToolName,
401 Input: tc.Input,
402 ProviderExecuted: false,
403 Finished: true,
404 }
405 currentAssistant.AddToolCall(toolCall)
406 // Use parent ctx instead of genCtx to ensure the update succeeds
407 // even if the request is canceled mid-stream
408 return a.messages.Update(ctx, *currentAssistant)
409 },
410 OnToolResult: func(result fantasy.ToolResultContent) error {
411 toolResult := a.convertToToolResult(result)
412 // Use parent ctx instead of genCtx to ensure the message is created
413 // even if the request is canceled mid-stream
414 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
415 Role: message.Tool,
416 Parts: []message.ContentPart{
417 toolResult,
418 },
419 })
420 return createMsgErr
421 },
422 OnStepFinish: func(stepResult fantasy.StepResult) error {
423 finishReason := message.FinishReasonUnknown
424 switch stepResult.FinishReason {
425 case fantasy.FinishReasonLength:
426 finishReason = message.FinishReasonMaxTokens
427 case fantasy.FinishReasonStop:
428 finishReason = message.FinishReasonEndTurn
429 case fantasy.FinishReasonToolCalls:
430 finishReason = message.FinishReasonToolUse
431 }
432 // If a tool result halted the turn (e.g. a hook halt or a
433 // permission denial), the step ends on FinishReasonToolCalls but
434 // the model will not be called again. Treat it as the end of the
435 // turn so the UI can render the assistant footer.
436 if finishReason == message.FinishReasonToolUse {
437 for _, tr := range stepResult.Content.ToolResults() {
438 if tr.StopTurn {
439 finishReason = message.FinishReasonEndTurn
440 break
441 }
442 }
443 }
444 currentAssistant.AddFinish(finishReason, "", "")
445 sessionLock.Lock()
446 defer sessionLock.Unlock()
447
448 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
449 if getSessionErr != nil {
450 return getSessionErr
451 }
452 usage, estimated := fallbackStepUsage(stepMessages, stepResult)
453 a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated)
454 _, sessionErr := a.sessions.Save(ctx, updatedSession)
455 if sessionErr != nil {
456 return sessionErr
457 }
458 currentSession = updatedSession
459 return a.messages.Update(genCtx, *currentAssistant)
460 },
461 StopWhen: []fantasy.StopCondition{
462 func(_ []fantasy.StepResult) bool {
463 cw := int64(largeModel.CatwalkCfg.ContextWindow)
464 // If context window is unknown (0), skip auto-summarize
465 // to avoid immediately truncating custom/local models.
466 if cw == 0 {
467 return false
468 }
469 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
470 remaining := cw - tokens
471 var threshold int64
472 if cw > largeContextWindowThreshold {
473 threshold = largeContextWindowBuffer
474 } else {
475 threshold = int64(float64(cw) * smallContextWindowRatio)
476 }
477 if (remaining <= threshold) && !a.disableAutoSummarize {
478 shouldSummarize = true
479 return true
480 }
481 return false
482 },
483 func(steps []fantasy.StepResult) bool {
484 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
485 },
486 },
487 })
488
489 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
490
491 if err != nil {
492 isHyper := largeModel.ModelCfg.Provider == hyper.Name
493 isCancelErr := errors.Is(err, context.Canceled)
494 if currentAssistant == nil {
495 return result, err
496 }
497 // Ensure we finish thinking on error to close the reasoning state.
498 currentAssistant.FinishThinking()
499 toolCalls := currentAssistant.ToolCalls()
500 // INFO: we use the parent context here because the genCtx has been cancelled.
501 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
502 if createErr != nil {
503 return nil, createErr
504 }
505 for _, tc := range toolCalls {
506 if !tc.Finished {
507 tc.Finished = true
508 tc.Input = "{}"
509 currentAssistant.AddToolCall(tc)
510 updateErr := a.messages.Update(ctx, *currentAssistant)
511 if updateErr != nil {
512 return nil, updateErr
513 }
514 }
515
516 found := false
517 for _, msg := range msgs {
518 if msg.Role == message.Tool {
519 for _, tr := range msg.ToolResults() {
520 if tr.ToolCallID == tc.ID {
521 found = true
522 break
523 }
524 }
525 }
526 if found {
527 break
528 }
529 }
530 if found {
531 continue
532 }
533 content := "There was an error while executing the tool"
534 if isCancelErr {
535 content = "Error: user cancelled assistant tool calling"
536 }
537 toolResult := message.ToolResult{
538 ToolCallID: tc.ID,
539 Name: tc.Name,
540 Content: content,
541 IsError: true,
542 }
543 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
544 Role: message.Tool,
545 Parts: []message.ContentPart{
546 toolResult,
547 },
548 })
549 if createErr != nil {
550 return nil, createErr
551 }
552 }
553 var fantasyErr *fantasy.Error
554 var providerErr *fantasy.ProviderError
555 const defaultTitle = "Provider Error"
556 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
557 if isCancelErr {
558 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
559 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized {
560 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
561 if a.notify != nil {
562 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
563 SessionID: call.SessionID,
564 SessionTitle: currentSession.Title,
565 Type: notify.TypeReAuthenticate,
566 ProviderID: largeModel.ModelCfg.Provider,
567 })
568 }
569 } else if isHyper && errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusPaymentRequired {
570 url := hyper.BaseURL()
571 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
572 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
573 } else if errors.As(err, &providerErr) {
574 if providerErr.Message == "The requested model is not supported." {
575 url := "https://github.com/settings/copilot/features"
576 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
577 currentAssistant.AddFinish(
578 message.FinishReasonError,
579 "Copilot model not enabled",
580 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
581 )
582 } else {
583 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
584 }
585 } else if errors.As(err, &fantasyErr) {
586 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
587 } else {
588 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
589 }
590 // Note: we use the parent context here because the genCtx has been
591 // cancelled.
592 updateErr := a.messages.Update(ctx, *currentAssistant)
593 if updateErr != nil {
594 return nil, updateErr
595 }
596 return nil, err
597 }
598
599 if shouldSummarize {
600 a.activeRequests.Del(call.SessionID)
601 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
602 return nil, summarizeErr
603 }
604 // If the agent wasn't done...
605 if len(currentAssistant.ToolCalls()) > 0 {
606 existing, ok := a.messageQueue.Get(call.SessionID)
607 if !ok {
608 existing = []SessionAgentCall{}
609 }
610 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
611 existing = append(existing, call)
612 a.messageQueue.Set(call.SessionID, existing)
613 }
614 }
615
616 // Release active request before publishing the notification.
617 // TUI handlers poll IsSessionBusy() and only re-evaluate when a
618 // tea.Msg arrives, so the cleanup must precede the notify or
619 // subscribers see stale busy state at the moment of receipt.
620 a.activeRequests.Del(call.SessionID)
621 cancel()
622
623 // Send notification that agent has finished its turn (skip for
624 // nested/non-interactive sessions).
625 if !call.NonInteractive && a.notify != nil {
626 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
627 SessionID: call.SessionID,
628 SessionTitle: currentSession.Title,
629 Type: notify.TypeAgentFinished,
630 })
631 }
632
633 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
634 if !ok || len(queuedMessages) == 0 {
635 return result, err
636 }
637 // There are queued messages restart the loop.
638 firstQueuedMessage := queuedMessages[0]
639 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
640 return a.Run(ctx, firstQueuedMessage)
641}
642
643func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
644 if a.IsSessionBusy(sessionID) {
645 return ErrSessionBusy
646 }
647
648 // Copy mutable fields under lock to avoid races with SetModels.
649 largeModel := a.largeModel.Get()
650 systemPromptPrefix := a.systemPromptPrefix.Get()
651
652 currentSession, err := a.sessions.Get(ctx, sessionID)
653 if err != nil {
654 return fmt.Errorf("failed to get session: %w", err)
655 }
656 msgs, err := a.getSessionMessages(ctx, currentSession)
657 if err != nil {
658 return err
659 }
660 if len(msgs) == 0 {
661 // Nothing to summarize.
662 return nil
663 }
664
665 aiMsgs, _ := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages)
666
667 genCtx, cancel := context.WithCancel(ctx)
668 a.activeRequests.Set(sessionID, cancel)
669 defer a.activeRequests.Del(sessionID)
670 defer cancel()
671 defer func() {
672 if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
673 slog.Error("Failed to flush pending message updates after summarize", "error", flushErr)
674 }
675 }()
676
677 agent := fantasy.NewAgent(
678 largeModel.Model,
679 fantasy.WithSystemPrompt(string(summaryPrompt)),
680 fantasy.WithUserAgent(userAgent),
681 )
682 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
683 Role: message.Assistant,
684 Model: largeModel.Model.Model(),
685 Provider: largeModel.Model.Provider(),
686 IsSummaryMessage: true,
687 })
688 if err != nil {
689 return err
690 }
691
692 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
693
694 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
695 Prompt: summaryPromptText,
696 Messages: aiMsgs,
697 ProviderOptions: opts,
698 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
699 prepared.Messages = options.Messages
700 if systemPromptPrefix != "" {
701 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
702 }
703 return callContext, prepared, nil
704 },
705 OnReasoningDelta: func(id string, text string) error {
706 summaryMessage.AppendReasoningContent(text)
707 return a.messages.Update(genCtx, summaryMessage)
708 },
709 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
710 // Handle anthropic signature.
711 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
712 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
713 summaryMessage.AppendReasoningSignature(signature.Signature)
714 }
715 }
716 summaryMessage.FinishThinking()
717 return a.messages.Update(genCtx, summaryMessage)
718 },
719 OnTextDelta: func(id, text string) error {
720 summaryMessage.AppendContent(text)
721 return a.messages.Update(genCtx, summaryMessage)
722 },
723 })
724 if err != nil {
725 isCancelErr := errors.Is(err, context.Canceled)
726 if isCancelErr {
727 // User cancelled summarize we need to remove the summary message.
728 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
729 return deleteErr
730 }
731 // Mark the summary message as finished with an error so the UI
732 // stops spinning.
733 summaryMessage.AddFinish(message.FinishReasonError, "Summarization Error", err.Error())
734 if updateErr := a.messages.Update(ctx, summaryMessage); updateErr != nil {
735 return updateErr
736 }
737 return err
738 }
739
740 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
741 err = a.messages.Update(genCtx, summaryMessage)
742 if err != nil {
743 return err
744 }
745
746 var openrouterCost *float64
747 for _, step := range resp.Steps {
748 stepCost := a.openrouterCost(step.ProviderMetadata)
749 if stepCost != nil {
750 newCost := *stepCost
751 if openrouterCost != nil {
752 newCost += *openrouterCost
753 }
754 openrouterCost = &newCost
755 }
756 }
757
758 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost, false)
759
760 // Just in case, get just the last usage info.
761 usage := resp.Response.Usage
762 currentSession.SummaryMessageID = summaryMessage.ID
763 currentSession.CompletionTokens = summaryCompletionTokens(usage, summaryMessage)
764 currentSession.PromptTokens = 0
765 _, err = a.sessions.Save(genCtx, currentSession)
766 if err != nil {
767 return err
768 }
769
770 // Release the active request before processing queued messages so that
771 // Run() does not see the session as busy.
772 a.activeRequests.Del(sessionID)
773 cancel()
774
775 // Process any messages that were queued while summarizing.
776 queuedMessages, ok := a.messageQueue.Get(sessionID)
777 if !ok || len(queuedMessages) == 0 {
778 return nil
779 }
780 firstQueuedMessage := queuedMessages[0]
781 a.messageQueue.Set(sessionID, queuedMessages[1:])
782 _, qErr := a.Run(ctx, firstQueuedMessage)
783 return qErr
784}
785
786func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
787 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
788 return fantasy.ProviderOptions{}
789 }
790 return fantasy.ProviderOptions{
791 anthropic.Name: &anthropic.ProviderCacheControlOptions{
792 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
793 },
794 bedrock.Name: &anthropic.ProviderCacheControlOptions{
795 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
796 },
797 vercel.Name: &anthropic.ProviderCacheControlOptions{
798 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
799 },
800 }
801}
802
803func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
804 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
805 var attachmentParts []message.ContentPart
806 for _, attachment := range call.Attachments {
807 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
808 }
809 parts = append(parts, attachmentParts...)
810 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
811 Role: message.User,
812 Parts: parts,
813 })
814 if err != nil {
815 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
816 }
817 return msg, nil
818}
819
820func (a *sessionAgent) preparePrompt(msgs []message.Message, supportsImages bool, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
821 var history []fantasy.Message
822 if !a.isSubAgent {
823 history = append(history, fantasy.NewUserMessage(
824 fmt.Sprintf(
825 "<system_reminder>%s</system_reminder>",
826 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
827If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
828If not, please feel free to ignore. Again do not mention this message to the user.`,
829 ),
830 ))
831 }
832 // Collect all tool call IDs present in assistant messages and all tool
833 // result IDs present in tool messages. This lets us detect both orphaned
834 // tool results (result without a call) and orphaned tool calls (call
835 // without a result).
836 knownToolCallIDs := make(map[string]struct{})
837 knownToolResultIDs := make(map[string]struct{})
838 for _, m := range msgs {
839 switch m.Role {
840 case message.Assistant:
841 for _, tc := range m.ToolCalls() {
842 knownToolCallIDs[tc.ID] = struct{}{}
843 }
844 case message.Tool:
845 for _, tr := range m.ToolResults() {
846 knownToolResultIDs[tr.ToolCallID] = struct{}{}
847 }
848 }
849 }
850
851 for _, m := range msgs {
852 if len(m.Parts) == 0 {
853 continue
854 }
855 // Assistant message without content or tool calls (cancelled before it returned anything).
856 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
857 continue
858 }
859 if m.Role == message.Tool {
860 if msg, ok := filterOrphanedToolResults(m, knownToolCallIDs); ok {
861 history = append(history, msg)
862 }
863 continue
864 }
865 aiMsgs := m.ToAIMessage()
866 if !supportsImages {
867 for i := range aiMsgs {
868 if aiMsgs[i].Role == fantasy.MessageRoleUser {
869 aiMsgs[i].Content = filterFileParts(aiMsgs[i].Content)
870 }
871 }
872 }
873 history = append(history, aiMsgs...)
874
875 if m.Role == message.Assistant {
876 if msg, ok := syntheticToolResultsForOrphanedCalls(m, knownToolResultIDs); ok {
877 history = append(history, msg)
878 }
879 }
880 }
881
882 var files []fantasy.FilePart
883 for _, attachment := range attachments {
884 if attachment.IsText() {
885 continue
886 }
887 files = append(files, fantasy.FilePart{
888 Filename: attachment.FileName,
889 Data: attachment.Content,
890 MediaType: attachment.MimeType,
891 })
892 }
893
894 return history, files
895}
896
897// filterFileParts removes fantasy.FilePart entries from a slice of message
898// parts. Used to strip image attachments from historical user messages when
899// the current model does not support them.
900func filterFileParts(parts []fantasy.MessagePart) []fantasy.MessagePart {
901 filtered := make([]fantasy.MessagePart, 0, len(parts))
902 for _, part := range parts {
903 if _, ok := fantasy.AsMessagePart[fantasy.FilePart](part); ok {
904 continue
905 }
906 filtered = append(filtered, part)
907 }
908 return filtered
909}
910
911// filterOrphanedToolResults converts a tool message to a fantasy.Message,
912// dropping any tool result parts whose tool_call_id has no matching tool call
913// in the known set. An orphaned result causes API validation to fail on every
914// subsequent turn, permanently locking the session. Returns the filtered
915// message and true if at least one valid part remains.
916func filterOrphanedToolResults(m message.Message, knownToolCallIDs map[string]struct{}) (fantasy.Message, bool) {
917 aiMsgs := m.ToAIMessage()
918 if len(aiMsgs) == 0 {
919 return fantasy.Message{}, false
920 }
921 var validParts []fantasy.MessagePart
922 for _, part := range aiMsgs[0].Content {
923 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
924 if !ok {
925 validParts = append(validParts, part)
926 continue
927 }
928 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
929 validParts = append(validParts, part)
930 } else {
931 slog.Warn(
932 "Dropping orphaned tool result with no matching tool call",
933 "tool_call_id", tr.ToolCallID,
934 )
935 }
936 }
937 if len(validParts) == 0 {
938 return fantasy.Message{}, false
939 }
940 msg := aiMsgs[0]
941 msg.Content = validParts
942 return msg, true
943}
944
945// syntheticToolResultsForOrphanedCalls returns a tool message containing
946// synthetic tool results for any tool calls in the assistant message that
947// have no matching result in knownToolResultIDs. LLM APIs require every
948// tool_use to be immediately followed by a tool_result; an interrupted
949// session can leave orphaned tool_use blocks that permanently lock the
950// conversation. Returns the message and true if any synthetic results were
951// produced.
952func syntheticToolResultsForOrphanedCalls(m message.Message, knownToolResultIDs map[string]struct{}) (fantasy.Message, bool) {
953 var syntheticParts []fantasy.MessagePart
954 for _, tc := range m.ToolCalls() {
955 if _, hasResult := knownToolResultIDs[tc.ID]; hasResult {
956 continue
957 }
958 slog.Warn(
959 "Injecting synthetic tool result for orphaned tool call",
960 "tool_call_id", tc.ID,
961 "tool_name", tc.Name,
962 )
963 syntheticParts = append(syntheticParts, fantasy.ToolResultPart{
964 ToolCallID: tc.ID,
965 Output: fantasy.ToolResultOutputContentError{
966 Error: errors.New("tool call was interrupted and did not produce a result, you may retry this call if the result is still needed"),
967 },
968 })
969 }
970 if len(syntheticParts) == 0 {
971 return fantasy.Message{}, false
972 }
973 return fantasy.Message{
974 Role: fantasy.MessageRoleTool,
975 Content: syntheticParts,
976 }, true
977}
978
979func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
980 msgs, err := a.messages.List(ctx, session.ID)
981 if err != nil {
982 return nil, fmt.Errorf("failed to list messages: %w", err)
983 }
984
985 if session.SummaryMessageID != "" {
986 summaryMsgIndex := -1
987 for i, msg := range msgs {
988 if msg.ID == session.SummaryMessageID {
989 summaryMsgIndex = i
990 break
991 }
992 }
993 if summaryMsgIndex != -1 {
994 msgs = msgs[summaryMsgIndex:]
995 msgs[0].Role = message.User
996 }
997 }
998 return msgs, nil
999}
1000
1001// generateTitle generates a session titled based on the initial prompt.
1002func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
1003 if userPrompt == "" {
1004 return
1005 }
1006
1007 smallModel := a.smallModel.Get()
1008 largeModel := a.largeModel.Get()
1009 systemPromptPrefix := a.systemPromptPrefix.Get()
1010
1011 var maxOutputTokens int64 = 40
1012 if smallModel.CatwalkCfg.CanReason {
1013 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
1014 }
1015
1016 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
1017 return fantasy.NewAgent(
1018 m,
1019 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
1020 fantasy.WithMaxOutputTokens(tok),
1021 fantasy.WithUserAgent(userAgent),
1022 )
1023 }
1024
1025 streamCall := fantasy.AgentStreamCall{
1026 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
1027 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
1028 prepared.Messages = opts.Messages
1029 if systemPromptPrefix != "" {
1030 prepared.Messages = append([]fantasy.Message{
1031 fantasy.NewSystemMessage(systemPromptPrefix),
1032 }, prepared.Messages...)
1033 }
1034 return callCtx, prepared, nil
1035 },
1036 }
1037
1038 // Use the small model to generate the title.
1039 model := smallModel
1040 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
1041 resp, err := agent.Stream(ctx, streamCall)
1042 if err == nil {
1043 // We successfully generated a title with the small model.
1044 slog.Debug("Generated title with small model")
1045 } else {
1046 // It didn't work. Let's try with the big model.
1047 slog.Error("Error generating title with small model; trying big model", "err", err)
1048 model = largeModel
1049 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
1050 resp, err = agent.Stream(ctx, streamCall)
1051 if err == nil {
1052 slog.Debug("Generated title with large model")
1053 } else {
1054 // Welp, the large model didn't work either. Use the default
1055 // session name and return.
1056 slog.Error("Error generating title with large model", "err", err)
1057 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1058 if saveErr != nil {
1059 slog.Error("Failed to save session title", "error", saveErr)
1060 }
1061 return
1062 }
1063 }
1064
1065 if resp == nil {
1066 // Actually, we didn't get a response so we can't. Use the default
1067 // session name and return.
1068 slog.Error("Response is nil; can't generate title")
1069 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
1070 if saveErr != nil {
1071 slog.Error("Failed to save session title", "error", saveErr)
1072 }
1073 return
1074 }
1075
1076 // Clean up title.
1077 var title string
1078 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
1079
1080 // Remove thinking tags if present.
1081 title = thinkTagRegex.ReplaceAllString(title, "")
1082 title = orphanThinkTagRegex.ReplaceAllString(title, "")
1083
1084 title = strings.TrimSpace(title)
1085 title = cmp.Or(title, DefaultSessionName)
1086
1087 // Calculate usage and cost.
1088 var openrouterCost *float64
1089 for _, step := range resp.Steps {
1090 stepCost := a.openrouterCost(step.ProviderMetadata)
1091 if stepCost != nil {
1092 newCost := *stepCost
1093 if openrouterCost != nil {
1094 newCost += *openrouterCost
1095 }
1096 openrouterCost = &newCost
1097 }
1098 }
1099
1100 modelConfig := model.CatwalkCfg
1101 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
1102 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
1103 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
1104 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
1105
1106 // Use override cost if available (e.g., from OpenRouter).
1107 if openrouterCost != nil {
1108 cost = *openrouterCost
1109 }
1110
1111 // Skip cost accumulation
1112 if model.FlatRate {
1113 cost = 0
1114 }
1115
1116 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
1117 completionTokens := resp.TotalUsage.OutputTokens
1118
1119 // Atomically update only title and usage fields to avoid overriding other
1120 // concurrent session updates.
1121 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
1122 if saveErr != nil {
1123 slog.Error("Failed to save session title and usage", "error", saveErr)
1124 return
1125 }
1126}
1127
1128func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
1129 openrouterMetadata, ok := metadata[openrouter.Name]
1130 if !ok {
1131 return nil
1132 }
1133
1134 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
1135 if !ok {
1136 return nil
1137 }
1138 return &opts.Usage.Cost
1139}
1140
1141func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) {
1142 modelConfig := model.CatwalkCfg
1143 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1144 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1145 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1146 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1147
1148 if !estimated {
1149 a.eventTokensUsed(session.ID, model, usage, cost)
1150 }
1151
1152 if estimated {
1153 cost = 0
1154 } else {
1155 // Use override cost if available (e.g., from OpenRouter).
1156 if overrideCost != nil {
1157 cost = *overrideCost
1158 }
1159
1160 // Skip cost accumulation
1161 if model.FlatRate {
1162 cost = 0
1163 }
1164 }
1165
1166 session.Cost += cost
1167 updateSessionTokenCounters(session, usage)
1168}
1169
1170func updateSessionTokenCounters(session *session.Session, usage fantasy.Usage) {
1171 if usage.OutputTokens != 0 {
1172 session.CompletionTokens = usage.OutputTokens
1173 }
1174 if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
1175 session.PromptTokens = promptTokens
1176 }
1177}
1178
1179func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message) int64 {
1180 if usage.OutputTokens != 0 {
1181 return usage.OutputTokens
1182 }
1183 return approxTokenCount(summaryMessage.Content().Text) + approxTokenCount(summaryMessage.ReasoningContent().String())
1184}
1185
1186func (a *sessionAgent) Cancel(sessionID string) {
1187 // Cancel regular requests. Don't use Take() here - we need the entry to
1188 // remain in activeRequests so IsBusy() returns true until the goroutine
1189 // fully completes (including error handling that may access the DB).
1190 // The defer in processRequest will clean up the entry.
1191 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1192 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1193 cancel()
1194 }
1195
1196 // Also check for summarize requests.
1197 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1198 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1199 cancel()
1200 }
1201
1202 if a.QueuedPrompts(sessionID) > 0 {
1203 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1204 a.messageQueue.Del(sessionID)
1205 }
1206}
1207
1208func (a *sessionAgent) ClearQueue(sessionID string) {
1209 if a.QueuedPrompts(sessionID) > 0 {
1210 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1211 a.messageQueue.Del(sessionID)
1212 }
1213}
1214
1215func (a *sessionAgent) CancelAll() {
1216 if !a.IsBusy() {
1217 return
1218 }
1219 for key := range a.activeRequests.Seq2() {
1220 a.Cancel(key) // key is sessionID
1221 }
1222
1223 timeout := time.After(5 * time.Second)
1224 for a.IsBusy() {
1225 select {
1226 case <-timeout:
1227 return
1228 default:
1229 time.Sleep(200 * time.Millisecond)
1230 }
1231 }
1232}
1233
1234func (a *sessionAgent) IsBusy() bool {
1235 var busy bool
1236 for cancelFunc := range a.activeRequests.Seq() {
1237 if cancelFunc != nil {
1238 busy = true
1239 break
1240 }
1241 }
1242 return busy
1243}
1244
1245func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1246 _, busy := a.activeRequests.Get(sessionID)
1247 return busy
1248}
1249
1250func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1251 l, ok := a.messageQueue.Get(sessionID)
1252 if !ok {
1253 return 0
1254 }
1255 return len(l)
1256}
1257
1258func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1259 l, ok := a.messageQueue.Get(sessionID)
1260 if !ok {
1261 return nil
1262 }
1263 prompts := make([]string, len(l))
1264 for i, call := range l {
1265 prompts[i] = call.Prompt
1266 }
1267 return prompts
1268}
1269
1270func (a *sessionAgent) SetModels(large Model, small Model) {
1271 a.largeModel.Set(large)
1272 a.smallModel.Set(small)
1273}
1274
1275func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1276 a.tools.SetSlice(tools)
1277}
1278
1279func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1280 a.systemPrompt.Set(systemPrompt)
1281}
1282
1283func (a *sessionAgent) Model() Model {
1284 return a.largeModel.Get()
1285}
1286
1287// convertToToolResult converts a fantasy tool result to a message tool result.
1288func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1289 baseResult := message.ToolResult{
1290 ToolCallID: result.ToolCallID,
1291 Name: result.ToolName,
1292 Metadata: result.ClientMetadata,
1293 }
1294
1295 switch result.Result.GetType() {
1296 case fantasy.ToolResultContentTypeText:
1297 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1298 baseResult.Content = r.Text
1299 }
1300 case fantasy.ToolResultContentTypeError:
1301 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1302 baseResult.Content = r.Error.Error()
1303 baseResult.IsError = true
1304 }
1305 case fantasy.ToolResultContentTypeMedia:
1306 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1307 if !stringext.IsValidBase64(r.Data) {
1308 slog.Warn(
1309 "Tool returned media with invalid base64 data, discarding image",
1310 "tool", result.ToolName,
1311 "tool_call_id", result.ToolCallID,
1312 )
1313 baseResult.Content = "Tool returned image data with invalid encoding"
1314 baseResult.IsError = true
1315 } else {
1316 content := r.Text
1317 if content == "" {
1318 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1319 }
1320 baseResult.Content = content
1321 baseResult.Data = r.Data
1322 baseResult.MIMEType = r.MediaType
1323 }
1324 }
1325 }
1326
1327 return baseResult
1328}
1329
1330// workaroundProviderMediaLimitations converts media content in tool results to
1331// user messages for providers that don't natively support images in tool results.
1332//
1333// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1334// don't support sending images/media in tool result messages - they only accept
1335// text in tool results. However, they DO support images in user messages.
1336//
1337// If we send media in tool results to these providers, the API returns an error.
1338//
1339// Solution: For these providers, we:
1340// 1. Replace the media in the tool result with a text placeholder
1341// 2. Inject a user message immediately after with the image as a file attachment
1342// 3. This maintains the tool execution flow while working around API limitations
1343//
1344// Anthropic and Bedrock support images natively in tool results, so we skip
1345// this workaround for them.
1346//
1347// Example transformation:
1348//
1349// BEFORE: [tool result: image data]
1350// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1351func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1352 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1353 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1354
1355 if providerSupportsMedia {
1356 return messages
1357 }
1358
1359 convertedMessages := make([]fantasy.Message, 0, len(messages))
1360
1361 for _, msg := range messages {
1362 if msg.Role != fantasy.MessageRoleTool {
1363 convertedMessages = append(convertedMessages, msg)
1364 continue
1365 }
1366
1367 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1368 var mediaFiles []fantasy.FilePart
1369
1370 for _, part := range msg.Content {
1371 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1372 if !ok {
1373 textParts = append(textParts, part)
1374 continue
1375 }
1376
1377 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1378 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1379 if err != nil {
1380 slog.Warn("Failed to decode media data", "error", err)
1381 textParts = append(textParts, part)
1382 continue
1383 }
1384
1385 mediaFiles = append(mediaFiles, fantasy.FilePart{
1386 Data: decoded,
1387 MediaType: media.MediaType,
1388 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1389 })
1390
1391 textParts = append(textParts, fantasy.ToolResultPart{
1392 ToolCallID: toolResult.ToolCallID,
1393 Output: fantasy.ToolResultOutputContentText{
1394 Text: "[Image/media content loaded - see attached file]",
1395 },
1396 ProviderOptions: toolResult.ProviderOptions,
1397 })
1398 } else {
1399 textParts = append(textParts, part)
1400 }
1401 }
1402
1403 convertedMessages = append(convertedMessages, fantasy.Message{
1404 Role: fantasy.MessageRoleTool,
1405 Content: textParts,
1406 })
1407
1408 if len(mediaFiles) > 0 {
1409 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1410 "Here is the media content from the tool result:",
1411 mediaFiles...,
1412 ))
1413 }
1414 }
1415
1416 return convertedMessages
1417}
1418
1419// buildSummaryPrompt constructs the prompt text for session summarization.
1420func buildSummaryPrompt(todos []session.Todo) string {
1421 var sb strings.Builder
1422 sb.WriteString("Provide a detailed summary of our conversation above.")
1423 if len(todos) > 0 {
1424 sb.WriteString("\n\n## Current Todo List\n\n")
1425 for _, t := range todos {
1426 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1427 }
1428 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1429 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1430 }
1431 return sb.String()
1432}
1433
1434func providerRetryLogFields(err *fantasy.ProviderError, delay time.Duration) []any {
1435 fields := []any{
1436 "retry_delay", delay.String(),
1437 }
1438 if err == nil {
1439 return fields
1440 }
1441 fields = append(fields, "status_code", err.StatusCode)
1442 if err.Title != "" {
1443 fields = append(fields, "title", err.Title)
1444 }
1445 if err.Message != "" {
1446 fields = append(fields, "message", err.Message)
1447 }
1448 return fields
1449}