1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "charm.land/lipgloss/v2"
31 "github.com/charmbracelet/catwalk/pkg/catwalk"
32 "github.com/charmbracelet/crush/internal/agent/hyper"
33 "github.com/charmbracelet/crush/internal/agent/tools"
34 "github.com/charmbracelet/crush/internal/config"
35 "github.com/charmbracelet/crush/internal/csync"
36 "github.com/charmbracelet/crush/internal/message"
37 "github.com/charmbracelet/crush/internal/permission"
38 "github.com/charmbracelet/crush/internal/session"
39 "github.com/charmbracelet/crush/internal/stringext"
40)
41
42//go:embed templates/title.md
43var titlePrompt []byte
44
45//go:embed templates/summary.md
46var summaryPrompt []byte
47
48type SessionAgentCall struct {
49 SessionID string
50 Prompt string
51 ProviderOptions fantasy.ProviderOptions
52 Attachments []message.Attachment
53 MaxOutputTokens int64
54 Temperature *float64
55 TopP *float64
56 TopK *int64
57 FrequencyPenalty *float64
58 PresencePenalty *float64
59}
60
61type SessionAgent interface {
62 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
63 SetModels(large Model, small Model)
64 SetTools(tools []fantasy.AgentTool)
65 Cancel(sessionID string)
66 CancelAll()
67 IsSessionBusy(sessionID string) bool
68 IsBusy() bool
69 QueuedPrompts(sessionID string) int
70 QueuedPromptsList(sessionID string) []string
71 ClearQueue(sessionID string)
72 Summarize(context.Context, string, fantasy.ProviderOptions) error
73 Model() Model
74}
75
76type Model struct {
77 Model fantasy.LanguageModel
78 CatwalkCfg catwalk.Model
79 ModelCfg config.SelectedModel
80}
81
82type sessionAgent struct {
83 largeModel Model
84 smallModel Model
85 systemPromptPrefix string
86 systemPrompt string
87 isSubAgent bool
88 tools []fantasy.AgentTool
89 sessions session.Service
90 messages message.Service
91 disableAutoSummarize bool
92 isYolo bool
93
94 messageQueue *csync.Map[string, []SessionAgentCall]
95 activeRequests *csync.Map[string, context.CancelFunc]
96}
97
98type SessionAgentOptions struct {
99 LargeModel Model
100 SmallModel Model
101 SystemPromptPrefix string
102 SystemPrompt string
103 IsSubAgent bool
104 DisableAutoSummarize bool
105 IsYolo bool
106 Sessions session.Service
107 Messages message.Service
108 Tools []fantasy.AgentTool
109}
110
111func NewSessionAgent(
112 opts SessionAgentOptions,
113) SessionAgent {
114 return &sessionAgent{
115 largeModel: opts.LargeModel,
116 smallModel: opts.SmallModel,
117 systemPromptPrefix: opts.SystemPromptPrefix,
118 systemPrompt: opts.SystemPrompt,
119 isSubAgent: opts.IsSubAgent,
120 sessions: opts.Sessions,
121 messages: opts.Messages,
122 disableAutoSummarize: opts.DisableAutoSummarize,
123 tools: opts.Tools,
124 isYolo: opts.IsYolo,
125 messageQueue: csync.NewMap[string, []SessionAgentCall](),
126 activeRequests: csync.NewMap[string, context.CancelFunc](),
127 }
128}
129
130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
131 if call.Prompt == "" {
132 return nil, ErrEmptyPrompt
133 }
134 if call.SessionID == "" {
135 return nil, ErrSessionMissing
136 }
137
138 // Queue the message if busy
139 if a.IsSessionBusy(call.SessionID) {
140 existing, ok := a.messageQueue.Get(call.SessionID)
141 if !ok {
142 existing = []SessionAgentCall{}
143 }
144 existing = append(existing, call)
145 a.messageQueue.Set(call.SessionID, existing)
146 return nil, nil
147 }
148
149 if len(a.tools) > 0 {
150 // Add Anthropic caching to the last tool.
151 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
152 }
153
154 agent := fantasy.NewAgent(
155 a.largeModel.Model,
156 fantasy.WithSystemPrompt(a.systemPrompt),
157 fantasy.WithTools(a.tools...),
158 )
159
160 sessionLock := sync.Mutex{}
161 currentSession, err := a.sessions.Get(ctx, call.SessionID)
162 if err != nil {
163 return nil, fmt.Errorf("failed to get session: %w", err)
164 }
165
166 msgs, err := a.getSessionMessages(ctx, currentSession)
167 if err != nil {
168 return nil, fmt.Errorf("failed to get session messages: %w", err)
169 }
170
171 var wg sync.WaitGroup
172 // Generate title if first message.
173 if len(msgs) == 0 {
174 wg.Go(func() {
175 sessionLock.Lock()
176 a.generateTitle(ctx, ¤tSession, call.Prompt)
177 sessionLock.Unlock()
178 })
179 }
180
181 // Add the user message to the session.
182 _, err = a.createUserMessage(ctx, call)
183 if err != nil {
184 return nil, err
185 }
186
187 // Add the session to the context.
188 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
189
190 genCtx, cancel := context.WithCancel(ctx)
191 a.activeRequests.Set(call.SessionID, cancel)
192
193 defer cancel()
194 defer a.activeRequests.Del(call.SessionID)
195
196 history, files := a.preparePrompt(msgs, call.Attachments...)
197
198 startTime := time.Now()
199 a.eventPromptSent(call.SessionID)
200
201 var currentAssistant *message.Message
202 var shouldSummarize bool
203 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
204 Prompt: call.Prompt,
205 Files: files,
206 Messages: history,
207 ProviderOptions: call.ProviderOptions,
208 MaxOutputTokens: &call.MaxOutputTokens,
209 TopP: call.TopP,
210 Temperature: call.Temperature,
211 PresencePenalty: call.PresencePenalty,
212 TopK: call.TopK,
213 FrequencyPenalty: call.FrequencyPenalty,
214 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
215 prepared.Messages = options.Messages
216 for i := range prepared.Messages {
217 prepared.Messages[i].ProviderOptions = nil
218 }
219
220 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
221 a.messageQueue.Del(call.SessionID)
222 for _, queued := range queuedCalls {
223 userMessage, createErr := a.createUserMessage(callContext, queued)
224 if createErr != nil {
225 return callContext, prepared, createErr
226 }
227 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
228 }
229
230 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
231
232 lastSystemRoleInx := 0
233 systemMessageUpdated := false
234 for i, msg := range prepared.Messages {
235 // Only add cache control to the last message.
236 if msg.Role == fantasy.MessageRoleSystem {
237 lastSystemRoleInx = i
238 } else if !systemMessageUpdated {
239 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
240 systemMessageUpdated = true
241 }
242 // Than add cache control to the last 2 messages.
243 if i > len(prepared.Messages)-3 {
244 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
245 }
246 }
247
248 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
249 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
250 }
251
252 var assistantMsg message.Message
253 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
254 Role: message.Assistant,
255 Parts: []message.ContentPart{},
256 Model: a.largeModel.ModelCfg.Model,
257 Provider: a.largeModel.ModelCfg.Provider,
258 })
259 if err != nil {
260 return callContext, prepared, err
261 }
262 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
263 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
264 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
265 currentAssistant = &assistantMsg
266 return callContext, prepared, err
267 },
268 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
269 currentAssistant.AppendReasoningContent(reasoning.Text)
270 return a.messages.Update(genCtx, *currentAssistant)
271 },
272 OnReasoningDelta: func(id string, text string) error {
273 currentAssistant.AppendReasoningContent(text)
274 return a.messages.Update(genCtx, *currentAssistant)
275 },
276 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
277 // handle anthropic signature
278 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
279 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
280 currentAssistant.AppendReasoningSignature(reasoning.Signature)
281 }
282 }
283 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
284 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
285 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
286 }
287 }
288 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
289 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
290 currentAssistant.SetReasoningResponsesData(reasoning)
291 }
292 }
293 currentAssistant.FinishThinking()
294 return a.messages.Update(genCtx, *currentAssistant)
295 },
296 OnTextDelta: func(id string, text string) error {
297 // Strip leading newline from initial text content. This is is
298 // particularly important in non-interactive mode where leading
299 // newlines are very visible.
300 if len(currentAssistant.Parts) == 0 {
301 text = strings.TrimPrefix(text, "\n")
302 }
303
304 currentAssistant.AppendContent(text)
305 return a.messages.Update(genCtx, *currentAssistant)
306 },
307 OnToolInputStart: func(id string, toolName string) error {
308 toolCall := message.ToolCall{
309 ID: id,
310 Name: toolName,
311 ProviderExecuted: false,
312 Finished: false,
313 }
314 currentAssistant.AddToolCall(toolCall)
315 return a.messages.Update(genCtx, *currentAssistant)
316 },
317 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
318 // TODO: implement
319 },
320 OnToolCall: func(tc fantasy.ToolCallContent) error {
321 toolCall := message.ToolCall{
322 ID: tc.ToolCallID,
323 Name: tc.ToolName,
324 Input: tc.Input,
325 ProviderExecuted: false,
326 Finished: true,
327 }
328 currentAssistant.AddToolCall(toolCall)
329 return a.messages.Update(genCtx, *currentAssistant)
330 },
331 OnToolResult: func(result fantasy.ToolResultContent) error {
332 toolResult := a.convertToToolResult(result)
333 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
334 Role: message.Tool,
335 Parts: []message.ContentPart{
336 toolResult,
337 },
338 })
339 return createMsgErr
340 },
341 OnStepFinish: func(stepResult fantasy.StepResult) error {
342 finishReason := message.FinishReasonUnknown
343 switch stepResult.FinishReason {
344 case fantasy.FinishReasonLength:
345 finishReason = message.FinishReasonMaxTokens
346 case fantasy.FinishReasonStop:
347 finishReason = message.FinishReasonEndTurn
348 case fantasy.FinishReasonToolCalls:
349 finishReason = message.FinishReasonToolUse
350 }
351 currentAssistant.AddFinish(finishReason, "", "")
352 sessionLock.Lock()
353 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
354 if getSessionErr != nil {
355 sessionLock.Unlock()
356 return getSessionErr
357 }
358 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
359 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
360 sessionLock.Unlock()
361 if sessionErr != nil {
362 return sessionErr
363 }
364 return a.messages.Update(genCtx, *currentAssistant)
365 },
366 StopWhen: []fantasy.StopCondition{
367 func(_ []fantasy.StepResult) bool {
368 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
369 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
370 remaining := cw - tokens
371 var threshold int64
372 if cw > 200_000 {
373 threshold = 20_000
374 } else {
375 threshold = int64(float64(cw) * 0.2)
376 }
377 if (remaining <= threshold) && !a.disableAutoSummarize {
378 shouldSummarize = true
379 return true
380 }
381 return false
382 },
383 },
384 })
385
386 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
387
388 if err != nil {
389 isCancelErr := errors.Is(err, context.Canceled)
390 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
391 if currentAssistant == nil {
392 return result, err
393 }
394 // Ensure we finish thinking on error to close the reasoning state.
395 currentAssistant.FinishThinking()
396 toolCalls := currentAssistant.ToolCalls()
397 // INFO: we use the parent context here because the genCtx has been cancelled.
398 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
399 if createErr != nil {
400 return nil, createErr
401 }
402 for _, tc := range toolCalls {
403 if !tc.Finished {
404 tc.Finished = true
405 tc.Input = "{}"
406 currentAssistant.AddToolCall(tc)
407 updateErr := a.messages.Update(ctx, *currentAssistant)
408 if updateErr != nil {
409 return nil, updateErr
410 }
411 }
412
413 found := false
414 for _, msg := range msgs {
415 if msg.Role == message.Tool {
416 for _, tr := range msg.ToolResults() {
417 if tr.ToolCallID == tc.ID {
418 found = true
419 break
420 }
421 }
422 }
423 if found {
424 break
425 }
426 }
427 if found {
428 continue
429 }
430 content := "There was an error while executing the tool"
431 if isCancelErr {
432 content = "Tool execution canceled by user"
433 } else if isPermissionErr {
434 content = "User denied permission"
435 }
436 toolResult := message.ToolResult{
437 ToolCallID: tc.ID,
438 Name: tc.Name,
439 Content: content,
440 IsError: true,
441 }
442 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
443 Role: message.Tool,
444 Parts: []message.ContentPart{
445 toolResult,
446 },
447 })
448 if createErr != nil {
449 return nil, createErr
450 }
451 }
452 var fantasyErr *fantasy.Error
453 var providerErr *fantasy.ProviderError
454 const defaultTitle = "Provider Error"
455 if isCancelErr {
456 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
457 } else if isPermissionErr {
458 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
459 } else if errors.Is(err, hyper.ErrNoCredits) {
460 url := hyper.BaseURL()
461 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
462 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
463 } else if errors.As(err, &providerErr) {
464 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
465 } else if errors.As(err, &fantasyErr) {
466 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
467 } else {
468 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
469 }
470 // Note: we use the parent context here because the genCtx has been
471 // cancelled.
472 updateErr := a.messages.Update(ctx, *currentAssistant)
473 if updateErr != nil {
474 return nil, updateErr
475 }
476 return nil, err
477 }
478 wg.Wait()
479
480 if shouldSummarize {
481 a.activeRequests.Del(call.SessionID)
482 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
483 return nil, summarizeErr
484 }
485 // If the agent wasn't done...
486 if len(currentAssistant.ToolCalls()) > 0 {
487 existing, ok := a.messageQueue.Get(call.SessionID)
488 if !ok {
489 existing = []SessionAgentCall{}
490 }
491 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
492 existing = append(existing, call)
493 a.messageQueue.Set(call.SessionID, existing)
494 }
495 }
496
497 // Release active request before processing queued messages.
498 a.activeRequests.Del(call.SessionID)
499 cancel()
500
501 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
502 if !ok || len(queuedMessages) == 0 {
503 return result, err
504 }
505 // There are queued messages restart the loop.
506 firstQueuedMessage := queuedMessages[0]
507 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
508 return a.Run(ctx, firstQueuedMessage)
509}
510
511func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
512 if a.IsSessionBusy(sessionID) {
513 return ErrSessionBusy
514 }
515
516 currentSession, err := a.sessions.Get(ctx, sessionID)
517 if err != nil {
518 return fmt.Errorf("failed to get session: %w", err)
519 }
520 msgs, err := a.getSessionMessages(ctx, currentSession)
521 if err != nil {
522 return err
523 }
524 if len(msgs) == 0 {
525 // Nothing to summarize.
526 return nil
527 }
528
529 aiMsgs, _ := a.preparePrompt(msgs)
530
531 genCtx, cancel := context.WithCancel(ctx)
532 a.activeRequests.Set(sessionID, cancel)
533 defer a.activeRequests.Del(sessionID)
534 defer cancel()
535
536 agent := fantasy.NewAgent(a.largeModel.Model,
537 fantasy.WithSystemPrompt(string(summaryPrompt)),
538 )
539 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
540 Role: message.Assistant,
541 Model: a.largeModel.Model.Model(),
542 Provider: a.largeModel.Model.Provider(),
543 IsSummaryMessage: true,
544 })
545 if err != nil {
546 return err
547 }
548
549 summaryPromptText := "Provide a detailed summary of our conversation above."
550 if len(currentSession.Todos) > 0 {
551 summaryPromptText += "\n\n## Current Todo List\n\n"
552 for _, t := range currentSession.Todos {
553 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
554 }
555 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
556 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
557 }
558
559 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
560 Prompt: summaryPromptText,
561 Messages: aiMsgs,
562 ProviderOptions: opts,
563 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
564 prepared.Messages = options.Messages
565 if a.systemPromptPrefix != "" {
566 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
567 }
568 return callContext, prepared, nil
569 },
570 OnReasoningDelta: func(id string, text string) error {
571 summaryMessage.AppendReasoningContent(text)
572 return a.messages.Update(genCtx, summaryMessage)
573 },
574 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
575 // Handle anthropic signature.
576 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
577 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
578 summaryMessage.AppendReasoningSignature(signature.Signature)
579 }
580 }
581 summaryMessage.FinishThinking()
582 return a.messages.Update(genCtx, summaryMessage)
583 },
584 OnTextDelta: func(id, text string) error {
585 summaryMessage.AppendContent(text)
586 return a.messages.Update(genCtx, summaryMessage)
587 },
588 })
589 if err != nil {
590 isCancelErr := errors.Is(err, context.Canceled)
591 if isCancelErr {
592 // User cancelled summarize we need to remove the summary message.
593 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
594 return deleteErr
595 }
596 return err
597 }
598
599 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
600 err = a.messages.Update(genCtx, summaryMessage)
601 if err != nil {
602 return err
603 }
604
605 var openrouterCost *float64
606 for _, step := range resp.Steps {
607 stepCost := a.openrouterCost(step.ProviderMetadata)
608 if stepCost != nil {
609 newCost := *stepCost
610 if openrouterCost != nil {
611 newCost += *openrouterCost
612 }
613 openrouterCost = &newCost
614 }
615 }
616
617 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
618
619 // Just in case, get just the last usage info.
620 usage := resp.Response.Usage
621 currentSession.SummaryMessageID = summaryMessage.ID
622 currentSession.CompletionTokens = usage.OutputTokens
623 currentSession.PromptTokens = 0
624 _, err = a.sessions.Save(genCtx, currentSession)
625 return err
626}
627
628func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
629 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
630 return fantasy.ProviderOptions{}
631 }
632 return fantasy.ProviderOptions{
633 anthropic.Name: &anthropic.ProviderCacheControlOptions{
634 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
635 },
636 bedrock.Name: &anthropic.ProviderCacheControlOptions{
637 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
638 },
639 }
640}
641
642func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
643 var attachmentParts []message.ContentPart
644 for _, attachment := range call.Attachments {
645 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
646 }
647 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
648 parts = append(parts, attachmentParts...)
649 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
650 Role: message.User,
651 Parts: parts,
652 })
653 if err != nil {
654 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
655 }
656 return msg, nil
657}
658
659func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
660 var history []fantasy.Message
661 if !a.isSubAgent {
662 history = append(history, fantasy.NewUserMessage(
663 fmt.Sprintf("<system_reminder>%s</system_reminder>",
664 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
665If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
666If not, please feel free to ignore. Again do not mention this message to the user.`,
667 ),
668 ))
669 }
670 for _, m := range msgs {
671 if len(m.Parts) == 0 {
672 continue
673 }
674 // Assistant message without content or tool calls (cancelled before it
675 // returned anything).
676 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
677 continue
678 }
679 history = append(history, m.ToAIMessage()...)
680 }
681
682 var files []fantasy.FilePart
683 for _, attachment := range attachments {
684 files = append(files, fantasy.FilePart{
685 Filename: attachment.FileName,
686 Data: attachment.Content,
687 MediaType: attachment.MimeType,
688 })
689 }
690
691 return history, files
692}
693
694func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
695 msgs, err := a.messages.List(ctx, session.ID)
696 if err != nil {
697 return nil, fmt.Errorf("failed to list messages: %w", err)
698 }
699
700 if session.SummaryMessageID != "" {
701 summaryMsgInex := -1
702 for i, msg := range msgs {
703 if msg.ID == session.SummaryMessageID {
704 summaryMsgInex = i
705 break
706 }
707 }
708 if summaryMsgInex != -1 {
709 msgs = msgs[summaryMsgInex:]
710 msgs[0].Role = message.User
711 }
712 }
713 return msgs, nil
714}
715
716func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
717 if prompt == "" {
718 return
719 }
720
721 var maxOutput int64 = 40
722 if a.smallModel.CatwalkCfg.CanReason {
723 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
724 }
725
726 agent := fantasy.NewAgent(a.smallModel.Model,
727 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
728 fantasy.WithMaxOutputTokens(maxOutput),
729 )
730
731 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
732 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
733 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
734 prepared.Messages = options.Messages
735 if a.systemPromptPrefix != "" {
736 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
737 }
738 return callContext, prepared, nil
739 },
740 })
741 if err != nil {
742 slog.Error("error generating title", "err", err)
743 return
744 }
745
746 title := resp.Response.Content.Text()
747
748 title = strings.ReplaceAll(title, "\n", " ")
749
750 // Remove thinking tags if present.
751 if idx := strings.Index(title, "</think>"); idx > 0 {
752 title = title[idx+len("</think>"):]
753 }
754
755 title = strings.TrimSpace(title)
756 if title == "" {
757 slog.Warn("failed to generate title", "warn", "empty title")
758 return
759 }
760
761 session.Title = title
762
763 var openrouterCost *float64
764 for _, step := range resp.Steps {
765 stepCost := a.openrouterCost(step.ProviderMetadata)
766 if stepCost != nil {
767 newCost := *stepCost
768 if openrouterCost != nil {
769 newCost += *openrouterCost
770 }
771 openrouterCost = &newCost
772 }
773 }
774
775 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
776 _, saveErr := a.sessions.Save(ctx, *session)
777 if saveErr != nil {
778 slog.Error("failed to save session title & usage", "error", saveErr)
779 return
780 }
781}
782
783func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
784 openrouterMetadata, ok := metadata[openrouter.Name]
785 if !ok {
786 return nil
787 }
788
789 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
790 if !ok {
791 return nil
792 }
793 return &opts.Usage.Cost
794}
795
796func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
797 modelConfig := model.CatwalkCfg
798 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
799 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
800 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
801 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
802
803 if a.isClaudeCode() {
804 cost = 0
805 }
806
807 a.eventTokensUsed(session.ID, model, usage, cost)
808
809 if overrideCost != nil {
810 session.Cost += *overrideCost
811 } else {
812 session.Cost += cost
813 }
814
815 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
816 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
817}
818
819func (a *sessionAgent) Cancel(sessionID string) {
820 // Cancel regular requests.
821 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
822 slog.Info("Request cancellation initiated", "session_id", sessionID)
823 cancel()
824 }
825
826 // Also check for summarize requests.
827 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
828 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
829 cancel()
830 }
831
832 if a.QueuedPrompts(sessionID) > 0 {
833 slog.Info("Clearing queued prompts", "session_id", sessionID)
834 a.messageQueue.Del(sessionID)
835 }
836}
837
838func (a *sessionAgent) ClearQueue(sessionID string) {
839 if a.QueuedPrompts(sessionID) > 0 {
840 slog.Info("Clearing queued prompts", "session_id", sessionID)
841 a.messageQueue.Del(sessionID)
842 }
843}
844
845func (a *sessionAgent) CancelAll() {
846 if !a.IsBusy() {
847 return
848 }
849 for key := range a.activeRequests.Seq2() {
850 a.Cancel(key) // key is sessionID
851 }
852
853 timeout := time.After(5 * time.Second)
854 for a.IsBusy() {
855 select {
856 case <-timeout:
857 return
858 default:
859 time.Sleep(200 * time.Millisecond)
860 }
861 }
862}
863
864func (a *sessionAgent) IsBusy() bool {
865 var busy bool
866 for cancelFunc := range a.activeRequests.Seq() {
867 if cancelFunc != nil {
868 busy = true
869 break
870 }
871 }
872 return busy
873}
874
875func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
876 _, busy := a.activeRequests.Get(sessionID)
877 return busy
878}
879
880func (a *sessionAgent) QueuedPrompts(sessionID string) int {
881 l, ok := a.messageQueue.Get(sessionID)
882 if !ok {
883 return 0
884 }
885 return len(l)
886}
887
888func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
889 l, ok := a.messageQueue.Get(sessionID)
890 if !ok {
891 return nil
892 }
893 prompts := make([]string, len(l))
894 for i, call := range l {
895 prompts[i] = call.Prompt
896 }
897 return prompts
898}
899
900func (a *sessionAgent) SetModels(large Model, small Model) {
901 a.largeModel = large
902 a.smallModel = small
903}
904
905func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
906 a.tools = tools
907}
908
909func (a *sessionAgent) Model() Model {
910 return a.largeModel
911}
912
913func (a *sessionAgent) promptPrefix() string {
914 if a.isClaudeCode() {
915 return "You are Claude Code, Anthropic's official CLI for Claude."
916 }
917 return a.systemPromptPrefix
918}
919
920func (a *sessionAgent) isClaudeCode() bool {
921 cfg := config.Get()
922 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
923 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
924}
925
926// convertToToolResult converts a fantasy tool result to a message tool result.
927func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
928 baseResult := message.ToolResult{
929 ToolCallID: result.ToolCallID,
930 Name: result.ToolName,
931 Metadata: result.ClientMetadata,
932 }
933
934 switch result.Result.GetType() {
935 case fantasy.ToolResultContentTypeText:
936 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
937 baseResult.Content = r.Text
938 }
939 case fantasy.ToolResultContentTypeError:
940 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
941 baseResult.Content = r.Error.Error()
942 baseResult.IsError = true
943 }
944 case fantasy.ToolResultContentTypeMedia:
945 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
946 content := r.Text
947 if content == "" {
948 content = fmt.Sprintf("Loaded %s content", r.MediaType)
949 }
950 baseResult.Content = content
951 baseResult.Data = r.Data
952 baseResult.MIMEType = r.MediaType
953 }
954 }
955
956 return baseResult
957}
958
959// workaroundProviderMediaLimitations converts media content in tool results to
960// user messages for providers that don't natively support images in tool results.
961//
962// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
963// don't support sending images/media in tool result messages - they only accept
964// text in tool results. However, they DO support images in user messages.
965//
966// If we send media in tool results to these providers, the API returns an error.
967//
968// Solution: For these providers, we:
969// 1. Replace the media in the tool result with a text placeholder
970// 2. Inject a user message immediately after with the image as a file attachment
971// 3. This maintains the tool execution flow while working around API limitations
972//
973// Anthropic and Bedrock support images natively in tool results, so we skip
974// this workaround for them.
975//
976// Example transformation:
977//
978// BEFORE: [tool result: image data]
979// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
980func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
981 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
982 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
983
984 if providerSupportsMedia {
985 return messages
986 }
987
988 convertedMessages := make([]fantasy.Message, 0, len(messages))
989
990 for _, msg := range messages {
991 if msg.Role != fantasy.MessageRoleTool {
992 convertedMessages = append(convertedMessages, msg)
993 continue
994 }
995
996 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
997 var mediaFiles []fantasy.FilePart
998
999 for _, part := range msg.Content {
1000 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1001 if !ok {
1002 textParts = append(textParts, part)
1003 continue
1004 }
1005
1006 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1007 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1008 if err != nil {
1009 slog.Warn("failed to decode media data", "error", err)
1010 textParts = append(textParts, part)
1011 continue
1012 }
1013
1014 mediaFiles = append(mediaFiles, fantasy.FilePart{
1015 Data: decoded,
1016 MediaType: media.MediaType,
1017 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1018 })
1019
1020 textParts = append(textParts, fantasy.ToolResultPart{
1021 ToolCallID: toolResult.ToolCallID,
1022 Output: fantasy.ToolResultOutputContentText{
1023 Text: "[Image/media content loaded - see attached file]",
1024 },
1025 ProviderOptions: toolResult.ProviderOptions,
1026 })
1027 } else {
1028 textParts = append(textParts, part)
1029 }
1030 }
1031
1032 convertedMessages = append(convertedMessages, fantasy.Message{
1033 Role: fantasy.MessageRoleTool,
1034 Content: textParts,
1035 })
1036
1037 if len(mediaFiles) > 0 {
1038 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1039 "Here is the media content from the tool result:",
1040 mediaFiles...,
1041 ))
1042 }
1043 }
1044
1045 return convertedMessages
1046}