1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "charm.land/lipgloss/v2"
31 "github.com/charmbracelet/catwalk/pkg/catwalk"
32 "github.com/charmbracelet/crush/internal/agent/hyper"
33 "github.com/charmbracelet/crush/internal/agent/tools"
34 "github.com/charmbracelet/crush/internal/config"
35 "github.com/charmbracelet/crush/internal/csync"
36 "github.com/charmbracelet/crush/internal/message"
37 "github.com/charmbracelet/crush/internal/permission"
38 "github.com/charmbracelet/crush/internal/session"
39 "github.com/charmbracelet/crush/internal/stringext"
40)
41
42//go:embed templates/title.md
43var titlePrompt []byte
44
45//go:embed templates/summary.md
46var summaryPrompt []byte
47
48type SessionAgentCall struct {
49 SessionID string
50 Prompt string
51 ProviderOptions fantasy.ProviderOptions
52 Attachments []message.Attachment
53 MaxOutputTokens int64
54 Temperature *float64
55 TopP *float64
56 TopK *int64
57 FrequencyPenalty *float64
58 PresencePenalty *float64
59}
60
61type SessionAgent interface {
62 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
63 SetModels(large Model, small Model)
64 SetTools(tools []fantasy.AgentTool)
65 Cancel(sessionID string)
66 CancelAll()
67 IsSessionBusy(sessionID string) bool
68 IsBusy() bool
69 QueuedPrompts(sessionID string) int
70 QueuedPromptsList(sessionID string) []string
71 ClearQueue(sessionID string)
72 Summarize(context.Context, string, fantasy.ProviderOptions) error
73 Model() Model
74}
75
76type Model struct {
77 Model fantasy.LanguageModel
78 CatwalkCfg catwalk.Model
79 ModelCfg config.SelectedModel
80}
81
82type sessionAgent struct {
83 largeModel Model
84 smallModel Model
85 systemPromptPrefix string
86 systemPrompt string
87 isSubAgent bool
88 tools []fantasy.AgentTool
89 sessions session.Service
90 messages message.Service
91 disableAutoSummarize bool
92 isYolo bool
93
94 messageQueue *csync.Map[string, []SessionAgentCall]
95 activeRequests *csync.Map[string, context.CancelFunc]
96}
97
98type SessionAgentOptions struct {
99 LargeModel Model
100 SmallModel Model
101 SystemPromptPrefix string
102 SystemPrompt string
103 IsSubAgent bool
104 DisableAutoSummarize bool
105 IsYolo bool
106 Sessions session.Service
107 Messages message.Service
108 Tools []fantasy.AgentTool
109}
110
111func NewSessionAgent(
112 opts SessionAgentOptions,
113) SessionAgent {
114 return &sessionAgent{
115 largeModel: opts.LargeModel,
116 smallModel: opts.SmallModel,
117 systemPromptPrefix: opts.SystemPromptPrefix,
118 systemPrompt: opts.SystemPrompt,
119 isSubAgent: opts.IsSubAgent,
120 sessions: opts.Sessions,
121 messages: opts.Messages,
122 disableAutoSummarize: opts.DisableAutoSummarize,
123 tools: opts.Tools,
124 isYolo: opts.IsYolo,
125 messageQueue: csync.NewMap[string, []SessionAgentCall](),
126 activeRequests: csync.NewMap[string, context.CancelFunc](),
127 }
128}
129
130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
131 if call.Prompt == "" {
132 return nil, ErrEmptyPrompt
133 }
134 if call.SessionID == "" {
135 return nil, ErrSessionMissing
136 }
137
138 // Queue the message if busy
139 if a.IsSessionBusy(call.SessionID) {
140 existing, ok := a.messageQueue.Get(call.SessionID)
141 if !ok {
142 existing = []SessionAgentCall{}
143 }
144 existing = append(existing, call)
145 a.messageQueue.Set(call.SessionID, existing)
146 return nil, nil
147 }
148
149 if len(a.tools) > 0 {
150 // Add Anthropic caching to the last tool.
151 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
152 }
153
154 agent := fantasy.NewAgent(
155 a.largeModel.Model,
156 fantasy.WithSystemPrompt(a.systemPrompt),
157 fantasy.WithTools(a.tools...),
158 )
159
160 sessionLock := sync.Mutex{}
161 currentSession, err := a.sessions.Get(ctx, call.SessionID)
162 if err != nil {
163 return nil, fmt.Errorf("failed to get session: %w", err)
164 }
165
166 msgs, err := a.getSessionMessages(ctx, currentSession)
167 if err != nil {
168 return nil, fmt.Errorf("failed to get session messages: %w", err)
169 }
170
171 var wg sync.WaitGroup
172 // Generate title if first message.
173 if len(msgs) == 0 {
174 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
175 wg.Go(func() {
176 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
177 })
178 }
179
180 // Add the user message to the session.
181 _, err = a.createUserMessage(ctx, call)
182 if err != nil {
183 return nil, err
184 }
185
186 // Add the session to the context.
187 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
188
189 genCtx, cancel := context.WithCancel(ctx)
190 a.activeRequests.Set(call.SessionID, cancel)
191
192 defer cancel()
193 defer a.activeRequests.Del(call.SessionID)
194
195 history, files := a.preparePrompt(msgs, call.Attachments...)
196
197 startTime := time.Now()
198 a.eventPromptSent(call.SessionID)
199
200 var currentAssistant *message.Message
201 var shouldSummarize bool
202 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
203 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
204 Files: files,
205 Messages: history,
206 ProviderOptions: call.ProviderOptions,
207 MaxOutputTokens: &call.MaxOutputTokens,
208 TopP: call.TopP,
209 Temperature: call.Temperature,
210 PresencePenalty: call.PresencePenalty,
211 TopK: call.TopK,
212 FrequencyPenalty: call.FrequencyPenalty,
213 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
214 prepared.Messages = options.Messages
215 for i := range prepared.Messages {
216 prepared.Messages[i].ProviderOptions = nil
217 }
218
219 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
220 a.messageQueue.Del(call.SessionID)
221 for _, queued := range queuedCalls {
222 userMessage, createErr := a.createUserMessage(callContext, queued)
223 if createErr != nil {
224 return callContext, prepared, createErr
225 }
226 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
227 }
228
229 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
230
231 lastSystemRoleInx := 0
232 systemMessageUpdated := false
233 for i, msg := range prepared.Messages {
234 // Only add cache control to the last message.
235 if msg.Role == fantasy.MessageRoleSystem {
236 lastSystemRoleInx = i
237 } else if !systemMessageUpdated {
238 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
239 systemMessageUpdated = true
240 }
241 // Than add cache control to the last 2 messages.
242 if i > len(prepared.Messages)-3 {
243 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
244 }
245 }
246
247 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
248 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
249 }
250
251 var assistantMsg message.Message
252 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
253 Role: message.Assistant,
254 Parts: []message.ContentPart{},
255 Model: a.largeModel.ModelCfg.Model,
256 Provider: a.largeModel.ModelCfg.Provider,
257 })
258 if err != nil {
259 return callContext, prepared, err
260 }
261 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
262 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
263 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
264 currentAssistant = &assistantMsg
265 return callContext, prepared, err
266 },
267 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
268 currentAssistant.AppendReasoningContent(reasoning.Text)
269 return a.messages.Update(genCtx, *currentAssistant)
270 },
271 OnReasoningDelta: func(id string, text string) error {
272 currentAssistant.AppendReasoningContent(text)
273 return a.messages.Update(genCtx, *currentAssistant)
274 },
275 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
276 // handle anthropic signature
277 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
278 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
279 currentAssistant.AppendReasoningSignature(reasoning.Signature)
280 }
281 }
282 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
283 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
284 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
285 }
286 }
287 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
288 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
289 currentAssistant.SetReasoningResponsesData(reasoning)
290 }
291 }
292 currentAssistant.FinishThinking()
293 return a.messages.Update(genCtx, *currentAssistant)
294 },
295 OnTextDelta: func(id string, text string) error {
296 // Strip leading newline from initial text content. This is is
297 // particularly important in non-interactive mode where leading
298 // newlines are very visible.
299 if len(currentAssistant.Parts) == 0 {
300 text = strings.TrimPrefix(text, "\n")
301 }
302
303 currentAssistant.AppendContent(text)
304 return a.messages.Update(genCtx, *currentAssistant)
305 },
306 OnToolInputStart: func(id string, toolName string) error {
307 toolCall := message.ToolCall{
308 ID: id,
309 Name: toolName,
310 ProviderExecuted: false,
311 Finished: false,
312 }
313 currentAssistant.AddToolCall(toolCall)
314 return a.messages.Update(genCtx, *currentAssistant)
315 },
316 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
317 // TODO: implement
318 },
319 OnToolCall: func(tc fantasy.ToolCallContent) error {
320 toolCall := message.ToolCall{
321 ID: tc.ToolCallID,
322 Name: tc.ToolName,
323 Input: tc.Input,
324 ProviderExecuted: false,
325 Finished: true,
326 }
327 currentAssistant.AddToolCall(toolCall)
328 return a.messages.Update(genCtx, *currentAssistant)
329 },
330 OnToolResult: func(result fantasy.ToolResultContent) error {
331 toolResult := a.convertToToolResult(result)
332 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
333 Role: message.Tool,
334 Parts: []message.ContentPart{
335 toolResult,
336 },
337 })
338 return createMsgErr
339 },
340 OnStepFinish: func(stepResult fantasy.StepResult) error {
341 finishReason := message.FinishReasonUnknown
342 switch stepResult.FinishReason {
343 case fantasy.FinishReasonLength:
344 finishReason = message.FinishReasonMaxTokens
345 case fantasy.FinishReasonStop:
346 finishReason = message.FinishReasonEndTurn
347 case fantasy.FinishReasonToolCalls:
348 finishReason = message.FinishReasonToolUse
349 }
350 currentAssistant.AddFinish(finishReason, "", "")
351 sessionLock.Lock()
352 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
353 if getSessionErr != nil {
354 sessionLock.Unlock()
355 return getSessionErr
356 }
357 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
358 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
359 sessionLock.Unlock()
360 if sessionErr != nil {
361 return sessionErr
362 }
363 return a.messages.Update(genCtx, *currentAssistant)
364 },
365 StopWhen: []fantasy.StopCondition{
366 func(_ []fantasy.StepResult) bool {
367 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
368 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
369 remaining := cw - tokens
370 var threshold int64
371 if cw > 200_000 {
372 threshold = 20_000
373 } else {
374 threshold = int64(float64(cw) * 0.2)
375 }
376 if (remaining <= threshold) && !a.disableAutoSummarize {
377 shouldSummarize = true
378 return true
379 }
380 return false
381 },
382 },
383 })
384
385 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
386
387 if err != nil {
388 isCancelErr := errors.Is(err, context.Canceled)
389 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
390 if currentAssistant == nil {
391 return result, err
392 }
393 // Ensure we finish thinking on error to close the reasoning state.
394 currentAssistant.FinishThinking()
395 toolCalls := currentAssistant.ToolCalls()
396 // INFO: we use the parent context here because the genCtx has been cancelled.
397 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
398 if createErr != nil {
399 return nil, createErr
400 }
401 for _, tc := range toolCalls {
402 if !tc.Finished {
403 tc.Finished = true
404 tc.Input = "{}"
405 currentAssistant.AddToolCall(tc)
406 updateErr := a.messages.Update(ctx, *currentAssistant)
407 if updateErr != nil {
408 return nil, updateErr
409 }
410 }
411
412 found := false
413 for _, msg := range msgs {
414 if msg.Role == message.Tool {
415 for _, tr := range msg.ToolResults() {
416 if tr.ToolCallID == tc.ID {
417 found = true
418 break
419 }
420 }
421 }
422 if found {
423 break
424 }
425 }
426 if found {
427 continue
428 }
429 content := "There was an error while executing the tool"
430 if isCancelErr {
431 content = "Tool execution canceled by user"
432 } else if isPermissionErr {
433 content = "User denied permission"
434 }
435 toolResult := message.ToolResult{
436 ToolCallID: tc.ID,
437 Name: tc.Name,
438 Content: content,
439 IsError: true,
440 }
441 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
442 Role: message.Tool,
443 Parts: []message.ContentPart{
444 toolResult,
445 },
446 })
447 if createErr != nil {
448 return nil, createErr
449 }
450 }
451 var fantasyErr *fantasy.Error
452 var providerErr *fantasy.ProviderError
453 const defaultTitle = "Provider Error"
454 if isCancelErr {
455 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
456 } else if isPermissionErr {
457 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
458 } else if errors.Is(err, hyper.ErrNoCredits) {
459 url := hyper.BaseURL()
460 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
461 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
462 } else if errors.As(err, &providerErr) {
463 if providerErr.Message == "The requested model is not supported." {
464 url := "https://github.com/settings/copilot/features"
465 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
466 currentAssistant.AddFinish(
467 message.FinishReasonError,
468 "Copilot model not enabled",
469 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
470 )
471 } else {
472 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
473 }
474 } else if errors.As(err, &fantasyErr) {
475 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
476 } else {
477 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
478 }
479 // Note: we use the parent context here because the genCtx has been
480 // cancelled.
481 updateErr := a.messages.Update(ctx, *currentAssistant)
482 if updateErr != nil {
483 return nil, updateErr
484 }
485 return nil, err
486 }
487 wg.Wait()
488
489 if shouldSummarize {
490 a.activeRequests.Del(call.SessionID)
491 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
492 return nil, summarizeErr
493 }
494 // If the agent wasn't done...
495 if len(currentAssistant.ToolCalls()) > 0 {
496 existing, ok := a.messageQueue.Get(call.SessionID)
497 if !ok {
498 existing = []SessionAgentCall{}
499 }
500 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
501 existing = append(existing, call)
502 a.messageQueue.Set(call.SessionID, existing)
503 }
504 }
505
506 // Release active request before processing queued messages.
507 a.activeRequests.Del(call.SessionID)
508 cancel()
509
510 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
511 if !ok || len(queuedMessages) == 0 {
512 return result, err
513 }
514 // There are queued messages restart the loop.
515 firstQueuedMessage := queuedMessages[0]
516 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
517 return a.Run(ctx, firstQueuedMessage)
518}
519
520func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
521 if a.IsSessionBusy(sessionID) {
522 return ErrSessionBusy
523 }
524
525 currentSession, err := a.sessions.Get(ctx, sessionID)
526 if err != nil {
527 return fmt.Errorf("failed to get session: %w", err)
528 }
529 msgs, err := a.getSessionMessages(ctx, currentSession)
530 if err != nil {
531 return err
532 }
533 if len(msgs) == 0 {
534 // Nothing to summarize.
535 return nil
536 }
537
538 aiMsgs, _ := a.preparePrompt(msgs)
539
540 genCtx, cancel := context.WithCancel(ctx)
541 a.activeRequests.Set(sessionID, cancel)
542 defer a.activeRequests.Del(sessionID)
543 defer cancel()
544
545 agent := fantasy.NewAgent(a.largeModel.Model,
546 fantasy.WithSystemPrompt(string(summaryPrompt)),
547 )
548 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
549 Role: message.Assistant,
550 Model: a.largeModel.Model.Model(),
551 Provider: a.largeModel.Model.Provider(),
552 IsSummaryMessage: true,
553 })
554 if err != nil {
555 return err
556 }
557
558 summaryPromptText := "Provide a detailed summary of our conversation above."
559 if len(currentSession.Todos) > 0 {
560 summaryPromptText += "\n\n## Current Todo List\n\n"
561 for _, t := range currentSession.Todos {
562 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
563 }
564 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
565 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
566 }
567
568 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
569 Prompt: summaryPromptText,
570 Messages: aiMsgs,
571 ProviderOptions: opts,
572 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
573 prepared.Messages = options.Messages
574 if a.systemPromptPrefix != "" {
575 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
576 }
577 return callContext, prepared, nil
578 },
579 OnReasoningDelta: func(id string, text string) error {
580 summaryMessage.AppendReasoningContent(text)
581 return a.messages.Update(genCtx, summaryMessage)
582 },
583 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
584 // Handle anthropic signature.
585 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
586 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
587 summaryMessage.AppendReasoningSignature(signature.Signature)
588 }
589 }
590 summaryMessage.FinishThinking()
591 return a.messages.Update(genCtx, summaryMessage)
592 },
593 OnTextDelta: func(id, text string) error {
594 summaryMessage.AppendContent(text)
595 return a.messages.Update(genCtx, summaryMessage)
596 },
597 })
598 if err != nil {
599 isCancelErr := errors.Is(err, context.Canceled)
600 if isCancelErr {
601 // User cancelled summarize we need to remove the summary message.
602 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
603 return deleteErr
604 }
605 return err
606 }
607
608 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
609 err = a.messages.Update(genCtx, summaryMessage)
610 if err != nil {
611 return err
612 }
613
614 var openrouterCost *float64
615 for _, step := range resp.Steps {
616 stepCost := a.openrouterCost(step.ProviderMetadata)
617 if stepCost != nil {
618 newCost := *stepCost
619 if openrouterCost != nil {
620 newCost += *openrouterCost
621 }
622 openrouterCost = &newCost
623 }
624 }
625
626 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
627
628 // Just in case, get just the last usage info.
629 usage := resp.Response.Usage
630 currentSession.SummaryMessageID = summaryMessage.ID
631 currentSession.CompletionTokens = usage.OutputTokens
632 currentSession.PromptTokens = 0
633 _, err = a.sessions.Save(genCtx, currentSession)
634 return err
635}
636
637func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
638 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
639 return fantasy.ProviderOptions{}
640 }
641 return fantasy.ProviderOptions{
642 anthropic.Name: &anthropic.ProviderCacheControlOptions{
643 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
644 },
645 bedrock.Name: &anthropic.ProviderCacheControlOptions{
646 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
647 },
648 }
649}
650
651func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
652 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
653 var attachmentParts []message.ContentPart
654 for _, attachment := range call.Attachments {
655 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
656 }
657 parts = append(parts, attachmentParts...)
658 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
659 Role: message.User,
660 Parts: parts,
661 })
662 if err != nil {
663 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
664 }
665 return msg, nil
666}
667
668func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
669 var history []fantasy.Message
670 if !a.isSubAgent {
671 history = append(history, fantasy.NewUserMessage(
672 fmt.Sprintf("<system_reminder>%s</system_reminder>",
673 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
674If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
675If not, please feel free to ignore. Again do not mention this message to the user.`,
676 ),
677 ))
678 }
679 for _, m := range msgs {
680 if len(m.Parts) == 0 {
681 continue
682 }
683 // Assistant message without content or tool calls (cancelled before it
684 // returned anything).
685 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
686 continue
687 }
688 history = append(history, m.ToAIMessage()...)
689 }
690
691 var files []fantasy.FilePart
692 for _, attachment := range attachments {
693 if attachment.IsText() {
694 continue
695 }
696 files = append(files, fantasy.FilePart{
697 Filename: attachment.FileName,
698 Data: attachment.Content,
699 MediaType: attachment.MimeType,
700 })
701 }
702
703 return history, files
704}
705
706func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
707 msgs, err := a.messages.List(ctx, session.ID)
708 if err != nil {
709 return nil, fmt.Errorf("failed to list messages: %w", err)
710 }
711
712 if session.SummaryMessageID != "" {
713 summaryMsgInex := -1
714 for i, msg := range msgs {
715 if msg.ID == session.SummaryMessageID {
716 summaryMsgInex = i
717 break
718 }
719 }
720 if summaryMsgInex != -1 {
721 msgs = msgs[summaryMsgInex:]
722 msgs[0].Role = message.User
723 }
724 }
725 return msgs, nil
726}
727
728func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
729 if prompt == "" {
730 return
731 }
732
733 var maxOutput int64 = 40
734 if a.smallModel.CatwalkCfg.CanReason {
735 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
736 }
737
738 agent := fantasy.NewAgent(a.smallModel.Model,
739 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
740 fantasy.WithMaxOutputTokens(maxOutput),
741 )
742
743 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
744 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
745 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
746 prepared.Messages = options.Messages
747 if a.systemPromptPrefix != "" {
748 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
749 }
750 return callContext, prepared, nil
751 },
752 })
753 if err != nil {
754 slog.Error("error generating title", "err", err)
755 return
756 }
757
758 title := resp.Response.Content.Text()
759
760 title = strings.ReplaceAll(title, "\n", " ")
761
762 // Remove thinking tags if present.
763 if idx := strings.Index(title, "</think>"); idx > 0 {
764 title = title[idx+len("</think>"):]
765 }
766
767 title = strings.TrimSpace(title)
768 if title == "" {
769 slog.Warn("failed to generate title", "warn", "empty title")
770 return
771 }
772
773 // Calculate usage and cost.
774 var openrouterCost *float64
775 for _, step := range resp.Steps {
776 stepCost := a.openrouterCost(step.ProviderMetadata)
777 if stepCost != nil {
778 newCost := *stepCost
779 if openrouterCost != nil {
780 newCost += *openrouterCost
781 }
782 openrouterCost = &newCost
783 }
784 }
785
786 modelConfig := a.smallModel.CatwalkCfg
787 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
788 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
789 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
790 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
791
792 if a.isClaudeCode() {
793 cost = 0
794 }
795
796 // Use override cost if available (e.g., from OpenRouter).
797 if openrouterCost != nil {
798 cost = *openrouterCost
799 }
800
801 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
802 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
803
804 // Atomically update only title and usage fields to avoid overriding other
805 // concurrent session updates.
806 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
807 if saveErr != nil {
808 slog.Error("failed to save session title & usage", "error", saveErr)
809 return
810 }
811}
812
813func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
814 openrouterMetadata, ok := metadata[openrouter.Name]
815 if !ok {
816 return nil
817 }
818
819 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
820 if !ok {
821 return nil
822 }
823 return &opts.Usage.Cost
824}
825
826func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
827 modelConfig := model.CatwalkCfg
828 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
829 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
830 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
831 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
832
833 if a.isClaudeCode() {
834 cost = 0
835 }
836
837 a.eventTokensUsed(session.ID, model, usage, cost)
838
839 if overrideCost != nil {
840 session.Cost += *overrideCost
841 } else {
842 session.Cost += cost
843 }
844
845 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
846 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
847}
848
849func (a *sessionAgent) Cancel(sessionID string) {
850 // Cancel regular requests.
851 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
852 slog.Info("Request cancellation initiated", "session_id", sessionID)
853 cancel()
854 }
855
856 // Also check for summarize requests.
857 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
858 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
859 cancel()
860 }
861
862 if a.QueuedPrompts(sessionID) > 0 {
863 slog.Info("Clearing queued prompts", "session_id", sessionID)
864 a.messageQueue.Del(sessionID)
865 }
866}
867
868func (a *sessionAgent) ClearQueue(sessionID string) {
869 if a.QueuedPrompts(sessionID) > 0 {
870 slog.Info("Clearing queued prompts", "session_id", sessionID)
871 a.messageQueue.Del(sessionID)
872 }
873}
874
875func (a *sessionAgent) CancelAll() {
876 if !a.IsBusy() {
877 return
878 }
879 for key := range a.activeRequests.Seq2() {
880 a.Cancel(key) // key is sessionID
881 }
882
883 timeout := time.After(5 * time.Second)
884 for a.IsBusy() {
885 select {
886 case <-timeout:
887 return
888 default:
889 time.Sleep(200 * time.Millisecond)
890 }
891 }
892}
893
894func (a *sessionAgent) IsBusy() bool {
895 var busy bool
896 for cancelFunc := range a.activeRequests.Seq() {
897 if cancelFunc != nil {
898 busy = true
899 break
900 }
901 }
902 return busy
903}
904
905func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
906 _, busy := a.activeRequests.Get(sessionID)
907 return busy
908}
909
910func (a *sessionAgent) QueuedPrompts(sessionID string) int {
911 l, ok := a.messageQueue.Get(sessionID)
912 if !ok {
913 return 0
914 }
915 return len(l)
916}
917
918func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
919 l, ok := a.messageQueue.Get(sessionID)
920 if !ok {
921 return nil
922 }
923 prompts := make([]string, len(l))
924 for i, call := range l {
925 prompts[i] = call.Prompt
926 }
927 return prompts
928}
929
930func (a *sessionAgent) SetModels(large Model, small Model) {
931 a.largeModel = large
932 a.smallModel = small
933}
934
935func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
936 a.tools = tools
937}
938
939func (a *sessionAgent) Model() Model {
940 return a.largeModel
941}
942
943func (a *sessionAgent) promptPrefix() string {
944 if a.isClaudeCode() {
945 return "You are Claude Code, Anthropic's official CLI for Claude."
946 }
947 return a.systemPromptPrefix
948}
949
950func (a *sessionAgent) isClaudeCode() bool {
951 cfg := config.Get()
952 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
953 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
954}
955
956// convertToToolResult converts a fantasy tool result to a message tool result.
957func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
958 baseResult := message.ToolResult{
959 ToolCallID: result.ToolCallID,
960 Name: result.ToolName,
961 Metadata: result.ClientMetadata,
962 }
963
964 switch result.Result.GetType() {
965 case fantasy.ToolResultContentTypeText:
966 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
967 baseResult.Content = r.Text
968 }
969 case fantasy.ToolResultContentTypeError:
970 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
971 baseResult.Content = r.Error.Error()
972 baseResult.IsError = true
973 }
974 case fantasy.ToolResultContentTypeMedia:
975 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
976 content := r.Text
977 if content == "" {
978 content = fmt.Sprintf("Loaded %s content", r.MediaType)
979 }
980 baseResult.Content = content
981 baseResult.Data = r.Data
982 baseResult.MIMEType = r.MediaType
983 }
984 }
985
986 return baseResult
987}
988
989// workaroundProviderMediaLimitations converts media content in tool results to
990// user messages for providers that don't natively support images in tool results.
991//
992// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
993// don't support sending images/media in tool result messages - they only accept
994// text in tool results. However, they DO support images in user messages.
995//
996// If we send media in tool results to these providers, the API returns an error.
997//
998// Solution: For these providers, we:
999// 1. Replace the media in the tool result with a text placeholder
1000// 2. Inject a user message immediately after with the image as a file attachment
1001// 3. This maintains the tool execution flow while working around API limitations
1002//
1003// Anthropic and Bedrock support images natively in tool results, so we skip
1004// this workaround for them.
1005//
1006// Example transformation:
1007//
1008// BEFORE: [tool result: image data]
1009// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1010func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1011 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1012 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1013
1014 if providerSupportsMedia {
1015 return messages
1016 }
1017
1018 convertedMessages := make([]fantasy.Message, 0, len(messages))
1019
1020 for _, msg := range messages {
1021 if msg.Role != fantasy.MessageRoleTool {
1022 convertedMessages = append(convertedMessages, msg)
1023 continue
1024 }
1025
1026 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1027 var mediaFiles []fantasy.FilePart
1028
1029 for _, part := range msg.Content {
1030 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1031 if !ok {
1032 textParts = append(textParts, part)
1033 continue
1034 }
1035
1036 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1037 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1038 if err != nil {
1039 slog.Warn("failed to decode media data", "error", err)
1040 textParts = append(textParts, part)
1041 continue
1042 }
1043
1044 mediaFiles = append(mediaFiles, fantasy.FilePart{
1045 Data: decoded,
1046 MediaType: media.MediaType,
1047 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1048 })
1049
1050 textParts = append(textParts, fantasy.ToolResultPart{
1051 ToolCallID: toolResult.ToolCallID,
1052 Output: fantasy.ToolResultOutputContentText{
1053 Text: "[Image/media content loaded - see attached file]",
1054 },
1055 ProviderOptions: toolResult.ProviderOptions,
1056 })
1057 } else {
1058 textParts = append(textParts, part)
1059 }
1060 }
1061
1062 convertedMessages = append(convertedMessages, fantasy.Message{
1063 Role: fantasy.MessageRoleTool,
1064 Content: textParts,
1065 })
1066
1067 if len(mediaFiles) > 0 {
1068 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1069 "Here is the media content from the tool result:",
1070 mediaFiles...,
1071 ))
1072 }
1073 }
1074
1075 return convertedMessages
1076}