1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "charm.land/lipgloss/v2"
31 "github.com/charmbracelet/catwalk/pkg/catwalk"
32 "github.com/charmbracelet/crush/internal/agent/hyper"
33 "github.com/charmbracelet/crush/internal/agent/tools"
34 "github.com/charmbracelet/crush/internal/config"
35 "github.com/charmbracelet/crush/internal/csync"
36 "github.com/charmbracelet/crush/internal/message"
37 "github.com/charmbracelet/crush/internal/permission"
38 "github.com/charmbracelet/crush/internal/session"
39 "github.com/charmbracelet/crush/internal/stringext"
40)
41
42//go:embed templates/title.md
43var titlePrompt []byte
44
45//go:embed templates/summary.md
46var summaryPrompt []byte
47
48type SessionAgentCall struct {
49 SessionID string
50 Prompt string
51 ProviderOptions fantasy.ProviderOptions
52 Attachments []message.Attachment
53 MaxOutputTokens int64
54 Temperature *float64
55 TopP *float64
56 TopK *int64
57 FrequencyPenalty *float64
58 PresencePenalty *float64
59}
60
61type SessionAgent interface {
62 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
63 SetModels(large Model, small Model)
64 SetTools(tools []fantasy.AgentTool)
65 Cancel(sessionID string)
66 CancelAll()
67 IsSessionBusy(sessionID string) bool
68 IsBusy() bool
69 QueuedPrompts(sessionID string) int
70 QueuedPromptsList(sessionID string) []string
71 ClearQueue(sessionID string)
72 Summarize(context.Context, string, fantasy.ProviderOptions) error
73 Model() Model
74}
75
76type Model struct {
77 Model fantasy.LanguageModel
78 CatwalkCfg catwalk.Model
79 ModelCfg config.SelectedModel
80}
81
82type sessionAgent struct {
83 largeModel Model
84 smallModel Model
85 systemPromptPrefix string
86 systemPrompt string
87 isSubAgent bool
88 tools []fantasy.AgentTool
89 sessions session.Service
90 messages message.Service
91 disableAutoSummarize bool
92 isYolo bool
93
94 messageQueue *csync.Map[string, []SessionAgentCall]
95 activeRequests *csync.Map[string, context.CancelFunc]
96}
97
98type SessionAgentOptions struct {
99 LargeModel Model
100 SmallModel Model
101 SystemPromptPrefix string
102 SystemPrompt string
103 IsSubAgent bool
104 DisableAutoSummarize bool
105 IsYolo bool
106 Sessions session.Service
107 Messages message.Service
108 Tools []fantasy.AgentTool
109}
110
111func NewSessionAgent(
112 opts SessionAgentOptions,
113) SessionAgent {
114 return &sessionAgent{
115 largeModel: opts.LargeModel,
116 smallModel: opts.SmallModel,
117 systemPromptPrefix: opts.SystemPromptPrefix,
118 systemPrompt: opts.SystemPrompt,
119 isSubAgent: opts.IsSubAgent,
120 sessions: opts.Sessions,
121 messages: opts.Messages,
122 disableAutoSummarize: opts.DisableAutoSummarize,
123 tools: opts.Tools,
124 isYolo: opts.IsYolo,
125 messageQueue: csync.NewMap[string, []SessionAgentCall](),
126 activeRequests: csync.NewMap[string, context.CancelFunc](),
127 }
128}
129
130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
131 if call.Prompt == "" {
132 return nil, ErrEmptyPrompt
133 }
134 if call.SessionID == "" {
135 return nil, ErrSessionMissing
136 }
137
138 // Queue the message if busy
139 if a.IsSessionBusy(call.SessionID) {
140 existing, ok := a.messageQueue.Get(call.SessionID)
141 if !ok {
142 existing = []SessionAgentCall{}
143 }
144 existing = append(existing, call)
145 a.messageQueue.Set(call.SessionID, existing)
146 return nil, nil
147 }
148
149 if len(a.tools) > 0 {
150 // Add Anthropic caching to the last tool.
151 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
152 }
153
154 agent := fantasy.NewAgent(
155 a.largeModel.Model,
156 fantasy.WithSystemPrompt(a.systemPrompt),
157 fantasy.WithTools(a.tools...),
158 )
159
160 sessionLock := sync.Mutex{}
161 currentSession, err := a.sessions.Get(ctx, call.SessionID)
162 if err != nil {
163 return nil, fmt.Errorf("failed to get session: %w", err)
164 }
165
166 msgs, err := a.getSessionMessages(ctx, currentSession)
167 if err != nil {
168 return nil, fmt.Errorf("failed to get session messages: %w", err)
169 }
170
171 var wg sync.WaitGroup
172 // Generate title if first message.
173 if len(msgs) == 0 {
174 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
175 wg.Go(func() {
176 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
177 })
178 }
179
180 // Add the user message to the session.
181 _, err = a.createUserMessage(ctx, call)
182 if err != nil {
183 return nil, err
184 }
185
186 // Add the session to the context.
187 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
188
189 genCtx, cancel := context.WithCancel(ctx)
190 a.activeRequests.Set(call.SessionID, cancel)
191
192 defer cancel()
193 defer a.activeRequests.Del(call.SessionID)
194
195 history, files := a.preparePrompt(msgs, call.Attachments...)
196
197 startTime := time.Now()
198 a.eventPromptSent(call.SessionID)
199
200 var currentAssistant *message.Message
201 var shouldSummarize bool
202 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
203 Prompt: call.Prompt,
204 Files: files,
205 Messages: history,
206 ProviderOptions: call.ProviderOptions,
207 MaxOutputTokens: &call.MaxOutputTokens,
208 TopP: call.TopP,
209 Temperature: call.Temperature,
210 PresencePenalty: call.PresencePenalty,
211 TopK: call.TopK,
212 FrequencyPenalty: call.FrequencyPenalty,
213 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
214 prepared.Messages = options.Messages
215 for i := range prepared.Messages {
216 prepared.Messages[i].ProviderOptions = nil
217 }
218
219 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
220 a.messageQueue.Del(call.SessionID)
221 for _, queued := range queuedCalls {
222 userMessage, createErr := a.createUserMessage(callContext, queued)
223 if createErr != nil {
224 return callContext, prepared, createErr
225 }
226 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
227 }
228
229 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
230
231 lastSystemRoleInx := 0
232 systemMessageUpdated := false
233 for i, msg := range prepared.Messages {
234 // Only add cache control to the last message.
235 if msg.Role == fantasy.MessageRoleSystem {
236 lastSystemRoleInx = i
237 } else if !systemMessageUpdated {
238 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
239 systemMessageUpdated = true
240 }
241 // Than add cache control to the last 2 messages.
242 if i > len(prepared.Messages)-3 {
243 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
244 }
245 }
246
247 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
248 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
249 }
250
251 var assistantMsg message.Message
252 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
253 Role: message.Assistant,
254 Parts: []message.ContentPart{},
255 Model: a.largeModel.ModelCfg.Model,
256 Provider: a.largeModel.ModelCfg.Provider,
257 })
258 if err != nil {
259 return callContext, prepared, err
260 }
261 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
262 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
263 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
264 currentAssistant = &assistantMsg
265 return callContext, prepared, err
266 },
267 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
268 currentAssistant.AppendReasoningContent(reasoning.Text)
269 return a.messages.Update(genCtx, *currentAssistant)
270 },
271 OnReasoningDelta: func(id string, text string) error {
272 currentAssistant.AppendReasoningContent(text)
273 return a.messages.Update(genCtx, *currentAssistant)
274 },
275 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
276 // handle anthropic signature
277 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
278 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
279 currentAssistant.AppendReasoningSignature(reasoning.Signature)
280 }
281 }
282 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
283 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
284 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
285 }
286 }
287 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
288 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
289 currentAssistant.SetReasoningResponsesData(reasoning)
290 }
291 }
292 currentAssistant.FinishThinking()
293 return a.messages.Update(genCtx, *currentAssistant)
294 },
295 OnTextDelta: func(id string, text string) error {
296 // Strip leading newline from initial text content. This is is
297 // particularly important in non-interactive mode where leading
298 // newlines are very visible.
299 if len(currentAssistant.Parts) == 0 {
300 text = strings.TrimPrefix(text, "\n")
301 }
302
303 currentAssistant.AppendContent(text)
304 return a.messages.Update(genCtx, *currentAssistant)
305 },
306 OnToolInputStart: func(id string, toolName string) error {
307 toolCall := message.ToolCall{
308 ID: id,
309 Name: toolName,
310 ProviderExecuted: false,
311 Finished: false,
312 }
313 currentAssistant.AddToolCall(toolCall)
314 return a.messages.Update(genCtx, *currentAssistant)
315 },
316 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
317 // TODO: implement
318 },
319 OnToolCall: func(tc fantasy.ToolCallContent) error {
320 toolCall := message.ToolCall{
321 ID: tc.ToolCallID,
322 Name: tc.ToolName,
323 Input: tc.Input,
324 ProviderExecuted: false,
325 Finished: true,
326 }
327 currentAssistant.AddToolCall(toolCall)
328 return a.messages.Update(genCtx, *currentAssistant)
329 },
330 OnToolResult: func(result fantasy.ToolResultContent) error {
331 toolResult := a.convertToToolResult(result)
332 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
333 Role: message.Tool,
334 Parts: []message.ContentPart{
335 toolResult,
336 },
337 })
338 return createMsgErr
339 },
340 OnStepFinish: func(stepResult fantasy.StepResult) error {
341 finishReason := message.FinishReasonUnknown
342 switch stepResult.FinishReason {
343 case fantasy.FinishReasonLength:
344 finishReason = message.FinishReasonMaxTokens
345 case fantasy.FinishReasonStop:
346 finishReason = message.FinishReasonEndTurn
347 case fantasy.FinishReasonToolCalls:
348 finishReason = message.FinishReasonToolUse
349 }
350 currentAssistant.AddFinish(finishReason, "", "")
351 sessionLock.Lock()
352 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
353 if getSessionErr != nil {
354 sessionLock.Unlock()
355 return getSessionErr
356 }
357 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
358 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
359 sessionLock.Unlock()
360 if sessionErr != nil {
361 return sessionErr
362 }
363 return a.messages.Update(genCtx, *currentAssistant)
364 },
365 StopWhen: []fantasy.StopCondition{
366 func(_ []fantasy.StepResult) bool {
367 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
368 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
369 remaining := cw - tokens
370 var threshold int64
371 if cw > 200_000 {
372 threshold = 20_000
373 } else {
374 threshold = int64(float64(cw) * 0.2)
375 }
376 if (remaining <= threshold) && !a.disableAutoSummarize {
377 shouldSummarize = true
378 return true
379 }
380 return false
381 },
382 },
383 })
384
385 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
386
387 if err != nil {
388 isCancelErr := errors.Is(err, context.Canceled)
389 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
390 if currentAssistant == nil {
391 return result, err
392 }
393 // Ensure we finish thinking on error to close the reasoning state.
394 currentAssistant.FinishThinking()
395 toolCalls := currentAssistant.ToolCalls()
396 // INFO: we use the parent context here because the genCtx has been cancelled.
397 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
398 if createErr != nil {
399 return nil, createErr
400 }
401 for _, tc := range toolCalls {
402 if !tc.Finished {
403 tc.Finished = true
404 tc.Input = "{}"
405 currentAssistant.AddToolCall(tc)
406 updateErr := a.messages.Update(ctx, *currentAssistant)
407 if updateErr != nil {
408 return nil, updateErr
409 }
410 }
411
412 found := false
413 for _, msg := range msgs {
414 if msg.Role == message.Tool {
415 for _, tr := range msg.ToolResults() {
416 if tr.ToolCallID == tc.ID {
417 found = true
418 break
419 }
420 }
421 }
422 if found {
423 break
424 }
425 }
426 if found {
427 continue
428 }
429 content := "There was an error while executing the tool"
430 if isCancelErr {
431 content = "Tool execution canceled by user"
432 } else if isPermissionErr {
433 content = "User denied permission"
434 }
435 toolResult := message.ToolResult{
436 ToolCallID: tc.ID,
437 Name: tc.Name,
438 Content: content,
439 IsError: true,
440 }
441 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
442 Role: message.Tool,
443 Parts: []message.ContentPart{
444 toolResult,
445 },
446 })
447 if createErr != nil {
448 return nil, createErr
449 }
450 }
451 var fantasyErr *fantasy.Error
452 var providerErr *fantasy.ProviderError
453 const defaultTitle = "Provider Error"
454 if isCancelErr {
455 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
456 } else if isPermissionErr {
457 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
458 } else if errors.Is(err, hyper.ErrNoCredits) {
459 url := hyper.BaseURL()
460 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
461 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
462 } else if errors.As(err, &providerErr) {
463 if providerErr.Message == "The requested model is not supported." {
464 url := "https://github.com/settings/copilot/features"
465 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
466 currentAssistant.AddFinish(
467 message.FinishReasonError,
468 "Copilot model not enabled",
469 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
470 )
471 } else {
472 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
473 }
474 } else if errors.As(err, &fantasyErr) {
475 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
476 } else {
477 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
478 }
479 // Note: we use the parent context here because the genCtx has been
480 // cancelled.
481 updateErr := a.messages.Update(ctx, *currentAssistant)
482 if updateErr != nil {
483 return nil, updateErr
484 }
485 return nil, err
486 }
487 wg.Wait()
488
489 if shouldSummarize {
490 a.activeRequests.Del(call.SessionID)
491 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
492 return nil, summarizeErr
493 }
494 // If the agent wasn't done...
495 if len(currentAssistant.ToolCalls()) > 0 {
496 existing, ok := a.messageQueue.Get(call.SessionID)
497 if !ok {
498 existing = []SessionAgentCall{}
499 }
500 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
501 existing = append(existing, call)
502 a.messageQueue.Set(call.SessionID, existing)
503 }
504 }
505
506 // Release active request before processing queued messages.
507 a.activeRequests.Del(call.SessionID)
508 cancel()
509
510 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
511 if !ok || len(queuedMessages) == 0 {
512 return result, err
513 }
514 // There are queued messages restart the loop.
515 firstQueuedMessage := queuedMessages[0]
516 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
517 return a.Run(ctx, firstQueuedMessage)
518}
519
520func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
521 if a.IsSessionBusy(sessionID) {
522 return ErrSessionBusy
523 }
524
525 currentSession, err := a.sessions.Get(ctx, sessionID)
526 if err != nil {
527 return fmt.Errorf("failed to get session: %w", err)
528 }
529 msgs, err := a.getSessionMessages(ctx, currentSession)
530 if err != nil {
531 return err
532 }
533 if len(msgs) == 0 {
534 // Nothing to summarize.
535 return nil
536 }
537
538 aiMsgs, _ := a.preparePrompt(msgs)
539
540 genCtx, cancel := context.WithCancel(ctx)
541 a.activeRequests.Set(sessionID, cancel)
542 defer a.activeRequests.Del(sessionID)
543 defer cancel()
544
545 agent := fantasy.NewAgent(a.largeModel.Model,
546 fantasy.WithSystemPrompt(string(summaryPrompt)),
547 )
548 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
549 Role: message.Assistant,
550 Model: a.largeModel.Model.Model(),
551 Provider: a.largeModel.Model.Provider(),
552 IsSummaryMessage: true,
553 })
554 if err != nil {
555 return err
556 }
557
558 summaryPromptText := "Provide a detailed summary of our conversation above."
559 if len(currentSession.Todos) > 0 {
560 summaryPromptText += "\n\n## Current Todo List\n\n"
561 for _, t := range currentSession.Todos {
562 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
563 }
564 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
565 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
566 }
567
568 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
569 Prompt: summaryPromptText,
570 Messages: aiMsgs,
571 ProviderOptions: opts,
572 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
573 prepared.Messages = options.Messages
574 if a.systemPromptPrefix != "" {
575 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
576 }
577 return callContext, prepared, nil
578 },
579 OnReasoningDelta: func(id string, text string) error {
580 summaryMessage.AppendReasoningContent(text)
581 return a.messages.Update(genCtx, summaryMessage)
582 },
583 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
584 // Handle anthropic signature.
585 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
586 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
587 summaryMessage.AppendReasoningSignature(signature.Signature)
588 }
589 }
590 summaryMessage.FinishThinking()
591 return a.messages.Update(genCtx, summaryMessage)
592 },
593 OnTextDelta: func(id, text string) error {
594 summaryMessage.AppendContent(text)
595 return a.messages.Update(genCtx, summaryMessage)
596 },
597 })
598 if err != nil {
599 isCancelErr := errors.Is(err, context.Canceled)
600 if isCancelErr {
601 // User cancelled summarize we need to remove the summary message.
602 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
603 return deleteErr
604 }
605 return err
606 }
607
608 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
609 err = a.messages.Update(genCtx, summaryMessage)
610 if err != nil {
611 return err
612 }
613
614 var openrouterCost *float64
615 for _, step := range resp.Steps {
616 stepCost := a.openrouterCost(step.ProviderMetadata)
617 if stepCost != nil {
618 newCost := *stepCost
619 if openrouterCost != nil {
620 newCost += *openrouterCost
621 }
622 openrouterCost = &newCost
623 }
624 }
625
626 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
627
628 // Just in case, get just the last usage info.
629 usage := resp.Response.Usage
630 currentSession.SummaryMessageID = summaryMessage.ID
631 currentSession.CompletionTokens = usage.OutputTokens
632 currentSession.PromptTokens = 0
633 _, err = a.sessions.Save(genCtx, currentSession)
634 return err
635}
636
637func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
638 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
639 return fantasy.ProviderOptions{}
640 }
641 return fantasy.ProviderOptions{
642 anthropic.Name: &anthropic.ProviderCacheControlOptions{
643 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
644 },
645 bedrock.Name: &anthropic.ProviderCacheControlOptions{
646 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
647 },
648 }
649}
650
651func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
652 var attachmentParts []message.ContentPart
653 for _, attachment := range call.Attachments {
654 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
655 }
656 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
657 parts = append(parts, attachmentParts...)
658 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
659 Role: message.User,
660 Parts: parts,
661 })
662 if err != nil {
663 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
664 }
665 return msg, nil
666}
667
668func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
669 var history []fantasy.Message
670 if !a.isSubAgent {
671 history = append(history, fantasy.NewUserMessage(
672 fmt.Sprintf("<system_reminder>%s</system_reminder>",
673 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
674If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
675If not, please feel free to ignore. Again do not mention this message to the user.`,
676 ),
677 ))
678 }
679 for _, m := range msgs {
680 if len(m.Parts) == 0 {
681 continue
682 }
683 // Assistant message without content or tool calls (cancelled before it
684 // returned anything).
685 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
686 continue
687 }
688 history = append(history, m.ToAIMessage()...)
689 }
690
691 var files []fantasy.FilePart
692 for _, attachment := range attachments {
693 files = append(files, fantasy.FilePart{
694 Filename: attachment.FileName,
695 Data: attachment.Content,
696 MediaType: attachment.MimeType,
697 })
698 }
699
700 return history, files
701}
702
703func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
704 msgs, err := a.messages.List(ctx, session.ID)
705 if err != nil {
706 return nil, fmt.Errorf("failed to list messages: %w", err)
707 }
708
709 if session.SummaryMessageID != "" {
710 summaryMsgInex := -1
711 for i, msg := range msgs {
712 if msg.ID == session.SummaryMessageID {
713 summaryMsgInex = i
714 break
715 }
716 }
717 if summaryMsgInex != -1 {
718 msgs = msgs[summaryMsgInex:]
719 msgs[0].Role = message.User
720 }
721 }
722 return msgs, nil
723}
724
725func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
726 if prompt == "" {
727 return
728 }
729
730 var maxOutput int64 = 40
731 if a.smallModel.CatwalkCfg.CanReason {
732 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
733 }
734
735 agent := fantasy.NewAgent(a.smallModel.Model,
736 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
737 fantasy.WithMaxOutputTokens(maxOutput),
738 )
739
740 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
741 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
742 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
743 prepared.Messages = options.Messages
744 if a.systemPromptPrefix != "" {
745 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
746 }
747 return callContext, prepared, nil
748 },
749 })
750 if err != nil {
751 slog.Error("error generating title", "err", err)
752 return
753 }
754
755 title := resp.Response.Content.Text()
756
757 title = strings.ReplaceAll(title, "\n", " ")
758
759 // Remove thinking tags if present.
760 if idx := strings.Index(title, "</think>"); idx > 0 {
761 title = title[idx+len("</think>"):]
762 }
763
764 title = strings.TrimSpace(title)
765 if title == "" {
766 slog.Warn("failed to generate title", "warn", "empty title")
767 return
768 }
769
770 // Calculate usage and cost.
771 var openrouterCost *float64
772 for _, step := range resp.Steps {
773 stepCost := a.openrouterCost(step.ProviderMetadata)
774 if stepCost != nil {
775 newCost := *stepCost
776 if openrouterCost != nil {
777 newCost += *openrouterCost
778 }
779 openrouterCost = &newCost
780 }
781 }
782
783 modelConfig := a.smallModel.CatwalkCfg
784 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
785 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
786 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
787 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
788
789 if a.isClaudeCode() {
790 cost = 0
791 }
792
793 // Use override cost if available (e.g., from OpenRouter).
794 if openrouterCost != nil {
795 cost = *openrouterCost
796 }
797
798 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
799 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
800
801 // Atomically update only title and usage fields to avoid overriding other
802 // concurrent session updates.
803 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
804 if saveErr != nil {
805 slog.Error("failed to save session title & usage", "error", saveErr)
806 return
807 }
808}
809
810func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
811 openrouterMetadata, ok := metadata[openrouter.Name]
812 if !ok {
813 return nil
814 }
815
816 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
817 if !ok {
818 return nil
819 }
820 return &opts.Usage.Cost
821}
822
823func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
824 modelConfig := model.CatwalkCfg
825 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
826 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
827 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
828 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
829
830 if a.isClaudeCode() {
831 cost = 0
832 }
833
834 a.eventTokensUsed(session.ID, model, usage, cost)
835
836 if overrideCost != nil {
837 session.Cost += *overrideCost
838 } else {
839 session.Cost += cost
840 }
841
842 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
843 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
844}
845
846func (a *sessionAgent) Cancel(sessionID string) {
847 // Cancel regular requests.
848 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
849 slog.Info("Request cancellation initiated", "session_id", sessionID)
850 cancel()
851 }
852
853 // Also check for summarize requests.
854 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
855 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
856 cancel()
857 }
858
859 if a.QueuedPrompts(sessionID) > 0 {
860 slog.Info("Clearing queued prompts", "session_id", sessionID)
861 a.messageQueue.Del(sessionID)
862 }
863}
864
865func (a *sessionAgent) ClearQueue(sessionID string) {
866 if a.QueuedPrompts(sessionID) > 0 {
867 slog.Info("Clearing queued prompts", "session_id", sessionID)
868 a.messageQueue.Del(sessionID)
869 }
870}
871
872func (a *sessionAgent) CancelAll() {
873 if !a.IsBusy() {
874 return
875 }
876 for key := range a.activeRequests.Seq2() {
877 a.Cancel(key) // key is sessionID
878 }
879
880 timeout := time.After(5 * time.Second)
881 for a.IsBusy() {
882 select {
883 case <-timeout:
884 return
885 default:
886 time.Sleep(200 * time.Millisecond)
887 }
888 }
889}
890
891func (a *sessionAgent) IsBusy() bool {
892 var busy bool
893 for cancelFunc := range a.activeRequests.Seq() {
894 if cancelFunc != nil {
895 busy = true
896 break
897 }
898 }
899 return busy
900}
901
902func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
903 _, busy := a.activeRequests.Get(sessionID)
904 return busy
905}
906
907func (a *sessionAgent) QueuedPrompts(sessionID string) int {
908 l, ok := a.messageQueue.Get(sessionID)
909 if !ok {
910 return 0
911 }
912 return len(l)
913}
914
915func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
916 l, ok := a.messageQueue.Get(sessionID)
917 if !ok {
918 return nil
919 }
920 prompts := make([]string, len(l))
921 for i, call := range l {
922 prompts[i] = call.Prompt
923 }
924 return prompts
925}
926
927func (a *sessionAgent) SetModels(large Model, small Model) {
928 a.largeModel = large
929 a.smallModel = small
930}
931
932func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
933 a.tools = tools
934}
935
936func (a *sessionAgent) Model() Model {
937 return a.largeModel
938}
939
940func (a *sessionAgent) promptPrefix() string {
941 if a.isClaudeCode() {
942 return "You are Claude Code, Anthropic's official CLI for Claude."
943 }
944 return a.systemPromptPrefix
945}
946
947func (a *sessionAgent) isClaudeCode() bool {
948 cfg := config.Get()
949 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
950 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
951}
952
953// convertToToolResult converts a fantasy tool result to a message tool result.
954func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
955 baseResult := message.ToolResult{
956 ToolCallID: result.ToolCallID,
957 Name: result.ToolName,
958 Metadata: result.ClientMetadata,
959 }
960
961 switch result.Result.GetType() {
962 case fantasy.ToolResultContentTypeText:
963 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
964 baseResult.Content = r.Text
965 }
966 case fantasy.ToolResultContentTypeError:
967 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
968 baseResult.Content = r.Error.Error()
969 baseResult.IsError = true
970 }
971 case fantasy.ToolResultContentTypeMedia:
972 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
973 content := r.Text
974 if content == "" {
975 content = fmt.Sprintf("Loaded %s content", r.MediaType)
976 }
977 baseResult.Content = content
978 baseResult.Data = r.Data
979 baseResult.MIMEType = r.MediaType
980 }
981 }
982
983 return baseResult
984}
985
986// workaroundProviderMediaLimitations converts media content in tool results to
987// user messages for providers that don't natively support images in tool results.
988//
989// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
990// don't support sending images/media in tool result messages - they only accept
991// text in tool results. However, they DO support images in user messages.
992//
993// If we send media in tool results to these providers, the API returns an error.
994//
995// Solution: For these providers, we:
996// 1. Replace the media in the tool result with a text placeholder
997// 2. Inject a user message immediately after with the image as a file attachment
998// 3. This maintains the tool execution flow while working around API limitations
999//
1000// Anthropic and Bedrock support images natively in tool results, so we skip
1001// this workaround for them.
1002//
1003// Example transformation:
1004//
1005// BEFORE: [tool result: image data]
1006// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1007func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1008 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1009 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1010
1011 if providerSupportsMedia {
1012 return messages
1013 }
1014
1015 convertedMessages := make([]fantasy.Message, 0, len(messages))
1016
1017 for _, msg := range messages {
1018 if msg.Role != fantasy.MessageRoleTool {
1019 convertedMessages = append(convertedMessages, msg)
1020 continue
1021 }
1022
1023 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1024 var mediaFiles []fantasy.FilePart
1025
1026 for _, part := range msg.Content {
1027 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1028 if !ok {
1029 textParts = append(textParts, part)
1030 continue
1031 }
1032
1033 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1034 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1035 if err != nil {
1036 slog.Warn("failed to decode media data", "error", err)
1037 textParts = append(textParts, part)
1038 continue
1039 }
1040
1041 mediaFiles = append(mediaFiles, fantasy.FilePart{
1042 Data: decoded,
1043 MediaType: media.MediaType,
1044 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1045 })
1046
1047 textParts = append(textParts, fantasy.ToolResultPart{
1048 ToolCallID: toolResult.ToolCallID,
1049 Output: fantasy.ToolResultOutputContentText{
1050 Text: "[Image/media content loaded - see attached file]",
1051 },
1052 ProviderOptions: toolResult.ProviderOptions,
1053 })
1054 } else {
1055 textParts = append(textParts, part)
1056 }
1057 }
1058
1059 convertedMessages = append(convertedMessages, fantasy.Message{
1060 Role: fantasy.MessageRoleTool,
1061 Content: textParts,
1062 })
1063
1064 if len(mediaFiles) > 0 {
1065 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1066 "Here is the media content from the tool result:",
1067 mediaFiles...,
1068 ))
1069 }
1070 }
1071
1072 return convertedMessages
1073}