1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "github.com/charmbracelet/catwalk/pkg/catwalk"
31 "github.com/charmbracelet/crush/internal/agent/tools"
32 "github.com/charmbracelet/crush/internal/config"
33 "github.com/charmbracelet/crush/internal/csync"
34 "github.com/charmbracelet/crush/internal/message"
35 "github.com/charmbracelet/crush/internal/permission"
36 "github.com/charmbracelet/crush/internal/session"
37 "github.com/charmbracelet/crush/internal/stringext"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57}
58
59type SessionAgent interface {
60 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
61 SetModels(large Model, small Model)
62 SetTools(tools []fantasy.AgentTool)
63 Cancel(sessionID string)
64 CancelAll()
65 IsSessionBusy(sessionID string) bool
66 IsBusy() bool
67 QueuedPrompts(sessionID string) int
68 ClearQueue(sessionID string)
69 Summarize(context.Context, string, fantasy.ProviderOptions) error
70 Model() Model
71}
72
73type Model struct {
74 Model fantasy.LanguageModel
75 CatwalkCfg catwalk.Model
76 ModelCfg config.SelectedModel
77}
78
79type sessionAgent struct {
80 largeModel Model
81 smallModel Model
82 systemPromptPrefix string
83 systemPrompt string
84 tools []fantasy.AgentTool
85 sessions session.Service
86 messages message.Service
87 disableAutoSummarize bool
88 isYolo bool
89
90 messageQueue *csync.Map[string, []SessionAgentCall]
91 activeRequests *csync.Map[string, context.CancelFunc]
92}
93
94type SessionAgentOptions struct {
95 LargeModel Model
96 SmallModel Model
97 SystemPromptPrefix string
98 SystemPrompt string
99 DisableAutoSummarize bool
100 IsYolo bool
101 Sessions session.Service
102 Messages message.Service
103 Tools []fantasy.AgentTool
104}
105
106func NewSessionAgent(
107 opts SessionAgentOptions,
108) SessionAgent {
109 return &sessionAgent{
110 largeModel: opts.LargeModel,
111 smallModel: opts.SmallModel,
112 systemPromptPrefix: opts.SystemPromptPrefix,
113 systemPrompt: opts.SystemPrompt,
114 sessions: opts.Sessions,
115 messages: opts.Messages,
116 disableAutoSummarize: opts.DisableAutoSummarize,
117 tools: opts.Tools,
118 isYolo: opts.IsYolo,
119 messageQueue: csync.NewMap[string, []SessionAgentCall](),
120 activeRequests: csync.NewMap[string, context.CancelFunc](),
121 }
122}
123
124func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
125 if call.Prompt == "" {
126 return nil, ErrEmptyPrompt
127 }
128 if call.SessionID == "" {
129 return nil, ErrSessionMissing
130 }
131
132 // Queue the message if busy
133 if a.IsSessionBusy(call.SessionID) {
134 existing, ok := a.messageQueue.Get(call.SessionID)
135 if !ok {
136 existing = []SessionAgentCall{}
137 }
138 existing = append(existing, call)
139 a.messageQueue.Set(call.SessionID, existing)
140 return nil, nil
141 }
142
143 if len(a.tools) > 0 {
144 // Add Anthropic caching to the last tool.
145 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
146 }
147
148 agent := fantasy.NewAgent(
149 a.largeModel.Model,
150 fantasy.WithSystemPrompt(a.systemPrompt),
151 fantasy.WithTools(a.tools...),
152 )
153
154 sessionLock := sync.Mutex{}
155 currentSession, err := a.sessions.Get(ctx, call.SessionID)
156 if err != nil {
157 return nil, fmt.Errorf("failed to get session: %w", err)
158 }
159
160 msgs, err := a.getSessionMessages(ctx, currentSession)
161 if err != nil {
162 return nil, fmt.Errorf("failed to get session messages: %w", err)
163 }
164
165 var wg sync.WaitGroup
166 // Generate title if first message.
167 if len(msgs) == 0 {
168 wg.Go(func() {
169 sessionLock.Lock()
170 a.generateTitle(ctx, ¤tSession, call.Prompt)
171 sessionLock.Unlock()
172 })
173 }
174
175 // Add the user message to the session.
176 _, err = a.createUserMessage(ctx, call)
177 if err != nil {
178 return nil, err
179 }
180
181 // Add the session to the context.
182 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
183
184 genCtx, cancel := context.WithCancel(ctx)
185 a.activeRequests.Set(call.SessionID, cancel)
186
187 defer cancel()
188 defer a.activeRequests.Del(call.SessionID)
189
190 history, files := a.preparePrompt(msgs, call.Attachments...)
191
192 startTime := time.Now()
193 a.eventPromptSent(call.SessionID)
194
195 var currentAssistant *message.Message
196 var shouldSummarize bool
197 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
198 Prompt: call.Prompt,
199 Files: files,
200 Messages: history,
201 ProviderOptions: call.ProviderOptions,
202 MaxOutputTokens: &call.MaxOutputTokens,
203 TopP: call.TopP,
204 Temperature: call.Temperature,
205 PresencePenalty: call.PresencePenalty,
206 TopK: call.TopK,
207 FrequencyPenalty: call.FrequencyPenalty,
208 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
209 prepared.Messages = options.Messages
210 for i := range prepared.Messages {
211 prepared.Messages[i].ProviderOptions = nil
212 }
213
214 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
215 a.messageQueue.Del(call.SessionID)
216 for _, queued := range queuedCalls {
217 userMessage, createErr := a.createUserMessage(callContext, queued)
218 if createErr != nil {
219 return callContext, prepared, createErr
220 }
221 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
222 }
223
224 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
225
226 lastSystemRoleInx := 0
227 systemMessageUpdated := false
228 for i, msg := range prepared.Messages {
229 // Only add cache control to the last message.
230 if msg.Role == fantasy.MessageRoleSystem {
231 lastSystemRoleInx = i
232 } else if !systemMessageUpdated {
233 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
234 systemMessageUpdated = true
235 }
236 // Than add cache control to the last 2 messages.
237 if i > len(prepared.Messages)-3 {
238 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
239 }
240 }
241
242 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
243 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
244 }
245
246 var assistantMsg message.Message
247 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
248 Role: message.Assistant,
249 Parts: []message.ContentPart{},
250 Model: a.largeModel.ModelCfg.Model,
251 Provider: a.largeModel.ModelCfg.Provider,
252 })
253 if err != nil {
254 return callContext, prepared, err
255 }
256 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
257 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
258 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
259 currentAssistant = &assistantMsg
260 return callContext, prepared, err
261 },
262 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
263 currentAssistant.AppendReasoningContent(reasoning.Text)
264 return a.messages.Update(genCtx, *currentAssistant)
265 },
266 OnReasoningDelta: func(id string, text string) error {
267 currentAssistant.AppendReasoningContent(text)
268 return a.messages.Update(genCtx, *currentAssistant)
269 },
270 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
271 // handle anthropic signature
272 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
273 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
274 currentAssistant.AppendReasoningSignature(reasoning.Signature)
275 }
276 }
277 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
278 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
279 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
280 }
281 }
282 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
283 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
284 currentAssistant.SetReasoningResponsesData(reasoning)
285 }
286 }
287 currentAssistant.FinishThinking()
288 return a.messages.Update(genCtx, *currentAssistant)
289 },
290 OnTextDelta: func(id string, text string) error {
291 // Strip leading newline from initial text content. This is is
292 // particularly important in non-interactive mode where leading
293 // newlines are very visible.
294 if len(currentAssistant.Parts) == 0 {
295 text = strings.TrimPrefix(text, "\n")
296 }
297
298 currentAssistant.AppendContent(text)
299 return a.messages.Update(genCtx, *currentAssistant)
300 },
301 OnToolInputStart: func(id string, toolName string) error {
302 toolCall := message.ToolCall{
303 ID: id,
304 Name: toolName,
305 ProviderExecuted: false,
306 Finished: false,
307 }
308 currentAssistant.AddToolCall(toolCall)
309 return a.messages.Update(genCtx, *currentAssistant)
310 },
311 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
312 // TODO: implement
313 },
314 OnToolCall: func(tc fantasy.ToolCallContent) error {
315 toolCall := message.ToolCall{
316 ID: tc.ToolCallID,
317 Name: tc.ToolName,
318 Input: tc.Input,
319 ProviderExecuted: false,
320 Finished: true,
321 }
322 currentAssistant.AddToolCall(toolCall)
323 return a.messages.Update(genCtx, *currentAssistant)
324 },
325 OnToolResult: func(result fantasy.ToolResultContent) error {
326 toolResult := a.convertToToolResult(result)
327 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
328 Role: message.Tool,
329 Parts: []message.ContentPart{
330 toolResult,
331 },
332 })
333 return createMsgErr
334 },
335 OnStepFinish: func(stepResult fantasy.StepResult) error {
336 finishReason := message.FinishReasonUnknown
337 switch stepResult.FinishReason {
338 case fantasy.FinishReasonLength:
339 finishReason = message.FinishReasonMaxTokens
340 case fantasy.FinishReasonStop:
341 finishReason = message.FinishReasonEndTurn
342 case fantasy.FinishReasonToolCalls:
343 finishReason = message.FinishReasonToolUse
344 }
345 currentAssistant.AddFinish(finishReason, "", "")
346 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
347 sessionLock.Lock()
348 _, sessionErr := a.sessions.Save(genCtx, currentSession)
349 sessionLock.Unlock()
350 if sessionErr != nil {
351 return sessionErr
352 }
353 return a.messages.Update(genCtx, *currentAssistant)
354 },
355 StopWhen: []fantasy.StopCondition{
356 func(_ []fantasy.StepResult) bool {
357 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
358 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
359 remaining := cw - tokens
360 var threshold int64
361 if cw > 200_000 {
362 threshold = 20_000
363 } else {
364 threshold = int64(float64(cw) * 0.2)
365 }
366 if (remaining <= threshold) && !a.disableAutoSummarize {
367 shouldSummarize = true
368 return true
369 }
370 return false
371 },
372 },
373 })
374
375 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
376
377 if err != nil {
378 isCancelErr := errors.Is(err, context.Canceled)
379 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
380 if currentAssistant == nil {
381 return result, err
382 }
383 // Ensure we finish thinking on error to close the reasoning state.
384 currentAssistant.FinishThinking()
385 toolCalls := currentAssistant.ToolCalls()
386 // INFO: we use the parent context here because the genCtx has been cancelled.
387 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
388 if createErr != nil {
389 return nil, createErr
390 }
391 for _, tc := range toolCalls {
392 if !tc.Finished {
393 tc.Finished = true
394 tc.Input = "{}"
395 currentAssistant.AddToolCall(tc)
396 updateErr := a.messages.Update(ctx, *currentAssistant)
397 if updateErr != nil {
398 return nil, updateErr
399 }
400 }
401
402 found := false
403 for _, msg := range msgs {
404 if msg.Role == message.Tool {
405 for _, tr := range msg.ToolResults() {
406 if tr.ToolCallID == tc.ID {
407 found = true
408 break
409 }
410 }
411 }
412 if found {
413 break
414 }
415 }
416 if found {
417 continue
418 }
419 content := "There was an error while executing the tool"
420 if isCancelErr {
421 content = "Tool execution canceled by user"
422 } else if isPermissionErr {
423 content = "User denied permission"
424 }
425 toolResult := message.ToolResult{
426 ToolCallID: tc.ID,
427 Name: tc.Name,
428 Content: content,
429 IsError: true,
430 }
431 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
432 Role: message.Tool,
433 Parts: []message.ContentPart{
434 toolResult,
435 },
436 })
437 if createErr != nil {
438 return nil, createErr
439 }
440 }
441 var fantasyErr *fantasy.Error
442 var providerErr *fantasy.ProviderError
443 const defaultTitle = "Provider Error"
444 if isCancelErr {
445 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
446 } else if isPermissionErr {
447 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
448 } else if errors.As(err, &providerErr) {
449 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
450 } else if errors.As(err, &fantasyErr) {
451 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
452 } else {
453 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
454 }
455 // Note: we use the parent context here because the genCtx has been
456 // cancelled.
457 updateErr := a.messages.Update(ctx, *currentAssistant)
458 if updateErr != nil {
459 return nil, updateErr
460 }
461 return nil, err
462 }
463 wg.Wait()
464
465 if shouldSummarize {
466 a.activeRequests.Del(call.SessionID)
467 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
468 return nil, summarizeErr
469 }
470 // If the agent wasn't done...
471 if len(currentAssistant.ToolCalls()) > 0 {
472 existing, ok := a.messageQueue.Get(call.SessionID)
473 if !ok {
474 existing = []SessionAgentCall{}
475 }
476 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
477 existing = append(existing, call)
478 a.messageQueue.Set(call.SessionID, existing)
479 }
480 }
481
482 // Release active request before processing queued messages.
483 a.activeRequests.Del(call.SessionID)
484 cancel()
485
486 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
487 if !ok || len(queuedMessages) == 0 {
488 return result, err
489 }
490 // There are queued messages restart the loop.
491 firstQueuedMessage := queuedMessages[0]
492 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
493 return a.Run(ctx, firstQueuedMessage)
494}
495
496func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
497 if a.IsSessionBusy(sessionID) {
498 return ErrSessionBusy
499 }
500
501 currentSession, err := a.sessions.Get(ctx, sessionID)
502 if err != nil {
503 return fmt.Errorf("failed to get session: %w", err)
504 }
505 msgs, err := a.getSessionMessages(ctx, currentSession)
506 if err != nil {
507 return err
508 }
509 if len(msgs) == 0 {
510 // Nothing to summarize.
511 return nil
512 }
513
514 aiMsgs, _ := a.preparePrompt(msgs)
515
516 genCtx, cancel := context.WithCancel(ctx)
517 a.activeRequests.Set(sessionID, cancel)
518 defer a.activeRequests.Del(sessionID)
519 defer cancel()
520
521 agent := fantasy.NewAgent(a.largeModel.Model,
522 fantasy.WithSystemPrompt(string(summaryPrompt)),
523 )
524 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
525 Role: message.Assistant,
526 Model: a.largeModel.Model.Model(),
527 Provider: a.largeModel.Model.Provider(),
528 IsSummaryMessage: true,
529 })
530 if err != nil {
531 return err
532 }
533
534 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
535 Prompt: "Provide a detailed summary of our conversation above.",
536 Messages: aiMsgs,
537 ProviderOptions: opts,
538 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
539 prepared.Messages = options.Messages
540 if a.systemPromptPrefix != "" {
541 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
542 }
543 return callContext, prepared, nil
544 },
545 OnReasoningDelta: func(id string, text string) error {
546 summaryMessage.AppendReasoningContent(text)
547 return a.messages.Update(genCtx, summaryMessage)
548 },
549 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
550 // Handle anthropic signature.
551 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
552 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
553 summaryMessage.AppendReasoningSignature(signature.Signature)
554 }
555 }
556 summaryMessage.FinishThinking()
557 return a.messages.Update(genCtx, summaryMessage)
558 },
559 OnTextDelta: func(id, text string) error {
560 summaryMessage.AppendContent(text)
561 return a.messages.Update(genCtx, summaryMessage)
562 },
563 })
564 if err != nil {
565 isCancelErr := errors.Is(err, context.Canceled)
566 if isCancelErr {
567 // User cancelled summarize we need to remove the summary message.
568 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
569 return deleteErr
570 }
571 return err
572 }
573
574 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
575 err = a.messages.Update(genCtx, summaryMessage)
576 if err != nil {
577 return err
578 }
579
580 var openrouterCost *float64
581 for _, step := range resp.Steps {
582 stepCost := a.openrouterCost(step.ProviderMetadata)
583 if stepCost != nil {
584 newCost := *stepCost
585 if openrouterCost != nil {
586 newCost += *openrouterCost
587 }
588 openrouterCost = &newCost
589 }
590 }
591
592 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
593
594 // Just in case, get just the last usage info.
595 usage := resp.Response.Usage
596 currentSession.SummaryMessageID = summaryMessage.ID
597 currentSession.CompletionTokens = usage.OutputTokens
598 currentSession.PromptTokens = 0
599 _, err = a.sessions.Save(genCtx, currentSession)
600 return err
601}
602
603func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
604 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
605 return fantasy.ProviderOptions{}
606 }
607 return fantasy.ProviderOptions{
608 anthropic.Name: &anthropic.ProviderCacheControlOptions{
609 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
610 },
611 bedrock.Name: &anthropic.ProviderCacheControlOptions{
612 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
613 },
614 }
615}
616
617func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
618 var attachmentParts []message.ContentPart
619 for _, attachment := range call.Attachments {
620 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
621 }
622 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
623 parts = append(parts, attachmentParts...)
624 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
625 Role: message.User,
626 Parts: parts,
627 })
628 if err != nil {
629 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
630 }
631 return msg, nil
632}
633
634func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
635 var history []fantasy.Message
636 for _, m := range msgs {
637 if len(m.Parts) == 0 {
638 continue
639 }
640 // Assistant message without content or tool calls (cancelled before it
641 // returned anything).
642 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
643 continue
644 }
645 history = append(history, m.ToAIMessage()...)
646 }
647
648 var files []fantasy.FilePart
649 for _, attachment := range attachments {
650 files = append(files, fantasy.FilePart{
651 Filename: attachment.FileName,
652 Data: attachment.Content,
653 MediaType: attachment.MimeType,
654 })
655 }
656
657 return history, files
658}
659
660func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
661 msgs, err := a.messages.List(ctx, session.ID)
662 if err != nil {
663 return nil, fmt.Errorf("failed to list messages: %w", err)
664 }
665
666 if session.SummaryMessageID != "" {
667 summaryMsgInex := -1
668 for i, msg := range msgs {
669 if msg.ID == session.SummaryMessageID {
670 summaryMsgInex = i
671 break
672 }
673 }
674 if summaryMsgInex != -1 {
675 msgs = msgs[summaryMsgInex:]
676 msgs[0].Role = message.User
677 }
678 }
679 return msgs, nil
680}
681
682func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
683 if prompt == "" {
684 return
685 }
686
687 var maxOutput int64 = 40
688 if a.smallModel.CatwalkCfg.CanReason {
689 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
690 }
691
692 agent := fantasy.NewAgent(a.smallModel.Model,
693 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
694 fantasy.WithMaxOutputTokens(maxOutput),
695 )
696
697 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
698 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
699 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
700 prepared.Messages = options.Messages
701 if a.systemPromptPrefix != "" {
702 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
703 }
704 return callContext, prepared, nil
705 },
706 })
707 if err != nil {
708 slog.Error("error generating title", "err", err)
709 return
710 }
711
712 title := resp.Response.Content.Text()
713
714 title = strings.ReplaceAll(title, "\n", " ")
715
716 // Remove thinking tags if present.
717 if idx := strings.Index(title, "</think>"); idx > 0 {
718 title = title[idx+len("</think>"):]
719 }
720
721 title = strings.TrimSpace(title)
722 if title == "" {
723 slog.Warn("failed to generate title", "warn", "empty title")
724 return
725 }
726
727 session.Title = title
728
729 var openrouterCost *float64
730 for _, step := range resp.Steps {
731 stepCost := a.openrouterCost(step.ProviderMetadata)
732 if stepCost != nil {
733 newCost := *stepCost
734 if openrouterCost != nil {
735 newCost += *openrouterCost
736 }
737 openrouterCost = &newCost
738 }
739 }
740
741 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
742 _, saveErr := a.sessions.Save(ctx, *session)
743 if saveErr != nil {
744 slog.Error("failed to save session title & usage", "error", saveErr)
745 return
746 }
747}
748
749func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
750 openrouterMetadata, ok := metadata[openrouter.Name]
751 if !ok {
752 return nil
753 }
754
755 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
756 if !ok {
757 return nil
758 }
759 return &opts.Usage.Cost
760}
761
762func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
763 modelConfig := model.CatwalkCfg
764 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
765 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
766 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
767 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
768
769 if a.isClaudeCode() {
770 cost = 0
771 }
772
773 a.eventTokensUsed(session.ID, model, usage, cost)
774
775 if overrideCost != nil {
776 session.Cost += *overrideCost
777 } else {
778 session.Cost += cost
779 }
780
781 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
782 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
783}
784
785func (a *sessionAgent) Cancel(sessionID string) {
786 // Cancel regular requests.
787 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
788 slog.Info("Request cancellation initiated", "session_id", sessionID)
789 cancel()
790 }
791
792 // Also check for summarize requests.
793 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
794 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
795 cancel()
796 }
797
798 if a.QueuedPrompts(sessionID) > 0 {
799 slog.Info("Clearing queued prompts", "session_id", sessionID)
800 a.messageQueue.Del(sessionID)
801 }
802}
803
804func (a *sessionAgent) ClearQueue(sessionID string) {
805 if a.QueuedPrompts(sessionID) > 0 {
806 slog.Info("Clearing queued prompts", "session_id", sessionID)
807 a.messageQueue.Del(sessionID)
808 }
809}
810
811func (a *sessionAgent) CancelAll() {
812 if !a.IsBusy() {
813 return
814 }
815 for key := range a.activeRequests.Seq2() {
816 a.Cancel(key) // key is sessionID
817 }
818
819 timeout := time.After(5 * time.Second)
820 for a.IsBusy() {
821 select {
822 case <-timeout:
823 return
824 default:
825 time.Sleep(200 * time.Millisecond)
826 }
827 }
828}
829
830func (a *sessionAgent) IsBusy() bool {
831 var busy bool
832 for cancelFunc := range a.activeRequests.Seq() {
833 if cancelFunc != nil {
834 busy = true
835 break
836 }
837 }
838 return busy
839}
840
841func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
842 _, busy := a.activeRequests.Get(sessionID)
843 return busy
844}
845
846func (a *sessionAgent) QueuedPrompts(sessionID string) int {
847 l, ok := a.messageQueue.Get(sessionID)
848 if !ok {
849 return 0
850 }
851 return len(l)
852}
853
854func (a *sessionAgent) SetModels(large Model, small Model) {
855 a.largeModel = large
856 a.smallModel = small
857}
858
859func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
860 a.tools = tools
861}
862
863func (a *sessionAgent) Model() Model {
864 return a.largeModel
865}
866
867func (a *sessionAgent) promptPrefix() string {
868 if a.isClaudeCode() {
869 return "You are Claude Code, Anthropic's official CLI for Claude."
870 }
871 return a.systemPromptPrefix
872}
873
874func (a *sessionAgent) isClaudeCode() bool {
875 cfg := config.Get()
876 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
877 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
878}
879
880// convertToToolResult converts a fantasy tool result to a message tool result.
881func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
882 baseResult := message.ToolResult{
883 ToolCallID: result.ToolCallID,
884 Name: result.ToolName,
885 Metadata: result.ClientMetadata,
886 }
887
888 switch result.Result.GetType() {
889 case fantasy.ToolResultContentTypeText:
890 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
891 baseResult.Content = r.Text
892 }
893 case fantasy.ToolResultContentTypeError:
894 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
895 baseResult.Content = r.Error.Error()
896 baseResult.IsError = true
897 }
898 case fantasy.ToolResultContentTypeMedia:
899 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
900 content := r.Text
901 if content == "" {
902 content = fmt.Sprintf("Loaded %s content", r.MediaType)
903 }
904 baseResult.Content = content
905 baseResult.Data = r.Data
906 baseResult.MIMEType = r.MediaType
907 }
908 }
909
910 return baseResult
911}
912
913// workaroundProviderMediaLimitations converts media content in tool results to
914// user messages for providers that don't natively support images in tool results.
915//
916// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
917// don't support sending images/media in tool result messages - they only accept
918// text in tool results. However, they DO support images in user messages.
919//
920// If we send media in tool results to these providers, the API returns an error.
921//
922// Solution: For these providers, we:
923// 1. Replace the media in the tool result with a text placeholder
924// 2. Inject a user message immediately after with the image as a file attachment
925// 3. This maintains the tool execution flow while working around API limitations
926//
927// Anthropic and Bedrock support images natively in tool results, so we skip
928// this workaround for them.
929//
930// Example transformation:
931//
932// BEFORE: [tool result: image data]
933// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
934func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
935 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
936 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
937
938 if providerSupportsMedia {
939 return messages
940 }
941
942 convertedMessages := make([]fantasy.Message, 0, len(messages))
943
944 for _, msg := range messages {
945 if msg.Role != fantasy.MessageRoleTool {
946 convertedMessages = append(convertedMessages, msg)
947 continue
948 }
949
950 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
951 var mediaFiles []fantasy.FilePart
952
953 for _, part := range msg.Content {
954 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
955 if !ok {
956 textParts = append(textParts, part)
957 continue
958 }
959
960 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
961 decoded, err := base64.StdEncoding.DecodeString(media.Data)
962 if err != nil {
963 slog.Warn("failed to decode media data", "error", err)
964 textParts = append(textParts, part)
965 continue
966 }
967
968 mediaFiles = append(mediaFiles, fantasy.FilePart{
969 Data: decoded,
970 MediaType: media.MediaType,
971 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
972 })
973
974 textParts = append(textParts, fantasy.ToolResultPart{
975 ToolCallID: toolResult.ToolCallID,
976 Output: fantasy.ToolResultOutputContentText{
977 Text: "[Image/media content loaded - see attached file]",
978 },
979 ProviderOptions: toolResult.ProviderOptions,
980 })
981 } else {
982 textParts = append(textParts, part)
983 }
984 }
985
986 convertedMessages = append(convertedMessages, fantasy.Message{
987 Role: fantasy.MessageRoleTool,
988 Content: textParts,
989 })
990
991 if len(mediaFiles) > 0 {
992 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
993 "Here is the media content from the tool result:",
994 mediaFiles...,
995 ))
996 }
997 }
998
999 return convertedMessages
1000}