1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "github.com/charmbracelet/catwalk/pkg/catwalk"
31 "github.com/charmbracelet/crush/internal/agent/tools"
32 "github.com/charmbracelet/crush/internal/config"
33 "github.com/charmbracelet/crush/internal/csync"
34 "github.com/charmbracelet/crush/internal/message"
35 "github.com/charmbracelet/crush/internal/permission"
36 "github.com/charmbracelet/crush/internal/session"
37 "github.com/charmbracelet/crush/internal/stringext"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57}
58
59type SessionAgent interface {
60 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
61 SetModels(large Model, small Model)
62 SetTools(tools []fantasy.AgentTool)
63 Cancel(sessionID string)
64 CancelAll()
65 IsSessionBusy(sessionID string) bool
66 IsBusy() bool
67 QueuedPrompts(sessionID string) int
68 QueuedPromptsList(sessionID string) []string
69 ClearQueue(sessionID string)
70 Summarize(context.Context, string, fantasy.ProviderOptions) error
71 Model() Model
72}
73
74type Model struct {
75 Model fantasy.LanguageModel
76 CatwalkCfg catwalk.Model
77 ModelCfg config.SelectedModel
78}
79
80type sessionAgent struct {
81 largeModel Model
82 smallModel Model
83 systemPromptPrefix string
84 systemPrompt string
85 isSubAgent bool
86 tools []fantasy.AgentTool
87 sessions session.Service
88 messages message.Service
89 disableAutoSummarize bool
90 isYolo bool
91
92 messageQueue *csync.Map[string, []SessionAgentCall]
93 activeRequests *csync.Map[string, context.CancelFunc]
94}
95
96type SessionAgentOptions struct {
97 LargeModel Model
98 SmallModel Model
99 SystemPromptPrefix string
100 SystemPrompt string
101 IsSubAgent bool
102 DisableAutoSummarize bool
103 IsYolo bool
104 Sessions session.Service
105 Messages message.Service
106 Tools []fantasy.AgentTool
107}
108
109func NewSessionAgent(
110 opts SessionAgentOptions,
111) SessionAgent {
112 return &sessionAgent{
113 largeModel: opts.LargeModel,
114 smallModel: opts.SmallModel,
115 systemPromptPrefix: opts.SystemPromptPrefix,
116 systemPrompt: opts.SystemPrompt,
117 isSubAgent: opts.IsSubAgent,
118 sessions: opts.Sessions,
119 messages: opts.Messages,
120 disableAutoSummarize: opts.DisableAutoSummarize,
121 tools: opts.Tools,
122 isYolo: opts.IsYolo,
123 messageQueue: csync.NewMap[string, []SessionAgentCall](),
124 activeRequests: csync.NewMap[string, context.CancelFunc](),
125 }
126}
127
128func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
129 if call.Prompt == "" {
130 return nil, ErrEmptyPrompt
131 }
132 if call.SessionID == "" {
133 return nil, ErrSessionMissing
134 }
135
136 // Queue the message if busy
137 if a.IsSessionBusy(call.SessionID) {
138 existing, ok := a.messageQueue.Get(call.SessionID)
139 if !ok {
140 existing = []SessionAgentCall{}
141 }
142 existing = append(existing, call)
143 a.messageQueue.Set(call.SessionID, existing)
144 return nil, nil
145 }
146
147 if len(a.tools) > 0 {
148 // Add Anthropic caching to the last tool.
149 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
150 }
151
152 agent := fantasy.NewAgent(
153 a.largeModel.Model,
154 fantasy.WithSystemPrompt(a.systemPrompt),
155 fantasy.WithTools(a.tools...),
156 )
157
158 sessionLock := sync.Mutex{}
159 currentSession, err := a.sessions.Get(ctx, call.SessionID)
160 if err != nil {
161 return nil, fmt.Errorf("failed to get session: %w", err)
162 }
163
164 msgs, err := a.getSessionMessages(ctx, currentSession)
165 if err != nil {
166 return nil, fmt.Errorf("failed to get session messages: %w", err)
167 }
168
169 var wg sync.WaitGroup
170 // Generate title if first message.
171 if len(msgs) == 0 {
172 wg.Go(func() {
173 sessionLock.Lock()
174 a.generateTitle(ctx, ¤tSession, call.Prompt)
175 sessionLock.Unlock()
176 })
177 }
178
179 // Add the user message to the session.
180 _, err = a.createUserMessage(ctx, call)
181 if err != nil {
182 return nil, err
183 }
184
185 // Add the session to the context.
186 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
187
188 genCtx, cancel := context.WithCancel(ctx)
189 a.activeRequests.Set(call.SessionID, cancel)
190
191 defer cancel()
192 defer a.activeRequests.Del(call.SessionID)
193
194 history, files := a.preparePrompt(msgs, call.Attachments...)
195
196 startTime := time.Now()
197 a.eventPromptSent(call.SessionID)
198
199 var currentAssistant *message.Message
200 var shouldSummarize bool
201 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
202 Prompt: call.Prompt,
203 Files: files,
204 Messages: history,
205 ProviderOptions: call.ProviderOptions,
206 MaxOutputTokens: &call.MaxOutputTokens,
207 TopP: call.TopP,
208 Temperature: call.Temperature,
209 PresencePenalty: call.PresencePenalty,
210 TopK: call.TopK,
211 FrequencyPenalty: call.FrequencyPenalty,
212 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
213 prepared.Messages = options.Messages
214 for i := range prepared.Messages {
215 prepared.Messages[i].ProviderOptions = nil
216 }
217
218 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
219 a.messageQueue.Del(call.SessionID)
220 for _, queued := range queuedCalls {
221 userMessage, createErr := a.createUserMessage(callContext, queued)
222 if createErr != nil {
223 return callContext, prepared, createErr
224 }
225 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
226 }
227
228 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
229
230 lastSystemRoleInx := 0
231 systemMessageUpdated := false
232 for i, msg := range prepared.Messages {
233 // Only add cache control to the last message.
234 if msg.Role == fantasy.MessageRoleSystem {
235 lastSystemRoleInx = i
236 } else if !systemMessageUpdated {
237 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
238 systemMessageUpdated = true
239 }
240 // Than add cache control to the last 2 messages.
241 if i > len(prepared.Messages)-3 {
242 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
243 }
244 }
245
246 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
247 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
248 }
249
250 var assistantMsg message.Message
251 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
252 Role: message.Assistant,
253 Parts: []message.ContentPart{},
254 Model: a.largeModel.ModelCfg.Model,
255 Provider: a.largeModel.ModelCfg.Provider,
256 })
257 if err != nil {
258 return callContext, prepared, err
259 }
260 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
261 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
262 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
263 currentAssistant = &assistantMsg
264 return callContext, prepared, err
265 },
266 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
267 currentAssistant.AppendReasoningContent(reasoning.Text)
268 return a.messages.Update(genCtx, *currentAssistant)
269 },
270 OnReasoningDelta: func(id string, text string) error {
271 currentAssistant.AppendReasoningContent(text)
272 return a.messages.Update(genCtx, *currentAssistant)
273 },
274 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
275 // handle anthropic signature
276 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
277 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
278 currentAssistant.AppendReasoningSignature(reasoning.Signature)
279 }
280 }
281 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
282 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
283 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
284 }
285 }
286 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
287 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
288 currentAssistant.SetReasoningResponsesData(reasoning)
289 }
290 }
291 currentAssistant.FinishThinking()
292 return a.messages.Update(genCtx, *currentAssistant)
293 },
294 OnTextDelta: func(id string, text string) error {
295 // Strip leading newline from initial text content. This is is
296 // particularly important in non-interactive mode where leading
297 // newlines are very visible.
298 if len(currentAssistant.Parts) == 0 {
299 text = strings.TrimPrefix(text, "\n")
300 }
301
302 currentAssistant.AppendContent(text)
303 return a.messages.Update(genCtx, *currentAssistant)
304 },
305 OnToolInputStart: func(id string, toolName string) error {
306 toolCall := message.ToolCall{
307 ID: id,
308 Name: toolName,
309 ProviderExecuted: false,
310 Finished: false,
311 }
312 currentAssistant.AddToolCall(toolCall)
313 return a.messages.Update(genCtx, *currentAssistant)
314 },
315 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
316 // TODO: implement
317 },
318 OnToolCall: func(tc fantasy.ToolCallContent) error {
319 toolCall := message.ToolCall{
320 ID: tc.ToolCallID,
321 Name: tc.ToolName,
322 Input: tc.Input,
323 ProviderExecuted: false,
324 Finished: true,
325 }
326 currentAssistant.AddToolCall(toolCall)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnToolResult: func(result fantasy.ToolResultContent) error {
330 toolResult := a.convertToToolResult(result)
331 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
332 Role: message.Tool,
333 Parts: []message.ContentPart{
334 toolResult,
335 },
336 })
337 return createMsgErr
338 },
339 OnStepFinish: func(stepResult fantasy.StepResult) error {
340 finishReason := message.FinishReasonUnknown
341 switch stepResult.FinishReason {
342 case fantasy.FinishReasonLength:
343 finishReason = message.FinishReasonMaxTokens
344 case fantasy.FinishReasonStop:
345 finishReason = message.FinishReasonEndTurn
346 case fantasy.FinishReasonToolCalls:
347 finishReason = message.FinishReasonToolUse
348 }
349 currentAssistant.AddFinish(finishReason, "", "")
350 sessionLock.Lock()
351 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
352 if getSessionErr != nil {
353 sessionLock.Unlock()
354 return getSessionErr
355 }
356 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
357 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
358 sessionLock.Unlock()
359 if sessionErr != nil {
360 return sessionErr
361 }
362 return a.messages.Update(genCtx, *currentAssistant)
363 },
364 StopWhen: []fantasy.StopCondition{
365 func(_ []fantasy.StepResult) bool {
366 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
367 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
368 remaining := cw - tokens
369 var threshold int64
370 if cw > 200_000 {
371 threshold = 20_000
372 } else {
373 threshold = int64(float64(cw) * 0.2)
374 }
375 if (remaining <= threshold) && !a.disableAutoSummarize {
376 shouldSummarize = true
377 return true
378 }
379 return false
380 },
381 },
382 })
383
384 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
385
386 if err != nil {
387 isCancelErr := errors.Is(err, context.Canceled)
388 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
389 if currentAssistant == nil {
390 return result, err
391 }
392 // Ensure we finish thinking on error to close the reasoning state.
393 currentAssistant.FinishThinking()
394 toolCalls := currentAssistant.ToolCalls()
395 // INFO: we use the parent context here because the genCtx has been cancelled.
396 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
397 if createErr != nil {
398 return nil, createErr
399 }
400 for _, tc := range toolCalls {
401 if !tc.Finished {
402 tc.Finished = true
403 tc.Input = "{}"
404 currentAssistant.AddToolCall(tc)
405 updateErr := a.messages.Update(ctx, *currentAssistant)
406 if updateErr != nil {
407 return nil, updateErr
408 }
409 }
410
411 found := false
412 for _, msg := range msgs {
413 if msg.Role == message.Tool {
414 for _, tr := range msg.ToolResults() {
415 if tr.ToolCallID == tc.ID {
416 found = true
417 break
418 }
419 }
420 }
421 if found {
422 break
423 }
424 }
425 if found {
426 continue
427 }
428 content := "There was an error while executing the tool"
429 if isCancelErr {
430 content = "Tool execution canceled by user"
431 } else if isPermissionErr {
432 content = "User denied permission"
433 }
434 toolResult := message.ToolResult{
435 ToolCallID: tc.ID,
436 Name: tc.Name,
437 Content: content,
438 IsError: true,
439 }
440 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
441 Role: message.Tool,
442 Parts: []message.ContentPart{
443 toolResult,
444 },
445 })
446 if createErr != nil {
447 return nil, createErr
448 }
449 }
450 var fantasyErr *fantasy.Error
451 var providerErr *fantasy.ProviderError
452 const defaultTitle = "Provider Error"
453 if isCancelErr {
454 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
455 } else if isPermissionErr {
456 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
457 } else if errors.As(err, &providerErr) {
458 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
459 } else if errors.As(err, &fantasyErr) {
460 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
461 } else {
462 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
463 }
464 // Note: we use the parent context here because the genCtx has been
465 // cancelled.
466 updateErr := a.messages.Update(ctx, *currentAssistant)
467 if updateErr != nil {
468 return nil, updateErr
469 }
470 return nil, err
471 }
472 wg.Wait()
473
474 if shouldSummarize {
475 a.activeRequests.Del(call.SessionID)
476 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
477 return nil, summarizeErr
478 }
479 // If the agent wasn't done...
480 if len(currentAssistant.ToolCalls()) > 0 {
481 existing, ok := a.messageQueue.Get(call.SessionID)
482 if !ok {
483 existing = []SessionAgentCall{}
484 }
485 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
486 existing = append(existing, call)
487 a.messageQueue.Set(call.SessionID, existing)
488 }
489 }
490
491 // Release active request before processing queued messages.
492 a.activeRequests.Del(call.SessionID)
493 cancel()
494
495 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
496 if !ok || len(queuedMessages) == 0 {
497 return result, err
498 }
499 // There are queued messages restart the loop.
500 firstQueuedMessage := queuedMessages[0]
501 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
502 return a.Run(ctx, firstQueuedMessage)
503}
504
505func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
506 if a.IsSessionBusy(sessionID) {
507 return ErrSessionBusy
508 }
509
510 currentSession, err := a.sessions.Get(ctx, sessionID)
511 if err != nil {
512 return fmt.Errorf("failed to get session: %w", err)
513 }
514 msgs, err := a.getSessionMessages(ctx, currentSession)
515 if err != nil {
516 return err
517 }
518 if len(msgs) == 0 {
519 // Nothing to summarize.
520 return nil
521 }
522
523 aiMsgs, _ := a.preparePrompt(msgs)
524
525 genCtx, cancel := context.WithCancel(ctx)
526 a.activeRequests.Set(sessionID, cancel)
527 defer a.activeRequests.Del(sessionID)
528 defer cancel()
529
530 agent := fantasy.NewAgent(a.largeModel.Model,
531 fantasy.WithSystemPrompt(string(summaryPrompt)),
532 )
533 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
534 Role: message.Assistant,
535 Model: a.largeModel.Model.Model(),
536 Provider: a.largeModel.Model.Provider(),
537 IsSummaryMessage: true,
538 })
539 if err != nil {
540 return err
541 }
542
543 summaryPromptText := "Provide a detailed summary of our conversation above."
544 if len(currentSession.Todos) > 0 {
545 summaryPromptText += "\n\n## Current Todo List\n\n"
546 for _, t := range currentSession.Todos {
547 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
548 }
549 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
550 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
551 }
552
553 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
554 Prompt: summaryPromptText,
555 Messages: aiMsgs,
556 ProviderOptions: opts,
557 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
558 prepared.Messages = options.Messages
559 if a.systemPromptPrefix != "" {
560 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
561 }
562 return callContext, prepared, nil
563 },
564 OnReasoningDelta: func(id string, text string) error {
565 summaryMessage.AppendReasoningContent(text)
566 return a.messages.Update(genCtx, summaryMessage)
567 },
568 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
569 // Handle anthropic signature.
570 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
571 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
572 summaryMessage.AppendReasoningSignature(signature.Signature)
573 }
574 }
575 summaryMessage.FinishThinking()
576 return a.messages.Update(genCtx, summaryMessage)
577 },
578 OnTextDelta: func(id, text string) error {
579 summaryMessage.AppendContent(text)
580 return a.messages.Update(genCtx, summaryMessage)
581 },
582 })
583 if err != nil {
584 isCancelErr := errors.Is(err, context.Canceled)
585 if isCancelErr {
586 // User cancelled summarize we need to remove the summary message.
587 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
588 return deleteErr
589 }
590 return err
591 }
592
593 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
594 err = a.messages.Update(genCtx, summaryMessage)
595 if err != nil {
596 return err
597 }
598
599 var openrouterCost *float64
600 for _, step := range resp.Steps {
601 stepCost := a.openrouterCost(step.ProviderMetadata)
602 if stepCost != nil {
603 newCost := *stepCost
604 if openrouterCost != nil {
605 newCost += *openrouterCost
606 }
607 openrouterCost = &newCost
608 }
609 }
610
611 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
612
613 // Just in case, get just the last usage info.
614 usage := resp.Response.Usage
615 currentSession.SummaryMessageID = summaryMessage.ID
616 currentSession.CompletionTokens = usage.OutputTokens
617 currentSession.PromptTokens = 0
618 _, err = a.sessions.Save(genCtx, currentSession)
619 return err
620}
621
622func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
623 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
624 return fantasy.ProviderOptions{}
625 }
626 return fantasy.ProviderOptions{
627 anthropic.Name: &anthropic.ProviderCacheControlOptions{
628 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
629 },
630 bedrock.Name: &anthropic.ProviderCacheControlOptions{
631 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
632 },
633 }
634}
635
636func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
637 var attachmentParts []message.ContentPart
638 for _, attachment := range call.Attachments {
639 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
640 }
641 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
642 parts = append(parts, attachmentParts...)
643 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
644 Role: message.User,
645 Parts: parts,
646 })
647 if err != nil {
648 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
649 }
650 return msg, nil
651}
652
653func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
654 var history []fantasy.Message
655 if !a.isSubAgent {
656 history = append(history, fantasy.NewUserMessage(
657 fmt.Sprintf("<system_reminder>%s</system_reminder>",
658 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
659If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
660If not, please feel free to ignore. Again do not mention this message to the user.`,
661 ),
662 ))
663 }
664 for _, m := range msgs {
665 if len(m.Parts) == 0 {
666 continue
667 }
668 // Assistant message without content or tool calls (cancelled before it
669 // returned anything).
670 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
671 continue
672 }
673 history = append(history, m.ToAIMessage()...)
674 }
675
676 var files []fantasy.FilePart
677 for _, attachment := range attachments {
678 files = append(files, fantasy.FilePart{
679 Filename: attachment.FileName,
680 Data: attachment.Content,
681 MediaType: attachment.MimeType,
682 })
683 }
684
685 return history, files
686}
687
688func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
689 msgs, err := a.messages.List(ctx, session.ID)
690 if err != nil {
691 return nil, fmt.Errorf("failed to list messages: %w", err)
692 }
693
694 if session.SummaryMessageID != "" {
695 summaryMsgInex := -1
696 for i, msg := range msgs {
697 if msg.ID == session.SummaryMessageID {
698 summaryMsgInex = i
699 break
700 }
701 }
702 if summaryMsgInex != -1 {
703 msgs = msgs[summaryMsgInex:]
704 msgs[0].Role = message.User
705 }
706 }
707 return msgs, nil
708}
709
710func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
711 if prompt == "" {
712 return
713 }
714
715 var maxOutput int64 = 40
716 if a.smallModel.CatwalkCfg.CanReason {
717 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
718 }
719
720 agent := fantasy.NewAgent(a.smallModel.Model,
721 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
722 fantasy.WithMaxOutputTokens(maxOutput),
723 )
724
725 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
726 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
727 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
728 prepared.Messages = options.Messages
729 if a.systemPromptPrefix != "" {
730 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
731 }
732 return callContext, prepared, nil
733 },
734 })
735 if err != nil {
736 slog.Error("error generating title", "err", err)
737 return
738 }
739
740 title := resp.Response.Content.Text()
741
742 title = strings.ReplaceAll(title, "\n", " ")
743
744 // Remove thinking tags if present.
745 if idx := strings.Index(title, "</think>"); idx > 0 {
746 title = title[idx+len("</think>"):]
747 }
748
749 title = strings.TrimSpace(title)
750 if title == "" {
751 slog.Warn("failed to generate title", "warn", "empty title")
752 return
753 }
754
755 session.Title = title
756
757 var openrouterCost *float64
758 for _, step := range resp.Steps {
759 stepCost := a.openrouterCost(step.ProviderMetadata)
760 if stepCost != nil {
761 newCost := *stepCost
762 if openrouterCost != nil {
763 newCost += *openrouterCost
764 }
765 openrouterCost = &newCost
766 }
767 }
768
769 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
770 _, saveErr := a.sessions.Save(ctx, *session)
771 if saveErr != nil {
772 slog.Error("failed to save session title & usage", "error", saveErr)
773 return
774 }
775}
776
777func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
778 openrouterMetadata, ok := metadata[openrouter.Name]
779 if !ok {
780 return nil
781 }
782
783 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
784 if !ok {
785 return nil
786 }
787 return &opts.Usage.Cost
788}
789
790func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
791 modelConfig := model.CatwalkCfg
792 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
793 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
794 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
795 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
796
797 if a.isClaudeCode() {
798 cost = 0
799 }
800
801 a.eventTokensUsed(session.ID, model, usage, cost)
802
803 if overrideCost != nil {
804 session.Cost += *overrideCost
805 } else {
806 session.Cost += cost
807 }
808
809 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
810 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
811}
812
813func (a *sessionAgent) Cancel(sessionID string) {
814 // Cancel regular requests.
815 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
816 slog.Info("Request cancellation initiated", "session_id", sessionID)
817 cancel()
818 }
819
820 // Also check for summarize requests.
821 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
822 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
823 cancel()
824 }
825
826 if a.QueuedPrompts(sessionID) > 0 {
827 slog.Info("Clearing queued prompts", "session_id", sessionID)
828 a.messageQueue.Del(sessionID)
829 }
830}
831
832func (a *sessionAgent) ClearQueue(sessionID string) {
833 if a.QueuedPrompts(sessionID) > 0 {
834 slog.Info("Clearing queued prompts", "session_id", sessionID)
835 a.messageQueue.Del(sessionID)
836 }
837}
838
839func (a *sessionAgent) CancelAll() {
840 if !a.IsBusy() {
841 return
842 }
843 for key := range a.activeRequests.Seq2() {
844 a.Cancel(key) // key is sessionID
845 }
846
847 timeout := time.After(5 * time.Second)
848 for a.IsBusy() {
849 select {
850 case <-timeout:
851 return
852 default:
853 time.Sleep(200 * time.Millisecond)
854 }
855 }
856}
857
858func (a *sessionAgent) IsBusy() bool {
859 var busy bool
860 for cancelFunc := range a.activeRequests.Seq() {
861 if cancelFunc != nil {
862 busy = true
863 break
864 }
865 }
866 return busy
867}
868
869func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
870 _, busy := a.activeRequests.Get(sessionID)
871 return busy
872}
873
874func (a *sessionAgent) QueuedPrompts(sessionID string) int {
875 l, ok := a.messageQueue.Get(sessionID)
876 if !ok {
877 return 0
878 }
879 return len(l)
880}
881
882func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
883 l, ok := a.messageQueue.Get(sessionID)
884 if !ok {
885 return nil
886 }
887 prompts := make([]string, len(l))
888 for i, call := range l {
889 prompts[i] = call.Prompt
890 }
891 return prompts
892}
893
894func (a *sessionAgent) SetModels(large Model, small Model) {
895 a.largeModel = large
896 a.smallModel = small
897}
898
899func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
900 a.tools = tools
901}
902
903func (a *sessionAgent) Model() Model {
904 return a.largeModel
905}
906
907func (a *sessionAgent) promptPrefix() string {
908 if a.isClaudeCode() {
909 return "You are Claude Code, Anthropic's official CLI for Claude."
910 }
911 return a.systemPromptPrefix
912}
913
914func (a *sessionAgent) isClaudeCode() bool {
915 cfg := config.Get()
916 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
917 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
918}
919
920// convertToToolResult converts a fantasy tool result to a message tool result.
921func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
922 baseResult := message.ToolResult{
923 ToolCallID: result.ToolCallID,
924 Name: result.ToolName,
925 Metadata: result.ClientMetadata,
926 }
927
928 switch result.Result.GetType() {
929 case fantasy.ToolResultContentTypeText:
930 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
931 baseResult.Content = r.Text
932 }
933 case fantasy.ToolResultContentTypeError:
934 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
935 baseResult.Content = r.Error.Error()
936 baseResult.IsError = true
937 }
938 case fantasy.ToolResultContentTypeMedia:
939 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
940 content := r.Text
941 if content == "" {
942 content = fmt.Sprintf("Loaded %s content", r.MediaType)
943 }
944 baseResult.Content = content
945 baseResult.Data = r.Data
946 baseResult.MIMEType = r.MediaType
947 }
948 }
949
950 return baseResult
951}
952
953// workaroundProviderMediaLimitations converts media content in tool results to
954// user messages for providers that don't natively support images in tool results.
955//
956// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
957// don't support sending images/media in tool result messages - they only accept
958// text in tool results. However, they DO support images in user messages.
959//
960// If we send media in tool results to these providers, the API returns an error.
961//
962// Solution: For these providers, we:
963// 1. Replace the media in the tool result with a text placeholder
964// 2. Inject a user message immediately after with the image as a file attachment
965// 3. This maintains the tool execution flow while working around API limitations
966//
967// Anthropic and Bedrock support images natively in tool results, so we skip
968// this workaround for them.
969//
970// Example transformation:
971//
972// BEFORE: [tool result: image data]
973// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
974func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
975 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
976 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
977
978 if providerSupportsMedia {
979 return messages
980 }
981
982 convertedMessages := make([]fantasy.Message, 0, len(messages))
983
984 for _, msg := range messages {
985 if msg.Role != fantasy.MessageRoleTool {
986 convertedMessages = append(convertedMessages, msg)
987 continue
988 }
989
990 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
991 var mediaFiles []fantasy.FilePart
992
993 for _, part := range msg.Content {
994 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
995 if !ok {
996 textParts = append(textParts, part)
997 continue
998 }
999
1000 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1001 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1002 if err != nil {
1003 slog.Warn("failed to decode media data", "error", err)
1004 textParts = append(textParts, part)
1005 continue
1006 }
1007
1008 mediaFiles = append(mediaFiles, fantasy.FilePart{
1009 Data: decoded,
1010 MediaType: media.MediaType,
1011 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1012 })
1013
1014 textParts = append(textParts, fantasy.ToolResultPart{
1015 ToolCallID: toolResult.ToolCallID,
1016 Output: fantasy.ToolResultOutputContentText{
1017 Text: "[Image/media content loaded - see attached file]",
1018 },
1019 ProviderOptions: toolResult.ProviderOptions,
1020 })
1021 } else {
1022 textParts = append(textParts, part)
1023 }
1024 }
1025
1026 convertedMessages = append(convertedMessages, fantasy.Message{
1027 Role: fantasy.MessageRoleTool,
1028 Content: textParts,
1029 })
1030
1031 if len(mediaFiles) > 0 {
1032 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1033 "Here is the media content from the tool result:",
1034 mediaFiles...,
1035 ))
1036 }
1037 }
1038
1039 return convertedMessages
1040}