1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "git.secluded.site/crush/internal/agent/tools"
31 "git.secluded.site/crush/internal/config"
32 "git.secluded.site/crush/internal/csync"
33 "git.secluded.site/crush/internal/message"
34 "git.secluded.site/crush/internal/notification"
35 "git.secluded.site/crush/internal/permission"
36 "git.secluded.site/crush/internal/session"
37 "git.secluded.site/crush/internal/stringext"
38 "github.com/charmbracelet/catwalk/pkg/catwalk"
39)
40
41//go:embed templates/title.md
42var titlePrompt []byte
43
44//go:embed templates/summary.md
45var summaryPrompt []byte
46
47type SessionAgentCall struct {
48 SessionID string
49 Prompt string
50 ProviderOptions fantasy.ProviderOptions
51 Attachments []message.Attachment
52 MaxOutputTokens int64
53 Temperature *float64
54 TopP *float64
55 TopK *int64
56 FrequencyPenalty *float64
57 PresencePenalty *float64
58 NonInteractive bool
59}
60
61type SessionAgent interface {
62 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
63 SetModels(large Model, small Model)
64 SetTools(tools []fantasy.AgentTool)
65 Cancel(sessionID string)
66 CancelAll()
67 IsSessionBusy(sessionID string) bool
68 IsBusy() bool
69 QueuedPrompts(sessionID string) int
70 ClearQueue(sessionID string)
71 Summarize(context.Context, string, fantasy.ProviderOptions) error
72 Model() Model
73}
74
75type Model struct {
76 Model fantasy.LanguageModel
77 CatwalkCfg catwalk.Model
78 ModelCfg config.SelectedModel
79}
80
81type sessionAgent struct {
82 largeModel Model
83 smallModel Model
84 systemPromptPrefix string
85 systemPrompt string
86 tools []fantasy.AgentTool
87 sessions session.Service
88 messages message.Service
89 disableAutoSummarize bool
90 isYolo bool
91
92 messageQueue *csync.Map[string, []SessionAgentCall]
93 activeRequests *csync.Map[string, context.CancelFunc]
94}
95
96type SessionAgentOptions struct {
97 LargeModel Model
98 SmallModel Model
99 SystemPromptPrefix string
100 SystemPrompt string
101 DisableAutoSummarize bool
102 IsYolo bool
103 Sessions session.Service
104 Messages message.Service
105 Tools []fantasy.AgentTool
106}
107
108func NewSessionAgent(
109 opts SessionAgentOptions,
110) SessionAgent {
111 return &sessionAgent{
112 largeModel: opts.LargeModel,
113 smallModel: opts.SmallModel,
114 systemPromptPrefix: opts.SystemPromptPrefix,
115 systemPrompt: opts.SystemPrompt,
116 sessions: opts.Sessions,
117 messages: opts.Messages,
118 disableAutoSummarize: opts.DisableAutoSummarize,
119 tools: opts.Tools,
120 isYolo: opts.IsYolo,
121 messageQueue: csync.NewMap[string, []SessionAgentCall](),
122 activeRequests: csync.NewMap[string, context.CancelFunc](),
123 }
124}
125
126func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
127 if call.Prompt == "" {
128 return nil, ErrEmptyPrompt
129 }
130 if call.SessionID == "" {
131 return nil, ErrSessionMissing
132 }
133
134 // Queue the message if busy
135 if a.IsSessionBusy(call.SessionID) {
136 existing, ok := a.messageQueue.Get(call.SessionID)
137 if !ok {
138 existing = []SessionAgentCall{}
139 }
140 existing = append(existing, call)
141 a.messageQueue.Set(call.SessionID, existing)
142 return nil, nil
143 }
144
145 if len(a.tools) > 0 {
146 // Add Anthropic caching to the last tool.
147 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
148 }
149
150 agent := fantasy.NewAgent(
151 a.largeModel.Model,
152 fantasy.WithSystemPrompt(a.systemPrompt),
153 fantasy.WithTools(a.tools...),
154 )
155
156 sessionLock := sync.Mutex{}
157 currentSession, err := a.sessions.Get(ctx, call.SessionID)
158 if err != nil {
159 return nil, fmt.Errorf("failed to get session: %w", err)
160 }
161
162 msgs, err := a.getSessionMessages(ctx, currentSession)
163 if err != nil {
164 return nil, fmt.Errorf("failed to get session messages: %w", err)
165 }
166
167 var wg sync.WaitGroup
168 // Generate title if first message.
169 if len(msgs) == 0 {
170 wg.Go(func() {
171 sessionLock.Lock()
172 a.generateTitle(ctx, ¤tSession, call.Prompt)
173 sessionLock.Unlock()
174 })
175 }
176
177 // Add the user message to the session.
178 _, err = a.createUserMessage(ctx, call)
179 if err != nil {
180 return nil, err
181 }
182
183 // Add the session to the context.
184 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
185
186 genCtx, cancel := context.WithCancel(ctx)
187 a.activeRequests.Set(call.SessionID, cancel)
188
189 defer cancel()
190 defer a.activeRequests.Del(call.SessionID)
191
192 history, files := a.preparePrompt(msgs, call.Attachments...)
193
194 startTime := time.Now()
195 a.eventPromptSent(call.SessionID)
196
197 var currentAssistant *message.Message
198 var shouldSummarize bool
199 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
200 Prompt: call.Prompt,
201 Files: files,
202 Messages: history,
203 ProviderOptions: call.ProviderOptions,
204 MaxOutputTokens: &call.MaxOutputTokens,
205 TopP: call.TopP,
206 Temperature: call.Temperature,
207 PresencePenalty: call.PresencePenalty,
208 TopK: call.TopK,
209 FrequencyPenalty: call.FrequencyPenalty,
210 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
211 prepared.Messages = options.Messages
212 for i := range prepared.Messages {
213 prepared.Messages[i].ProviderOptions = nil
214 }
215
216 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
217 a.messageQueue.Del(call.SessionID)
218 for _, queued := range queuedCalls {
219 userMessage, createErr := a.createUserMessage(callContext, queued)
220 if createErr != nil {
221 return callContext, prepared, createErr
222 }
223 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
224 }
225
226 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
227
228 lastSystemRoleInx := 0
229 systemMessageUpdated := false
230 for i, msg := range prepared.Messages {
231 // Only add cache control to the last message.
232 if msg.Role == fantasy.MessageRoleSystem {
233 lastSystemRoleInx = i
234 } else if !systemMessageUpdated {
235 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
236 systemMessageUpdated = true
237 }
238 // Than add cache control to the last 2 messages.
239 if i > len(prepared.Messages)-3 {
240 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
241 }
242 }
243
244 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
245 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
246 }
247
248 var assistantMsg message.Message
249 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
250 Role: message.Assistant,
251 Parts: []message.ContentPart{},
252 Model: a.largeModel.ModelCfg.Model,
253 Provider: a.largeModel.ModelCfg.Provider,
254 })
255 if err != nil {
256 return callContext, prepared, err
257 }
258 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
259 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
260 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
261 currentAssistant = &assistantMsg
262 return callContext, prepared, err
263 },
264 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
265 currentAssistant.AppendReasoningContent(reasoning.Text)
266 return a.messages.Update(genCtx, *currentAssistant)
267 },
268 OnReasoningDelta: func(id string, text string) error {
269 currentAssistant.AppendReasoningContent(text)
270 return a.messages.Update(genCtx, *currentAssistant)
271 },
272 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
273 // handle anthropic signature
274 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
275 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
276 currentAssistant.AppendReasoningSignature(reasoning.Signature)
277 }
278 }
279 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
280 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
281 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
282 }
283 }
284 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
285 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
286 currentAssistant.SetReasoningResponsesData(reasoning)
287 }
288 }
289 currentAssistant.FinishThinking()
290 return a.messages.Update(genCtx, *currentAssistant)
291 },
292 OnTextDelta: func(id string, text string) error {
293 // Strip leading newline from initial text content. This is is
294 // particularly important in non-interactive mode where leading
295 // newlines are very visible.
296 if len(currentAssistant.Parts) == 0 {
297 text = strings.TrimPrefix(text, "\n")
298 }
299
300 currentAssistant.AppendContent(text)
301 return a.messages.Update(genCtx, *currentAssistant)
302 },
303 OnToolInputStart: func(id string, toolName string) error {
304 toolCall := message.ToolCall{
305 ID: id,
306 Name: toolName,
307 ProviderExecuted: false,
308 Finished: false,
309 }
310 currentAssistant.AddToolCall(toolCall)
311 return a.messages.Update(genCtx, *currentAssistant)
312 },
313 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
314 // TODO: implement
315 },
316 OnToolCall: func(tc fantasy.ToolCallContent) error {
317 toolCall := message.ToolCall{
318 ID: tc.ToolCallID,
319 Name: tc.ToolName,
320 Input: tc.Input,
321 ProviderExecuted: false,
322 Finished: true,
323 }
324 currentAssistant.AddToolCall(toolCall)
325 return a.messages.Update(genCtx, *currentAssistant)
326 },
327 OnToolResult: func(result fantasy.ToolResultContent) error {
328 toolResult := a.convertToToolResult(result)
329 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
330 Role: message.Tool,
331 Parts: []message.ContentPart{
332 toolResult,
333 },
334 })
335 return createMsgErr
336 },
337 OnStepFinish: func(stepResult fantasy.StepResult) error {
338 finishReason := message.FinishReasonUnknown
339 switch stepResult.FinishReason {
340 case fantasy.FinishReasonLength:
341 finishReason = message.FinishReasonMaxTokens
342 case fantasy.FinishReasonStop:
343 finishReason = message.FinishReasonEndTurn
344 case fantasy.FinishReasonToolCalls:
345 finishReason = message.FinishReasonToolUse
346 }
347 currentAssistant.AddFinish(finishReason, "", "")
348 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
349 sessionLock.Lock()
350 _, sessionErr := a.sessions.Save(genCtx, currentSession)
351 sessionLock.Unlock()
352 if sessionErr != nil {
353 return sessionErr
354 }
355 return a.messages.Update(genCtx, *currentAssistant)
356 },
357 StopWhen: []fantasy.StopCondition{
358 func(_ []fantasy.StepResult) bool {
359 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
360 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
361 remaining := cw - tokens
362 var threshold int64
363 if cw > 200_000 {
364 threshold = 20_000
365 } else {
366 threshold = int64(float64(cw) * 0.2)
367 }
368 if (remaining <= threshold) && !a.disableAutoSummarize {
369 shouldSummarize = true
370 return true
371 }
372 return false
373 },
374 },
375 })
376
377 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
378
379 if err != nil {
380 isCancelErr := errors.Is(err, context.Canceled)
381 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
382 if currentAssistant == nil {
383 return result, err
384 }
385 // Ensure we finish thinking on error to close the reasoning state.
386 currentAssistant.FinishThinking()
387 toolCalls := currentAssistant.ToolCalls()
388 // INFO: we use the parent context here because the genCtx has been cancelled.
389 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
390 if createErr != nil {
391 return nil, createErr
392 }
393 for _, tc := range toolCalls {
394 if !tc.Finished {
395 tc.Finished = true
396 tc.Input = "{}"
397 currentAssistant.AddToolCall(tc)
398 updateErr := a.messages.Update(ctx, *currentAssistant)
399 if updateErr != nil {
400 return nil, updateErr
401 }
402 }
403
404 found := false
405 for _, msg := range msgs {
406 if msg.Role == message.Tool {
407 for _, tr := range msg.ToolResults() {
408 if tr.ToolCallID == tc.ID {
409 found = true
410 break
411 }
412 }
413 }
414 if found {
415 break
416 }
417 }
418 if found {
419 continue
420 }
421 content := "There was an error while executing the tool"
422 if isCancelErr {
423 content = "Tool execution canceled by user"
424 } else if isPermissionErr {
425 content = "User denied permission"
426 }
427 toolResult := message.ToolResult{
428 ToolCallID: tc.ID,
429 Name: tc.Name,
430 Content: content,
431 IsError: true,
432 }
433 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
434 Role: message.Tool,
435 Parts: []message.ContentPart{
436 toolResult,
437 },
438 })
439 if createErr != nil {
440 return nil, createErr
441 }
442 }
443 var fantasyErr *fantasy.Error
444 var providerErr *fantasy.ProviderError
445 const defaultTitle = "Provider Error"
446 if isCancelErr {
447 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
448 } else if isPermissionErr {
449 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
450 } else if errors.As(err, &providerErr) {
451 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
452 } else if errors.As(err, &fantasyErr) {
453 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
454 } else {
455 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
456 }
457 // Note: we use the parent context here because the genCtx has been
458 // cancelled.
459 updateErr := a.messages.Update(ctx, *currentAssistant)
460 if updateErr != nil {
461 return nil, updateErr
462 }
463 return nil, err
464 }
465 wg.Wait()
466
467 // Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
468 if !call.NonInteractive {
469 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
470 _ = notification.Send("Crush is waiting...", notifBody)
471 }
472
473 if shouldSummarize {
474 a.activeRequests.Del(call.SessionID)
475 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
476 return nil, summarizeErr
477 }
478 // If the agent wasn't done...
479 if len(currentAssistant.ToolCalls()) > 0 {
480 existing, ok := a.messageQueue.Get(call.SessionID)
481 if !ok {
482 existing = []SessionAgentCall{}
483 }
484 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
485 existing = append(existing, call)
486 a.messageQueue.Set(call.SessionID, existing)
487 }
488 }
489
490 // Release active request before processing queued messages.
491 a.activeRequests.Del(call.SessionID)
492 cancel()
493
494 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
495 if !ok || len(queuedMessages) == 0 {
496 return result, err
497 }
498 // There are queued messages restart the loop.
499 firstQueuedMessage := queuedMessages[0]
500 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
501 return a.Run(ctx, firstQueuedMessage)
502}
503
504func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
505 if a.IsSessionBusy(sessionID) {
506 return ErrSessionBusy
507 }
508
509 currentSession, err := a.sessions.Get(ctx, sessionID)
510 if err != nil {
511 return fmt.Errorf("failed to get session: %w", err)
512 }
513 msgs, err := a.getSessionMessages(ctx, currentSession)
514 if err != nil {
515 return err
516 }
517 if len(msgs) == 0 {
518 // Nothing to summarize.
519 return nil
520 }
521
522 aiMsgs, _ := a.preparePrompt(msgs)
523
524 genCtx, cancel := context.WithCancel(ctx)
525 a.activeRequests.Set(sessionID, cancel)
526 defer a.activeRequests.Del(sessionID)
527 defer cancel()
528
529 agent := fantasy.NewAgent(a.largeModel.Model,
530 fantasy.WithSystemPrompt(string(summaryPrompt)),
531 )
532 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
533 Role: message.Assistant,
534 Model: a.largeModel.Model.Model(),
535 Provider: a.largeModel.Model.Provider(),
536 IsSummaryMessage: true,
537 })
538 if err != nil {
539 return err
540 }
541
542 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
543 Prompt: "Provide a detailed summary of our conversation above.",
544 Messages: aiMsgs,
545 ProviderOptions: opts,
546 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
547 prepared.Messages = options.Messages
548 if a.systemPromptPrefix != "" {
549 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
550 }
551 return callContext, prepared, nil
552 },
553 OnReasoningDelta: func(id string, text string) error {
554 summaryMessage.AppendReasoningContent(text)
555 return a.messages.Update(genCtx, summaryMessage)
556 },
557 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
558 // Handle anthropic signature.
559 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
560 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
561 summaryMessage.AppendReasoningSignature(signature.Signature)
562 }
563 }
564 summaryMessage.FinishThinking()
565 return a.messages.Update(genCtx, summaryMessage)
566 },
567 OnTextDelta: func(id, text string) error {
568 summaryMessage.AppendContent(text)
569 return a.messages.Update(genCtx, summaryMessage)
570 },
571 })
572 if err != nil {
573 isCancelErr := errors.Is(err, context.Canceled)
574 if isCancelErr {
575 // User cancelled summarize we need to remove the summary message.
576 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
577 return deleteErr
578 }
579 return err
580 }
581
582 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
583 err = a.messages.Update(genCtx, summaryMessage)
584 if err != nil {
585 return err
586 }
587
588 var openrouterCost *float64
589 for _, step := range resp.Steps {
590 stepCost := a.openrouterCost(step.ProviderMetadata)
591 if stepCost != nil {
592 newCost := *stepCost
593 if openrouterCost != nil {
594 newCost += *openrouterCost
595 }
596 openrouterCost = &newCost
597 }
598 }
599
600 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
601
602 // Just in case, get just the last usage info.
603 usage := resp.Response.Usage
604 currentSession.SummaryMessageID = summaryMessage.ID
605 currentSession.CompletionTokens = usage.OutputTokens
606 currentSession.PromptTokens = 0
607 _, err = a.sessions.Save(genCtx, currentSession)
608 return err
609}
610
611func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
612 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
613 return fantasy.ProviderOptions{}
614 }
615 return fantasy.ProviderOptions{
616 anthropic.Name: &anthropic.ProviderCacheControlOptions{
617 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
618 },
619 bedrock.Name: &anthropic.ProviderCacheControlOptions{
620 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
621 },
622 }
623}
624
625func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
626 var attachmentParts []message.ContentPart
627 for _, attachment := range call.Attachments {
628 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
629 }
630 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
631 parts = append(parts, attachmentParts...)
632 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
633 Role: message.User,
634 Parts: parts,
635 })
636 if err != nil {
637 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
638 }
639 return msg, nil
640}
641
642func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
643 var history []fantasy.Message
644 for _, m := range msgs {
645 if len(m.Parts) == 0 {
646 continue
647 }
648 // Assistant message without content or tool calls (cancelled before it
649 // returned anything).
650 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
651 continue
652 }
653 history = append(history, m.ToAIMessage()...)
654 }
655
656 var files []fantasy.FilePart
657 for _, attachment := range attachments {
658 files = append(files, fantasy.FilePart{
659 Filename: attachment.FileName,
660 Data: attachment.Content,
661 MediaType: attachment.MimeType,
662 })
663 }
664
665 return history, files
666}
667
668func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
669 msgs, err := a.messages.List(ctx, session.ID)
670 if err != nil {
671 return nil, fmt.Errorf("failed to list messages: %w", err)
672 }
673
674 if session.SummaryMessageID != "" {
675 summaryMsgInex := -1
676 for i, msg := range msgs {
677 if msg.ID == session.SummaryMessageID {
678 summaryMsgInex = i
679 break
680 }
681 }
682 if summaryMsgInex != -1 {
683 msgs = msgs[summaryMsgInex:]
684 msgs[0].Role = message.User
685 }
686 }
687 return msgs, nil
688}
689
690func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
691 if prompt == "" {
692 return
693 }
694
695 var maxOutput int64 = 40
696 if a.smallModel.CatwalkCfg.CanReason {
697 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
698 }
699
700 agent := fantasy.NewAgent(a.smallModel.Model,
701 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
702 fantasy.WithMaxOutputTokens(maxOutput),
703 )
704
705 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
706 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
707 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
708 prepared.Messages = options.Messages
709 if a.systemPromptPrefix != "" {
710 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
711 }
712 return callContext, prepared, nil
713 },
714 })
715 if err != nil {
716 slog.Error("error generating title", "err", err)
717 return
718 }
719
720 title := resp.Response.Content.Text()
721
722 title = strings.ReplaceAll(title, "\n", " ")
723
724 // Remove thinking tags if present.
725 if idx := strings.Index(title, "</think>"); idx > 0 {
726 title = title[idx+len("</think>"):]
727 }
728
729 title = strings.TrimSpace(title)
730 if title == "" {
731 slog.Warn("failed to generate title", "warn", "empty title")
732 return
733 }
734
735 session.Title = title
736
737 var openrouterCost *float64
738 for _, step := range resp.Steps {
739 stepCost := a.openrouterCost(step.ProviderMetadata)
740 if stepCost != nil {
741 newCost := *stepCost
742 if openrouterCost != nil {
743 newCost += *openrouterCost
744 }
745 openrouterCost = &newCost
746 }
747 }
748
749 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
750 _, saveErr := a.sessions.Save(ctx, *session)
751 if saveErr != nil {
752 slog.Error("failed to save session title & usage", "error", saveErr)
753 return
754 }
755}
756
757func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
758 openrouterMetadata, ok := metadata[openrouter.Name]
759 if !ok {
760 return nil
761 }
762
763 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
764 if !ok {
765 return nil
766 }
767 return &opts.Usage.Cost
768}
769
770func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
771 modelConfig := model.CatwalkCfg
772 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
773 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
774 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
775 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
776
777 if a.isClaudeCode() {
778 cost = 0
779 }
780
781 a.eventTokensUsed(session.ID, model, usage, cost)
782
783 if overrideCost != nil {
784 session.Cost += *overrideCost
785 } else {
786 session.Cost += cost
787 }
788
789 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
790 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
791}
792
793func (a *sessionAgent) Cancel(sessionID string) {
794 // Cancel regular requests.
795 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
796 slog.Info("Request cancellation initiated", "session_id", sessionID)
797 cancel()
798 }
799
800 // Also check for summarize requests.
801 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
802 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
803 cancel()
804 }
805
806 if a.QueuedPrompts(sessionID) > 0 {
807 slog.Info("Clearing queued prompts", "session_id", sessionID)
808 a.messageQueue.Del(sessionID)
809 }
810}
811
812func (a *sessionAgent) ClearQueue(sessionID string) {
813 if a.QueuedPrompts(sessionID) > 0 {
814 slog.Info("Clearing queued prompts", "session_id", sessionID)
815 a.messageQueue.Del(sessionID)
816 }
817}
818
819func (a *sessionAgent) CancelAll() {
820 if !a.IsBusy() {
821 return
822 }
823 for key := range a.activeRequests.Seq2() {
824 a.Cancel(key) // key is sessionID
825 }
826
827 timeout := time.After(5 * time.Second)
828 for a.IsBusy() {
829 select {
830 case <-timeout:
831 return
832 default:
833 time.Sleep(200 * time.Millisecond)
834 }
835 }
836}
837
838func (a *sessionAgent) IsBusy() bool {
839 var busy bool
840 for cancelFunc := range a.activeRequests.Seq() {
841 if cancelFunc != nil {
842 busy = true
843 break
844 }
845 }
846 return busy
847}
848
849func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
850 _, busy := a.activeRequests.Get(sessionID)
851 return busy
852}
853
854func (a *sessionAgent) QueuedPrompts(sessionID string) int {
855 l, ok := a.messageQueue.Get(sessionID)
856 if !ok {
857 return 0
858 }
859 return len(l)
860}
861
862func (a *sessionAgent) SetModels(large Model, small Model) {
863 a.largeModel = large
864 a.smallModel = small
865}
866
867func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
868 a.tools = tools
869}
870
871func (a *sessionAgent) Model() Model {
872 return a.largeModel
873}
874
875func (a *sessionAgent) promptPrefix() string {
876 if a.isClaudeCode() {
877 return "You are Claude Code, Anthropic's official CLI for Claude."
878 }
879 return a.systemPromptPrefix
880}
881
882func (a *sessionAgent) isClaudeCode() bool {
883 cfg := config.Get()
884 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
885 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
886}
887
888// convertToToolResult converts a fantasy tool result to a message tool result.
889func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
890 baseResult := message.ToolResult{
891 ToolCallID: result.ToolCallID,
892 Name: result.ToolName,
893 Metadata: result.ClientMetadata,
894 }
895
896 switch result.Result.GetType() {
897 case fantasy.ToolResultContentTypeText:
898 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
899 baseResult.Content = r.Text
900 }
901 case fantasy.ToolResultContentTypeError:
902 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
903 baseResult.Content = r.Error.Error()
904 baseResult.IsError = true
905 }
906 case fantasy.ToolResultContentTypeMedia:
907 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
908 content := r.Text
909 if content == "" {
910 content = fmt.Sprintf("Loaded %s content", r.MediaType)
911 }
912 baseResult.Content = content
913 baseResult.Data = r.Data
914 baseResult.MIMEType = r.MediaType
915 }
916 }
917
918 return baseResult
919}
920
921// workaroundProviderMediaLimitations converts media content in tool results to
922// user messages for providers that don't natively support images in tool results.
923//
924// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
925// don't support sending images/media in tool result messages - they only accept
926// text in tool results. However, they DO support images in user messages.
927//
928// If we send media in tool results to these providers, the API returns an error.
929//
930// Solution: For these providers, we:
931// 1. Replace the media in the tool result with a text placeholder
932// 2. Inject a user message immediately after with the image as a file attachment
933// 3. This maintains the tool execution flow while working around API limitations
934//
935// Anthropic and Bedrock support images natively in tool results, so we skip
936// this workaround for them.
937//
938// Example transformation:
939//
940// BEFORE: [tool result: image data]
941// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
942func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
943 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
944 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
945
946 if providerSupportsMedia {
947 return messages
948 }
949
950 convertedMessages := make([]fantasy.Message, 0, len(messages))
951
952 for _, msg := range messages {
953 if msg.Role != fantasy.MessageRoleTool {
954 convertedMessages = append(convertedMessages, msg)
955 continue
956 }
957
958 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
959 var mediaFiles []fantasy.FilePart
960
961 for _, part := range msg.Content {
962 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
963 if !ok {
964 textParts = append(textParts, part)
965 continue
966 }
967
968 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
969 decoded, err := base64.StdEncoding.DecodeString(media.Data)
970 if err != nil {
971 slog.Warn("failed to decode media data", "error", err)
972 textParts = append(textParts, part)
973 continue
974 }
975
976 mediaFiles = append(mediaFiles, fantasy.FilePart{
977 Data: decoded,
978 MediaType: media.MediaType,
979 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
980 })
981
982 textParts = append(textParts, fantasy.ToolResultPart{
983 ToolCallID: toolResult.ToolCallID,
984 Output: fantasy.ToolResultOutputContentText{
985 Text: "[Image/media content loaded - see attached file]",
986 },
987 ProviderOptions: toolResult.ProviderOptions,
988 })
989 } else {
990 textParts = append(textParts, part)
991 }
992 }
993
994 convertedMessages = append(convertedMessages, fantasy.Message{
995 Role: fantasy.MessageRoleTool,
996 Content: textParts,
997 })
998
999 if len(mediaFiles) > 0 {
1000 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1001 "Here is the media content from the tool result:",
1002 mediaFiles...,
1003 ))
1004 }
1005 }
1006
1007 return convertedMessages
1008}