1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "charm.land/lipgloss/v2"
31 "github.com/charmbracelet/catwalk/pkg/catwalk"
32 "github.com/charmbracelet/crush/internal/agent/hyper"
33 "github.com/charmbracelet/crush/internal/agent/tools"
34 "github.com/charmbracelet/crush/internal/config"
35 "github.com/charmbracelet/crush/internal/csync"
36 "github.com/charmbracelet/crush/internal/message"
37 "github.com/charmbracelet/crush/internal/permission"
38 "github.com/charmbracelet/crush/internal/session"
39 "github.com/charmbracelet/crush/internal/stringext"
40)
41
42//go:embed templates/title.md
43var titlePrompt []byte
44
45//go:embed templates/summary.md
46var summaryPrompt []byte
47
48type SessionAgentCall struct {
49 SessionID string
50 Prompt string
51 ProviderOptions fantasy.ProviderOptions
52 Attachments []message.Attachment
53 MaxOutputTokens int64
54 Temperature *float64
55 TopP *float64
56 TopK *int64
57 FrequencyPenalty *float64
58 PresencePenalty *float64
59}
60
61type SessionAgent interface {
62 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
63 SetModels(large Model, small Model)
64 SetTools(tools []fantasy.AgentTool)
65 Cancel(sessionID string)
66 CancelAll()
67 IsSessionBusy(sessionID string) bool
68 IsBusy() bool
69 QueuedPrompts(sessionID string) int
70 QueuedPromptsList(sessionID string) []string
71 ClearQueue(sessionID string)
72 Summarize(context.Context, string, fantasy.ProviderOptions) error
73 Model() Model
74}
75
76type Model struct {
77 Model fantasy.LanguageModel
78 CatwalkCfg catwalk.Model
79 ModelCfg config.SelectedModel
80}
81
82type sessionAgent struct {
83 largeModel Model
84 smallModel Model
85 systemPromptPrefix string
86 systemPrompt string
87 isSubAgent bool
88 tools []fantasy.AgentTool
89 sessions session.Service
90 messages message.Service
91 disableAutoSummarize bool
92 isYolo bool
93
94 messageQueue *csync.Map[string, []SessionAgentCall]
95 activeRequests *csync.Map[string, context.CancelFunc]
96}
97
98type SessionAgentOptions struct {
99 LargeModel Model
100 SmallModel Model
101 SystemPromptPrefix string
102 SystemPrompt string
103 IsSubAgent bool
104 DisableAutoSummarize bool
105 IsYolo bool
106 Sessions session.Service
107 Messages message.Service
108 Tools []fantasy.AgentTool
109}
110
111func NewSessionAgent(
112 opts SessionAgentOptions,
113) SessionAgent {
114 return &sessionAgent{
115 largeModel: opts.LargeModel,
116 smallModel: opts.SmallModel,
117 systemPromptPrefix: opts.SystemPromptPrefix,
118 systemPrompt: opts.SystemPrompt,
119 isSubAgent: opts.IsSubAgent,
120 sessions: opts.Sessions,
121 messages: opts.Messages,
122 disableAutoSummarize: opts.DisableAutoSummarize,
123 tools: opts.Tools,
124 isYolo: opts.IsYolo,
125 messageQueue: csync.NewMap[string, []SessionAgentCall](),
126 activeRequests: csync.NewMap[string, context.CancelFunc](),
127 }
128}
129
130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
131 if call.Prompt == "" {
132 return nil, ErrEmptyPrompt
133 }
134 if call.SessionID == "" {
135 return nil, ErrSessionMissing
136 }
137
138 // Queue the message if busy
139 if a.IsSessionBusy(call.SessionID) {
140 existing, ok := a.messageQueue.Get(call.SessionID)
141 if !ok {
142 existing = []SessionAgentCall{}
143 }
144 existing = append(existing, call)
145 a.messageQueue.Set(call.SessionID, existing)
146 return nil, nil
147 }
148
149 if len(a.tools) > 0 {
150 // Add Anthropic caching to the last tool.
151 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
152 }
153
154 agent := fantasy.NewAgent(
155 a.largeModel.Model,
156 fantasy.WithSystemPrompt(a.systemPrompt),
157 fantasy.WithTools(a.tools...),
158 )
159
160 sessionLock := sync.Mutex{}
161 currentSession, err := a.sessions.Get(ctx, call.SessionID)
162 if err != nil {
163 return nil, fmt.Errorf("failed to get session: %w", err)
164 }
165
166 msgs, err := a.getSessionMessages(ctx, currentSession)
167 if err != nil {
168 return nil, fmt.Errorf("failed to get session messages: %w", err)
169 }
170
171 var wg sync.WaitGroup
172 // Generate title if first message.
173 if len(msgs) == 0 {
174 wg.Go(func() {
175 sessionLock.Lock()
176 a.generateTitle(ctx, ¤tSession, call.Prompt)
177 sessionLock.Unlock()
178 })
179 }
180
181 // Add the user message to the session.
182 _, err = a.createUserMessage(ctx, call)
183 if err != nil {
184 return nil, err
185 }
186
187 // Add the session to the context.
188 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
189
190 genCtx, cancel := context.WithCancel(ctx)
191 a.activeRequests.Set(call.SessionID, cancel)
192
193 defer cancel()
194 defer a.activeRequests.Del(call.SessionID)
195
196 history, files := a.preparePrompt(msgs, call.Attachments...)
197
198 startTime := time.Now()
199 a.eventPromptSent(call.SessionID)
200
201 var currentAssistant *message.Message
202 var shouldSummarize bool
203 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
204 Prompt: call.Prompt,
205 Files: files,
206 Messages: history,
207 ProviderOptions: call.ProviderOptions,
208 MaxOutputTokens: &call.MaxOutputTokens,
209 TopP: call.TopP,
210 Temperature: call.Temperature,
211 PresencePenalty: call.PresencePenalty,
212 TopK: call.TopK,
213 FrequencyPenalty: call.FrequencyPenalty,
214 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
215 prepared.Messages = options.Messages
216 for i := range prepared.Messages {
217 prepared.Messages[i].ProviderOptions = nil
218 }
219
220 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
221 a.messageQueue.Del(call.SessionID)
222 for _, queued := range queuedCalls {
223 userMessage, createErr := a.createUserMessage(callContext, queued)
224 if createErr != nil {
225 return callContext, prepared, createErr
226 }
227 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
228 }
229
230 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
231
232 lastSystemRoleInx := 0
233 systemMessageUpdated := false
234 for i, msg := range prepared.Messages {
235 // Only add cache control to the last message.
236 if msg.Role == fantasy.MessageRoleSystem {
237 lastSystemRoleInx = i
238 } else if !systemMessageUpdated {
239 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
240 systemMessageUpdated = true
241 }
242 // Than add cache control to the last 2 messages.
243 if i > len(prepared.Messages)-3 {
244 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
245 }
246 }
247
248 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
249 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
250 }
251
252 var assistantMsg message.Message
253 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
254 Role: message.Assistant,
255 Parts: []message.ContentPart{},
256 Model: a.largeModel.ModelCfg.Model,
257 Provider: a.largeModel.ModelCfg.Provider,
258 })
259 if err != nil {
260 return callContext, prepared, err
261 }
262 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
263 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
264 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
265 currentAssistant = &assistantMsg
266 return callContext, prepared, err
267 },
268 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
269 currentAssistant.AppendReasoningContent(reasoning.Text)
270 return a.messages.Update(genCtx, *currentAssistant)
271 },
272 OnReasoningDelta: func(id string, text string) error {
273 currentAssistant.AppendReasoningContent(text)
274 return a.messages.Update(genCtx, *currentAssistant)
275 },
276 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
277 // handle anthropic signature
278 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
279 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
280 currentAssistant.AppendReasoningSignature(reasoning.Signature)
281 }
282 }
283 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
284 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
285 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
286 }
287 }
288 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
289 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
290 currentAssistant.SetReasoningResponsesData(reasoning)
291 }
292 }
293 currentAssistant.FinishThinking()
294 return a.messages.Update(genCtx, *currentAssistant)
295 },
296 OnTextDelta: func(id string, text string) error {
297 // Strip leading newline from initial text content. This is is
298 // particularly important in non-interactive mode where leading
299 // newlines are very visible.
300 if len(currentAssistant.Parts) == 0 {
301 text = strings.TrimPrefix(text, "\n")
302 }
303
304 currentAssistant.AppendContent(text)
305 return a.messages.Update(genCtx, *currentAssistant)
306 },
307 OnToolInputStart: func(id string, toolName string) error {
308 toolCall := message.ToolCall{
309 ID: id,
310 Name: toolName,
311 ProviderExecuted: false,
312 Finished: false,
313 }
314 currentAssistant.AddToolCall(toolCall)
315 return a.messages.Update(genCtx, *currentAssistant)
316 },
317 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
318 // TODO: implement
319 },
320 OnToolCall: func(tc fantasy.ToolCallContent) error {
321 toolCall := message.ToolCall{
322 ID: tc.ToolCallID,
323 Name: tc.ToolName,
324 Input: tc.Input,
325 ProviderExecuted: false,
326 Finished: true,
327 }
328 currentAssistant.AddToolCall(toolCall)
329 return a.messages.Update(genCtx, *currentAssistant)
330 },
331 OnToolResult: func(result fantasy.ToolResultContent) error {
332 toolResult := a.convertToToolResult(result)
333 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
334 Role: message.Tool,
335 Parts: []message.ContentPart{
336 toolResult,
337 },
338 })
339 return createMsgErr
340 },
341 OnStepFinish: func(stepResult fantasy.StepResult) error {
342 finishReason := message.FinishReasonUnknown
343 switch stepResult.FinishReason {
344 case fantasy.FinishReasonLength:
345 finishReason = message.FinishReasonMaxTokens
346 case fantasy.FinishReasonStop:
347 finishReason = message.FinishReasonEndTurn
348 case fantasy.FinishReasonToolCalls:
349 finishReason = message.FinishReasonToolUse
350 }
351 currentAssistant.AddFinish(finishReason, "", "")
352 sessionLock.Lock()
353 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
354 if getSessionErr != nil {
355 sessionLock.Unlock()
356 return getSessionErr
357 }
358 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
359 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
360 sessionLock.Unlock()
361 if sessionErr != nil {
362 return sessionErr
363 }
364 return a.messages.Update(genCtx, *currentAssistant)
365 },
366 StopWhen: []fantasy.StopCondition{
367 func(_ []fantasy.StepResult) bool {
368 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
369 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
370 remaining := cw - tokens
371 var threshold int64
372 if cw > 200_000 {
373 threshold = 20_000
374 } else {
375 threshold = int64(float64(cw) * 0.2)
376 }
377 if (remaining <= threshold) && !a.disableAutoSummarize {
378 shouldSummarize = true
379 return true
380 }
381 return false
382 },
383 },
384 })
385
386 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
387
388 if err != nil {
389 isCancelErr := errors.Is(err, context.Canceled)
390 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
391 if currentAssistant == nil {
392 return result, err
393 }
394 // Ensure we finish thinking on error to close the reasoning state.
395 currentAssistant.FinishThinking()
396 toolCalls := currentAssistant.ToolCalls()
397 // INFO: we use the parent context here because the genCtx has been cancelled.
398 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
399 if createErr != nil {
400 return nil, createErr
401 }
402 for _, tc := range toolCalls {
403 if !tc.Finished {
404 tc.Finished = true
405 tc.Input = "{}"
406 currentAssistant.AddToolCall(tc)
407 updateErr := a.messages.Update(ctx, *currentAssistant)
408 if updateErr != nil {
409 return nil, updateErr
410 }
411 }
412
413 found := false
414 for _, msg := range msgs {
415 if msg.Role == message.Tool {
416 for _, tr := range msg.ToolResults() {
417 if tr.ToolCallID == tc.ID {
418 found = true
419 break
420 }
421 }
422 }
423 if found {
424 break
425 }
426 }
427 if found {
428 continue
429 }
430 content := "There was an error while executing the tool"
431 if isCancelErr {
432 content = "Tool execution canceled by user"
433 } else if isPermissionErr {
434 content = "User denied permission"
435 }
436 toolResult := message.ToolResult{
437 ToolCallID: tc.ID,
438 Name: tc.Name,
439 Content: content,
440 IsError: true,
441 }
442 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
443 Role: message.Tool,
444 Parts: []message.ContentPart{
445 toolResult,
446 },
447 })
448 if createErr != nil {
449 return nil, createErr
450 }
451 }
452 var fantasyErr *fantasy.Error
453 var providerErr *fantasy.ProviderError
454 const defaultTitle = "Provider Error"
455 if isCancelErr {
456 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
457 } else if isPermissionErr {
458 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
459 } else if errors.Is(err, hyper.ErrNoCredits) {
460 url := hyper.BaseURL()
461 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
462 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
463 } else if errors.As(err, &providerErr) {
464 if providerErr.Message == "The requested model is not supported." {
465 url := "https://github.com/settings/copilot/features"
466 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
467 currentAssistant.AddFinish(
468 message.FinishReasonError,
469 "Copilot model not enabled",
470 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
471 )
472 } else {
473 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
474 }
475 } else if errors.As(err, &fantasyErr) {
476 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
477 } else {
478 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
479 }
480 // Note: we use the parent context here because the genCtx has been
481 // cancelled.
482 updateErr := a.messages.Update(ctx, *currentAssistant)
483 if updateErr != nil {
484 return nil, updateErr
485 }
486 return nil, err
487 }
488 wg.Wait()
489
490 if shouldSummarize {
491 a.activeRequests.Del(call.SessionID)
492 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
493 return nil, summarizeErr
494 }
495 // If the agent wasn't done...
496 if len(currentAssistant.ToolCalls()) > 0 {
497 existing, ok := a.messageQueue.Get(call.SessionID)
498 if !ok {
499 existing = []SessionAgentCall{}
500 }
501 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
502 existing = append(existing, call)
503 a.messageQueue.Set(call.SessionID, existing)
504 }
505 }
506
507 // Release active request before processing queued messages.
508 a.activeRequests.Del(call.SessionID)
509 cancel()
510
511 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
512 if !ok || len(queuedMessages) == 0 {
513 return result, err
514 }
515 // There are queued messages restart the loop.
516 firstQueuedMessage := queuedMessages[0]
517 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
518 return a.Run(ctx, firstQueuedMessage)
519}
520
521func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
522 if a.IsSessionBusy(sessionID) {
523 return ErrSessionBusy
524 }
525
526 currentSession, err := a.sessions.Get(ctx, sessionID)
527 if err != nil {
528 return fmt.Errorf("failed to get session: %w", err)
529 }
530 msgs, err := a.getSessionMessages(ctx, currentSession)
531 if err != nil {
532 return err
533 }
534 if len(msgs) == 0 {
535 // Nothing to summarize.
536 return nil
537 }
538
539 aiMsgs, _ := a.preparePrompt(msgs)
540
541 genCtx, cancel := context.WithCancel(ctx)
542 a.activeRequests.Set(sessionID, cancel)
543 defer a.activeRequests.Del(sessionID)
544 defer cancel()
545
546 agent := fantasy.NewAgent(a.largeModel.Model,
547 fantasy.WithSystemPrompt(string(summaryPrompt)),
548 )
549 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
550 Role: message.Assistant,
551 Model: a.largeModel.Model.Model(),
552 Provider: a.largeModel.Model.Provider(),
553 IsSummaryMessage: true,
554 })
555 if err != nil {
556 return err
557 }
558
559 summaryPromptText := "Provide a detailed summary of our conversation above."
560 if len(currentSession.Todos) > 0 {
561 summaryPromptText += "\n\n## Current Todo List\n\n"
562 for _, t := range currentSession.Todos {
563 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
564 }
565 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
566 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
567 }
568
569 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
570 Prompt: summaryPromptText,
571 Messages: aiMsgs,
572 ProviderOptions: opts,
573 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
574 prepared.Messages = options.Messages
575 if a.systemPromptPrefix != "" {
576 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
577 }
578 return callContext, prepared, nil
579 },
580 OnReasoningDelta: func(id string, text string) error {
581 summaryMessage.AppendReasoningContent(text)
582 return a.messages.Update(genCtx, summaryMessage)
583 },
584 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
585 // Handle anthropic signature.
586 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
587 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
588 summaryMessage.AppendReasoningSignature(signature.Signature)
589 }
590 }
591 summaryMessage.FinishThinking()
592 return a.messages.Update(genCtx, summaryMessage)
593 },
594 OnTextDelta: func(id, text string) error {
595 summaryMessage.AppendContent(text)
596 return a.messages.Update(genCtx, summaryMessage)
597 },
598 })
599 if err != nil {
600 isCancelErr := errors.Is(err, context.Canceled)
601 if isCancelErr {
602 // User cancelled summarize we need to remove the summary message.
603 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
604 return deleteErr
605 }
606 return err
607 }
608
609 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
610 err = a.messages.Update(genCtx, summaryMessage)
611 if err != nil {
612 return err
613 }
614
615 var openrouterCost *float64
616 for _, step := range resp.Steps {
617 stepCost := a.openrouterCost(step.ProviderMetadata)
618 if stepCost != nil {
619 newCost := *stepCost
620 if openrouterCost != nil {
621 newCost += *openrouterCost
622 }
623 openrouterCost = &newCost
624 }
625 }
626
627 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
628
629 // Just in case, get just the last usage info.
630 usage := resp.Response.Usage
631 currentSession.SummaryMessageID = summaryMessage.ID
632 currentSession.CompletionTokens = usage.OutputTokens
633 currentSession.PromptTokens = 0
634 _, err = a.sessions.Save(genCtx, currentSession)
635 return err
636}
637
638func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
639 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
640 return fantasy.ProviderOptions{}
641 }
642 return fantasy.ProviderOptions{
643 anthropic.Name: &anthropic.ProviderCacheControlOptions{
644 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
645 },
646 bedrock.Name: &anthropic.ProviderCacheControlOptions{
647 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
648 },
649 }
650}
651
652func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
653 var attachmentParts []message.ContentPart
654 for _, attachment := range call.Attachments {
655 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
656 }
657 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
658 parts = append(parts, attachmentParts...)
659 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
660 Role: message.User,
661 Parts: parts,
662 })
663 if err != nil {
664 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
665 }
666 return msg, nil
667}
668
669func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
670 var history []fantasy.Message
671 if !a.isSubAgent {
672 history = append(history, fantasy.NewUserMessage(
673 fmt.Sprintf("<system_reminder>%s</system_reminder>",
674 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
675If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
676If not, please feel free to ignore. Again do not mention this message to the user.`,
677 ),
678 ))
679 }
680 for _, m := range msgs {
681 if len(m.Parts) == 0 {
682 continue
683 }
684 // Assistant message without content or tool calls (cancelled before it
685 // returned anything).
686 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
687 continue
688 }
689 history = append(history, m.ToAIMessage()...)
690 }
691
692 var files []fantasy.FilePart
693 for _, attachment := range attachments {
694 files = append(files, fantasy.FilePart{
695 Filename: attachment.FileName,
696 Data: attachment.Content,
697 MediaType: attachment.MimeType,
698 })
699 }
700
701 return history, files
702}
703
704func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
705 msgs, err := a.messages.List(ctx, session.ID)
706 if err != nil {
707 return nil, fmt.Errorf("failed to list messages: %w", err)
708 }
709
710 if session.SummaryMessageID != "" {
711 summaryMsgInex := -1
712 for i, msg := range msgs {
713 if msg.ID == session.SummaryMessageID {
714 summaryMsgInex = i
715 break
716 }
717 }
718 if summaryMsgInex != -1 {
719 msgs = msgs[summaryMsgInex:]
720 msgs[0].Role = message.User
721 }
722 }
723 return msgs, nil
724}
725
726func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
727 if prompt == "" {
728 return
729 }
730
731 var maxOutput int64 = 40
732 if a.smallModel.CatwalkCfg.CanReason {
733 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
734 }
735
736 agent := fantasy.NewAgent(a.smallModel.Model,
737 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
738 fantasy.WithMaxOutputTokens(maxOutput),
739 )
740
741 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
742 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
743 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
744 prepared.Messages = options.Messages
745 if a.systemPromptPrefix != "" {
746 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
747 }
748 return callContext, prepared, nil
749 },
750 })
751 if err != nil {
752 slog.Error("error generating title", "err", err)
753 return
754 }
755
756 title := resp.Response.Content.Text()
757
758 title = strings.ReplaceAll(title, "\n", " ")
759
760 // Remove thinking tags if present.
761 if idx := strings.Index(title, "</think>"); idx > 0 {
762 title = title[idx+len("</think>"):]
763 }
764
765 title = strings.TrimSpace(title)
766 if title == "" {
767 slog.Warn("failed to generate title", "warn", "empty title")
768 return
769 }
770
771 session.Title = title
772
773 var openrouterCost *float64
774 for _, step := range resp.Steps {
775 stepCost := a.openrouterCost(step.ProviderMetadata)
776 if stepCost != nil {
777 newCost := *stepCost
778 if openrouterCost != nil {
779 newCost += *openrouterCost
780 }
781 openrouterCost = &newCost
782 }
783 }
784
785 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
786 _, saveErr := a.sessions.Save(ctx, *session)
787 if saveErr != nil {
788 slog.Error("failed to save session title & usage", "error", saveErr)
789 return
790 }
791}
792
793func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
794 openrouterMetadata, ok := metadata[openrouter.Name]
795 if !ok {
796 return nil
797 }
798
799 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
800 if !ok {
801 return nil
802 }
803 return &opts.Usage.Cost
804}
805
806func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
807 modelConfig := model.CatwalkCfg
808 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
809 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
810 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
811 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
812
813 if a.isClaudeCode() {
814 cost = 0
815 }
816
817 a.eventTokensUsed(session.ID, model, usage, cost)
818
819 if overrideCost != nil {
820 session.Cost += *overrideCost
821 } else {
822 session.Cost += cost
823 }
824
825 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
826 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
827}
828
829func (a *sessionAgent) Cancel(sessionID string) {
830 // Cancel regular requests.
831 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
832 slog.Info("Request cancellation initiated", "session_id", sessionID)
833 cancel()
834 }
835
836 // Also check for summarize requests.
837 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
838 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
839 cancel()
840 }
841
842 if a.QueuedPrompts(sessionID) > 0 {
843 slog.Info("Clearing queued prompts", "session_id", sessionID)
844 a.messageQueue.Del(sessionID)
845 }
846}
847
848func (a *sessionAgent) ClearQueue(sessionID string) {
849 if a.QueuedPrompts(sessionID) > 0 {
850 slog.Info("Clearing queued prompts", "session_id", sessionID)
851 a.messageQueue.Del(sessionID)
852 }
853}
854
855func (a *sessionAgent) CancelAll() {
856 if !a.IsBusy() {
857 return
858 }
859 for key := range a.activeRequests.Seq2() {
860 a.Cancel(key) // key is sessionID
861 }
862
863 timeout := time.After(5 * time.Second)
864 for a.IsBusy() {
865 select {
866 case <-timeout:
867 return
868 default:
869 time.Sleep(200 * time.Millisecond)
870 }
871 }
872}
873
874func (a *sessionAgent) IsBusy() bool {
875 var busy bool
876 for cancelFunc := range a.activeRequests.Seq() {
877 if cancelFunc != nil {
878 busy = true
879 break
880 }
881 }
882 return busy
883}
884
885func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
886 _, busy := a.activeRequests.Get(sessionID)
887 return busy
888}
889
890func (a *sessionAgent) QueuedPrompts(sessionID string) int {
891 l, ok := a.messageQueue.Get(sessionID)
892 if !ok {
893 return 0
894 }
895 return len(l)
896}
897
898func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
899 l, ok := a.messageQueue.Get(sessionID)
900 if !ok {
901 return nil
902 }
903 prompts := make([]string, len(l))
904 for i, call := range l {
905 prompts[i] = call.Prompt
906 }
907 return prompts
908}
909
910func (a *sessionAgent) SetModels(large Model, small Model) {
911 a.largeModel = large
912 a.smallModel = small
913}
914
915func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
916 a.tools = tools
917}
918
919func (a *sessionAgent) Model() Model {
920 return a.largeModel
921}
922
923func (a *sessionAgent) promptPrefix() string {
924 if a.isClaudeCode() {
925 return "You are Claude Code, Anthropic's official CLI for Claude."
926 }
927 return a.systemPromptPrefix
928}
929
930func (a *sessionAgent) isClaudeCode() bool {
931 cfg := config.Get()
932 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
933 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
934}
935
936// convertToToolResult converts a fantasy tool result to a message tool result.
937func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
938 baseResult := message.ToolResult{
939 ToolCallID: result.ToolCallID,
940 Name: result.ToolName,
941 Metadata: result.ClientMetadata,
942 }
943
944 switch result.Result.GetType() {
945 case fantasy.ToolResultContentTypeText:
946 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
947 baseResult.Content = r.Text
948 }
949 case fantasy.ToolResultContentTypeError:
950 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
951 baseResult.Content = r.Error.Error()
952 baseResult.IsError = true
953 }
954 case fantasy.ToolResultContentTypeMedia:
955 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
956 content := r.Text
957 if content == "" {
958 content = fmt.Sprintf("Loaded %s content", r.MediaType)
959 }
960 baseResult.Content = content
961 baseResult.Data = r.Data
962 baseResult.MIMEType = r.MediaType
963 }
964 }
965
966 return baseResult
967}
968
969// workaroundProviderMediaLimitations converts media content in tool results to
970// user messages for providers that don't natively support images in tool results.
971//
972// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
973// don't support sending images/media in tool result messages - they only accept
974// text in tool results. However, they DO support images in user messages.
975//
976// If we send media in tool results to these providers, the API returns an error.
977//
978// Solution: For these providers, we:
979// 1. Replace the media in the tool result with a text placeholder
980// 2. Inject a user message immediately after with the image as a file attachment
981// 3. This maintains the tool execution flow while working around API limitations
982//
983// Anthropic and Bedrock support images natively in tool results, so we skip
984// this workaround for them.
985//
986// Example transformation:
987//
988// BEFORE: [tool result: image data]
989// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
990func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
991 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
992 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
993
994 if providerSupportsMedia {
995 return messages
996 }
997
998 convertedMessages := make([]fantasy.Message, 0, len(messages))
999
1000 for _, msg := range messages {
1001 if msg.Role != fantasy.MessageRoleTool {
1002 convertedMessages = append(convertedMessages, msg)
1003 continue
1004 }
1005
1006 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1007 var mediaFiles []fantasy.FilePart
1008
1009 for _, part := range msg.Content {
1010 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1011 if !ok {
1012 textParts = append(textParts, part)
1013 continue
1014 }
1015
1016 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1017 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1018 if err != nil {
1019 slog.Warn("failed to decode media data", "error", err)
1020 textParts = append(textParts, part)
1021 continue
1022 }
1023
1024 mediaFiles = append(mediaFiles, fantasy.FilePart{
1025 Data: decoded,
1026 MediaType: media.MediaType,
1027 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1028 })
1029
1030 textParts = append(textParts, fantasy.ToolResultPart{
1031 ToolCallID: toolResult.ToolCallID,
1032 Output: fantasy.ToolResultOutputContentText{
1033 Text: "[Image/media content loaded - see attached file]",
1034 },
1035 ProviderOptions: toolResult.ProviderOptions,
1036 })
1037 } else {
1038 textParts = append(textParts, part)
1039 }
1040 }
1041
1042 convertedMessages = append(convertedMessages, fantasy.Message{
1043 Role: fantasy.MessageRoleTool,
1044 Content: textParts,
1045 })
1046
1047 if len(mediaFiles) > 0 {
1048 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1049 "Here is the media content from the tool result:",
1050 mediaFiles...,
1051 ))
1052 }
1053 }
1054
1055 return convertedMessages
1056}