1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "charm.land/lipgloss/v2"
31 "git.secluded.site/crush/internal/agent/hyper"
32 "git.secluded.site/crush/internal/agent/tools"
33 "git.secluded.site/crush/internal/config"
34 "git.secluded.site/crush/internal/csync"
35 "git.secluded.site/crush/internal/message"
36 "git.secluded.site/crush/internal/notification"
37 "git.secluded.site/crush/internal/permission"
38 "git.secluded.site/crush/internal/session"
39 "git.secluded.site/crush/internal/stringext"
40 "github.com/charmbracelet/catwalk/pkg/catwalk"
41)
42
43//go:embed templates/title.md
44var titlePrompt []byte
45
46//go:embed templates/summary.md
47var summaryPrompt []byte
48
49type SessionAgentCall struct {
50 SessionID string
51 Prompt string
52 ProviderOptions fantasy.ProviderOptions
53 Attachments []message.Attachment
54 MaxOutputTokens int64
55 Temperature *float64
56 TopP *float64
57 TopK *int64
58 FrequencyPenalty *float64
59 PresencePenalty *float64
60 NonInteractive bool
61}
62
63type SessionAgent interface {
64 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
65 SetModels(large Model, small Model)
66 SetTools(tools []fantasy.AgentTool)
67 Cancel(sessionID string)
68 CancelAll()
69 IsSessionBusy(sessionID string) bool
70 IsBusy() bool
71 QueuedPrompts(sessionID string) int
72 QueuedPromptsList(sessionID string) []string
73 ClearQueue(sessionID string)
74 Summarize(context.Context, string, fantasy.ProviderOptions) error
75 Model() Model
76}
77
78type Model struct {
79 Model fantasy.LanguageModel
80 CatwalkCfg catwalk.Model
81 ModelCfg config.SelectedModel
82}
83
84type sessionAgent struct {
85 largeModel Model
86 smallModel Model
87 systemPromptPrefix string
88 systemPrompt string
89 isSubAgent bool
90 tools []fantasy.AgentTool
91 sessions session.Service
92 messages message.Service
93 disableAutoSummarize bool
94 isYolo bool
95
96 messageQueue *csync.Map[string, []SessionAgentCall]
97 activeRequests *csync.Map[string, context.CancelFunc]
98}
99
100type SessionAgentOptions struct {
101 LargeModel Model
102 SmallModel Model
103 SystemPromptPrefix string
104 SystemPrompt string
105 IsSubAgent bool
106 DisableAutoSummarize bool
107 IsYolo bool
108 Sessions session.Service
109 Messages message.Service
110 Tools []fantasy.AgentTool
111}
112
113func NewSessionAgent(
114 opts SessionAgentOptions,
115) SessionAgent {
116 return &sessionAgent{
117 largeModel: opts.LargeModel,
118 smallModel: opts.SmallModel,
119 systemPromptPrefix: opts.SystemPromptPrefix,
120 systemPrompt: opts.SystemPrompt,
121 isSubAgent: opts.IsSubAgent,
122 sessions: opts.Sessions,
123 messages: opts.Messages,
124 disableAutoSummarize: opts.DisableAutoSummarize,
125 tools: opts.Tools,
126 isYolo: opts.IsYolo,
127 messageQueue: csync.NewMap[string, []SessionAgentCall](),
128 activeRequests: csync.NewMap[string, context.CancelFunc](),
129 }
130}
131
132func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
133 if call.Prompt == "" {
134 return nil, ErrEmptyPrompt
135 }
136 if call.SessionID == "" {
137 return nil, ErrSessionMissing
138 }
139
140 // Queue the message if busy
141 if a.IsSessionBusy(call.SessionID) {
142 existing, ok := a.messageQueue.Get(call.SessionID)
143 if !ok {
144 existing = []SessionAgentCall{}
145 }
146 existing = append(existing, call)
147 a.messageQueue.Set(call.SessionID, existing)
148 return nil, nil
149 }
150
151 if len(a.tools) > 0 {
152 // Add Anthropic caching to the last tool.
153 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
154 }
155
156 agent := fantasy.NewAgent(
157 a.largeModel.Model,
158 fantasy.WithSystemPrompt(a.systemPrompt),
159 fantasy.WithTools(a.tools...),
160 )
161
162 sessionLock := sync.Mutex{}
163 currentSession, err := a.sessions.Get(ctx, call.SessionID)
164 if err != nil {
165 return nil, fmt.Errorf("failed to get session: %w", err)
166 }
167
168 msgs, err := a.getSessionMessages(ctx, currentSession)
169 if err != nil {
170 return nil, fmt.Errorf("failed to get session messages: %w", err)
171 }
172
173 var wg sync.WaitGroup
174 // Generate title if first message.
175 if len(msgs) == 0 {
176 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
177 wg.Go(func() {
178 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
179 })
180 }
181
182 // Add the user message to the session.
183 _, err = a.createUserMessage(ctx, call)
184 if err != nil {
185 return nil, err
186 }
187
188 // Add the session to the context.
189 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
190
191 genCtx, cancel := context.WithCancel(ctx)
192 a.activeRequests.Set(call.SessionID, cancel)
193
194 defer cancel()
195 defer a.activeRequests.Del(call.SessionID)
196
197 history, files := a.preparePrompt(msgs, call.Attachments...)
198
199 startTime := time.Now()
200 a.eventPromptSent(call.SessionID)
201
202 var currentAssistant *message.Message
203 var shouldSummarize bool
204 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
205 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
206 Files: files,
207 Messages: history,
208 ProviderOptions: call.ProviderOptions,
209 MaxOutputTokens: &call.MaxOutputTokens,
210 TopP: call.TopP,
211 Temperature: call.Temperature,
212 PresencePenalty: call.PresencePenalty,
213 TopK: call.TopK,
214 FrequencyPenalty: call.FrequencyPenalty,
215 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
216 prepared.Messages = options.Messages
217 for i := range prepared.Messages {
218 prepared.Messages[i].ProviderOptions = nil
219 }
220
221 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
222 a.messageQueue.Del(call.SessionID)
223 for _, queued := range queuedCalls {
224 userMessage, createErr := a.createUserMessage(callContext, queued)
225 if createErr != nil {
226 return callContext, prepared, createErr
227 }
228 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
229 }
230
231 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
232
233 lastSystemRoleInx := 0
234 systemMessageUpdated := false
235 for i, msg := range prepared.Messages {
236 // Only add cache control to the last message.
237 if msg.Role == fantasy.MessageRoleSystem {
238 lastSystemRoleInx = i
239 } else if !systemMessageUpdated {
240 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
241 systemMessageUpdated = true
242 }
243 // Than add cache control to the last 2 messages.
244 if i > len(prepared.Messages)-3 {
245 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
246 }
247 }
248
249 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
250 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
251 }
252
253 var assistantMsg message.Message
254 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
255 Role: message.Assistant,
256 Parts: []message.ContentPart{},
257 Model: a.largeModel.ModelCfg.Model,
258 Provider: a.largeModel.ModelCfg.Provider,
259 })
260 if err != nil {
261 return callContext, prepared, err
262 }
263 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
264 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
265 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
266 currentAssistant = &assistantMsg
267 return callContext, prepared, err
268 },
269 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
270 currentAssistant.AppendReasoningContent(reasoning.Text)
271 return a.messages.Update(genCtx, *currentAssistant)
272 },
273 OnReasoningDelta: func(id string, text string) error {
274 currentAssistant.AppendReasoningContent(text)
275 return a.messages.Update(genCtx, *currentAssistant)
276 },
277 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
278 // handle anthropic signature
279 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
280 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
281 currentAssistant.AppendReasoningSignature(reasoning.Signature)
282 }
283 }
284 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
285 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
286 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
287 }
288 }
289 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
290 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
291 currentAssistant.SetReasoningResponsesData(reasoning)
292 }
293 }
294 currentAssistant.FinishThinking()
295 return a.messages.Update(genCtx, *currentAssistant)
296 },
297 OnTextDelta: func(id string, text string) error {
298 // Strip leading newline from initial text content. This is is
299 // particularly important in non-interactive mode where leading
300 // newlines are very visible.
301 if len(currentAssistant.Parts) == 0 {
302 text = strings.TrimPrefix(text, "\n")
303 }
304
305 currentAssistant.AppendContent(text)
306 return a.messages.Update(genCtx, *currentAssistant)
307 },
308 OnToolInputStart: func(id string, toolName string) error {
309 toolCall := message.ToolCall{
310 ID: id,
311 Name: toolName,
312 ProviderExecuted: false,
313 Finished: false,
314 }
315 currentAssistant.AddToolCall(toolCall)
316 return a.messages.Update(genCtx, *currentAssistant)
317 },
318 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
319 // TODO: implement
320 },
321 OnToolCall: func(tc fantasy.ToolCallContent) error {
322 toolCall := message.ToolCall{
323 ID: tc.ToolCallID,
324 Name: tc.ToolName,
325 Input: tc.Input,
326 ProviderExecuted: false,
327 Finished: true,
328 }
329 currentAssistant.AddToolCall(toolCall)
330 return a.messages.Update(genCtx, *currentAssistant)
331 },
332 OnToolResult: func(result fantasy.ToolResultContent) error {
333 toolResult := a.convertToToolResult(result)
334 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
335 Role: message.Tool,
336 Parts: []message.ContentPart{
337 toolResult,
338 },
339 })
340 return createMsgErr
341 },
342 OnStepFinish: func(stepResult fantasy.StepResult) error {
343 finishReason := message.FinishReasonUnknown
344 switch stepResult.FinishReason {
345 case fantasy.FinishReasonLength:
346 finishReason = message.FinishReasonMaxTokens
347 case fantasy.FinishReasonStop:
348 finishReason = message.FinishReasonEndTurn
349 case fantasy.FinishReasonToolCalls:
350 finishReason = message.FinishReasonToolUse
351 }
352 currentAssistant.AddFinish(finishReason, "", "")
353 sessionLock.Lock()
354 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
355 if getSessionErr != nil {
356 sessionLock.Unlock()
357 return getSessionErr
358 }
359 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
360 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
361 sessionLock.Unlock()
362 if sessionErr != nil {
363 return sessionErr
364 }
365 return a.messages.Update(genCtx, *currentAssistant)
366 },
367 StopWhen: []fantasy.StopCondition{
368 func(_ []fantasy.StepResult) bool {
369 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
370 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
371 remaining := cw - tokens
372 var threshold int64
373 if cw > 200_000 {
374 threshold = 20_000
375 } else {
376 threshold = int64(float64(cw) * 0.2)
377 }
378 if (remaining <= threshold) && !a.disableAutoSummarize {
379 shouldSummarize = true
380 return true
381 }
382 return false
383 },
384 },
385 })
386
387 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
388
389 if err != nil {
390 isCancelErr := errors.Is(err, context.Canceled)
391 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
392 if currentAssistant == nil {
393 return result, err
394 }
395 // Ensure we finish thinking on error to close the reasoning state.
396 currentAssistant.FinishThinking()
397 toolCalls := currentAssistant.ToolCalls()
398 // INFO: we use the parent context here because the genCtx has been cancelled.
399 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
400 if createErr != nil {
401 return nil, createErr
402 }
403 for _, tc := range toolCalls {
404 if !tc.Finished {
405 tc.Finished = true
406 tc.Input = "{}"
407 currentAssistant.AddToolCall(tc)
408 updateErr := a.messages.Update(ctx, *currentAssistant)
409 if updateErr != nil {
410 return nil, updateErr
411 }
412 }
413
414 found := false
415 for _, msg := range msgs {
416 if msg.Role == message.Tool {
417 for _, tr := range msg.ToolResults() {
418 if tr.ToolCallID == tc.ID {
419 found = true
420 break
421 }
422 }
423 }
424 if found {
425 break
426 }
427 }
428 if found {
429 continue
430 }
431 content := "There was an error while executing the tool"
432 if isCancelErr {
433 content = "Tool execution canceled by user"
434 } else if isPermissionErr {
435 content = "User denied permission"
436 }
437 toolResult := message.ToolResult{
438 ToolCallID: tc.ID,
439 Name: tc.Name,
440 Content: content,
441 IsError: true,
442 }
443 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
444 Role: message.Tool,
445 Parts: []message.ContentPart{
446 toolResult,
447 },
448 })
449 if createErr != nil {
450 return nil, createErr
451 }
452 }
453 var fantasyErr *fantasy.Error
454 var providerErr *fantasy.ProviderError
455 const defaultTitle = "Provider Error"
456 if isCancelErr {
457 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
458 } else if isPermissionErr {
459 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
460 } else if errors.Is(err, hyper.ErrNoCredits) {
461 url := hyper.BaseURL()
462 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
463 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
464 } else if errors.As(err, &providerErr) {
465 if providerErr.Message == "The requested model is not supported." {
466 url := "https://github.com/settings/copilot/features"
467 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
468 currentAssistant.AddFinish(
469 message.FinishReasonError,
470 "Copilot model not enabled",
471 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
472 )
473 } else {
474 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
475 }
476 } else if errors.As(err, &fantasyErr) {
477 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
478 } else {
479 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
480 }
481 // Note: we use the parent context here because the genCtx has been
482 // cancelled.
483 updateErr := a.messages.Update(ctx, *currentAssistant)
484 if updateErr != nil {
485 return nil, updateErr
486 }
487 return nil, err
488 }
489 wg.Wait()
490
491 // Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
492 if !call.NonInteractive {
493 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
494 _ = notification.Send("Crush is waiting...", notifBody)
495 }
496
497 if shouldSummarize {
498 a.activeRequests.Del(call.SessionID)
499 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
500 return nil, summarizeErr
501 }
502 // If the agent wasn't done...
503 if len(currentAssistant.ToolCalls()) > 0 {
504 existing, ok := a.messageQueue.Get(call.SessionID)
505 if !ok {
506 existing = []SessionAgentCall{}
507 }
508 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
509 existing = append(existing, call)
510 a.messageQueue.Set(call.SessionID, existing)
511 }
512 }
513
514 // Release active request before processing queued messages.
515 a.activeRequests.Del(call.SessionID)
516 cancel()
517
518 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
519 if !ok || len(queuedMessages) == 0 {
520 return result, err
521 }
522 // There are queued messages restart the loop.
523 firstQueuedMessage := queuedMessages[0]
524 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
525 return a.Run(ctx, firstQueuedMessage)
526}
527
528func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
529 if a.IsSessionBusy(sessionID) {
530 return ErrSessionBusy
531 }
532
533 currentSession, err := a.sessions.Get(ctx, sessionID)
534 if err != nil {
535 return fmt.Errorf("failed to get session: %w", err)
536 }
537 msgs, err := a.getSessionMessages(ctx, currentSession)
538 if err != nil {
539 return err
540 }
541 if len(msgs) == 0 {
542 // Nothing to summarize.
543 return nil
544 }
545
546 aiMsgs, _ := a.preparePrompt(msgs)
547
548 genCtx, cancel := context.WithCancel(ctx)
549 a.activeRequests.Set(sessionID, cancel)
550 defer a.activeRequests.Del(sessionID)
551 defer cancel()
552
553 agent := fantasy.NewAgent(a.largeModel.Model,
554 fantasy.WithSystemPrompt(string(summaryPrompt)),
555 )
556 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
557 Role: message.Assistant,
558 Model: a.largeModel.Model.Model(),
559 Provider: a.largeModel.Model.Provider(),
560 IsSummaryMessage: true,
561 })
562 if err != nil {
563 return err
564 }
565
566 summaryPromptText := "Provide a detailed summary of our conversation above."
567 if len(currentSession.Todos) > 0 {
568 summaryPromptText += "\n\n## Current Todo List\n\n"
569 for _, t := range currentSession.Todos {
570 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
571 }
572 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
573 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
574 }
575
576 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
577 Prompt: summaryPromptText,
578 Messages: aiMsgs,
579 ProviderOptions: opts,
580 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
581 prepared.Messages = options.Messages
582 if a.systemPromptPrefix != "" {
583 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
584 }
585 return callContext, prepared, nil
586 },
587 OnReasoningDelta: func(id string, text string) error {
588 summaryMessage.AppendReasoningContent(text)
589 return a.messages.Update(genCtx, summaryMessage)
590 },
591 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
592 // Handle anthropic signature.
593 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
594 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
595 summaryMessage.AppendReasoningSignature(signature.Signature)
596 }
597 }
598 summaryMessage.FinishThinking()
599 return a.messages.Update(genCtx, summaryMessage)
600 },
601 OnTextDelta: func(id, text string) error {
602 summaryMessage.AppendContent(text)
603 return a.messages.Update(genCtx, summaryMessage)
604 },
605 })
606 if err != nil {
607 isCancelErr := errors.Is(err, context.Canceled)
608 if isCancelErr {
609 // User cancelled summarize we need to remove the summary message.
610 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
611 return deleteErr
612 }
613 return err
614 }
615
616 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
617 err = a.messages.Update(genCtx, summaryMessage)
618 if err != nil {
619 return err
620 }
621
622 var openrouterCost *float64
623 for _, step := range resp.Steps {
624 stepCost := a.openrouterCost(step.ProviderMetadata)
625 if stepCost != nil {
626 newCost := *stepCost
627 if openrouterCost != nil {
628 newCost += *openrouterCost
629 }
630 openrouterCost = &newCost
631 }
632 }
633
634 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
635
636 // Just in case, get just the last usage info.
637 usage := resp.Response.Usage
638 currentSession.SummaryMessageID = summaryMessage.ID
639 currentSession.CompletionTokens = usage.OutputTokens
640 currentSession.PromptTokens = 0
641 _, err = a.sessions.Save(genCtx, currentSession)
642 return err
643}
644
645func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
646 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
647 return fantasy.ProviderOptions{}
648 }
649 return fantasy.ProviderOptions{
650 anthropic.Name: &anthropic.ProviderCacheControlOptions{
651 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
652 },
653 bedrock.Name: &anthropic.ProviderCacheControlOptions{
654 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
655 },
656 }
657}
658
659func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
660 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
661 var attachmentParts []message.ContentPart
662 for _, attachment := range call.Attachments {
663 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
664 }
665 parts = append(parts, attachmentParts...)
666 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
667 Role: message.User,
668 Parts: parts,
669 })
670 if err != nil {
671 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
672 }
673 return msg, nil
674}
675
676func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
677 var history []fantasy.Message
678 if !a.isSubAgent {
679 history = append(history, fantasy.NewUserMessage(
680 fmt.Sprintf("<system_reminder>%s</system_reminder>",
681 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
682If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
683If not, please feel free to ignore. Again do not mention this message to the user.`,
684 ),
685 ))
686 }
687 for _, m := range msgs {
688 if len(m.Parts) == 0 {
689 continue
690 }
691 // Assistant message without content or tool calls (cancelled before it
692 // returned anything).
693 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
694 continue
695 }
696 history = append(history, m.ToAIMessage()...)
697 }
698
699 var files []fantasy.FilePart
700 for _, attachment := range attachments {
701 if attachment.IsText() {
702 continue
703 }
704 files = append(files, fantasy.FilePart{
705 Filename: attachment.FileName,
706 Data: attachment.Content,
707 MediaType: attachment.MimeType,
708 })
709 }
710
711 return history, files
712}
713
714func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
715 msgs, err := a.messages.List(ctx, session.ID)
716 if err != nil {
717 return nil, fmt.Errorf("failed to list messages: %w", err)
718 }
719
720 if session.SummaryMessageID != "" {
721 summaryMsgInex := -1
722 for i, msg := range msgs {
723 if msg.ID == session.SummaryMessageID {
724 summaryMsgInex = i
725 break
726 }
727 }
728 if summaryMsgInex != -1 {
729 msgs = msgs[summaryMsgInex:]
730 msgs[0].Role = message.User
731 }
732 }
733 return msgs, nil
734}
735
736func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
737 if prompt == "" {
738 return
739 }
740
741 var maxOutput int64 = 40
742 if a.smallModel.CatwalkCfg.CanReason {
743 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
744 }
745
746 agent := fantasy.NewAgent(a.smallModel.Model,
747 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
748 fantasy.WithMaxOutputTokens(maxOutput),
749 )
750
751 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
752 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
753 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
754 prepared.Messages = options.Messages
755 if a.systemPromptPrefix != "" {
756 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
757 }
758 return callContext, prepared, nil
759 },
760 })
761 if err != nil {
762 slog.Error("error generating title", "err", err)
763 return
764 }
765
766 title := resp.Response.Content.Text()
767
768 title = strings.ReplaceAll(title, "\n", " ")
769
770 // Remove thinking tags if present.
771 if idx := strings.Index(title, "</think>"); idx > 0 {
772 title = title[idx+len("</think>"):]
773 }
774
775 title = strings.TrimSpace(title)
776 if title == "" {
777 slog.Warn("failed to generate title", "warn", "empty title")
778 return
779 }
780
781 // Calculate usage and cost.
782 var openrouterCost *float64
783 for _, step := range resp.Steps {
784 stepCost := a.openrouterCost(step.ProviderMetadata)
785 if stepCost != nil {
786 newCost := *stepCost
787 if openrouterCost != nil {
788 newCost += *openrouterCost
789 }
790 openrouterCost = &newCost
791 }
792 }
793
794 modelConfig := a.smallModel.CatwalkCfg
795 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
796 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
797 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
798 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
799
800 if a.isClaudeCode() {
801 cost = 0
802 }
803
804 // Use override cost if available (e.g., from OpenRouter).
805 if openrouterCost != nil {
806 cost = *openrouterCost
807 }
808
809 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
810 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
811
812 // Atomically update only title and usage fields to avoid overriding other
813 // concurrent session updates.
814 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
815 if saveErr != nil {
816 slog.Error("failed to save session title & usage", "error", saveErr)
817 return
818 }
819}
820
821func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
822 openrouterMetadata, ok := metadata[openrouter.Name]
823 if !ok {
824 return nil
825 }
826
827 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
828 if !ok {
829 return nil
830 }
831 return &opts.Usage.Cost
832}
833
834func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
835 modelConfig := model.CatwalkCfg
836 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
837 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
838 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
839 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
840
841 if a.isClaudeCode() {
842 cost = 0
843 }
844
845 a.eventTokensUsed(session.ID, model, usage, cost)
846
847 if overrideCost != nil {
848 session.Cost += *overrideCost
849 } else {
850 session.Cost += cost
851 }
852
853 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
854 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
855}
856
857func (a *sessionAgent) Cancel(sessionID string) {
858 // Cancel regular requests.
859 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
860 slog.Info("Request cancellation initiated", "session_id", sessionID)
861 cancel()
862 }
863
864 // Also check for summarize requests.
865 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
866 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
867 cancel()
868 }
869
870 if a.QueuedPrompts(sessionID) > 0 {
871 slog.Info("Clearing queued prompts", "session_id", sessionID)
872 a.messageQueue.Del(sessionID)
873 }
874}
875
876func (a *sessionAgent) ClearQueue(sessionID string) {
877 if a.QueuedPrompts(sessionID) > 0 {
878 slog.Info("Clearing queued prompts", "session_id", sessionID)
879 a.messageQueue.Del(sessionID)
880 }
881}
882
883func (a *sessionAgent) CancelAll() {
884 if !a.IsBusy() {
885 return
886 }
887 for key := range a.activeRequests.Seq2() {
888 a.Cancel(key) // key is sessionID
889 }
890
891 timeout := time.After(5 * time.Second)
892 for a.IsBusy() {
893 select {
894 case <-timeout:
895 return
896 default:
897 time.Sleep(200 * time.Millisecond)
898 }
899 }
900}
901
902func (a *sessionAgent) IsBusy() bool {
903 var busy bool
904 for cancelFunc := range a.activeRequests.Seq() {
905 if cancelFunc != nil {
906 busy = true
907 break
908 }
909 }
910 return busy
911}
912
913func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
914 _, busy := a.activeRequests.Get(sessionID)
915 return busy
916}
917
918func (a *sessionAgent) QueuedPrompts(sessionID string) int {
919 l, ok := a.messageQueue.Get(sessionID)
920 if !ok {
921 return 0
922 }
923 return len(l)
924}
925
926func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
927 l, ok := a.messageQueue.Get(sessionID)
928 if !ok {
929 return nil
930 }
931 prompts := make([]string, len(l))
932 for i, call := range l {
933 prompts[i] = call.Prompt
934 }
935 return prompts
936}
937
938func (a *sessionAgent) SetModels(large Model, small Model) {
939 a.largeModel = large
940 a.smallModel = small
941}
942
943func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
944 a.tools = tools
945}
946
947func (a *sessionAgent) Model() Model {
948 return a.largeModel
949}
950
951func (a *sessionAgent) promptPrefix() string {
952 if a.isClaudeCode() {
953 return "You are Claude Code, Anthropic's official CLI for Claude."
954 }
955 return a.systemPromptPrefix
956}
957
958func (a *sessionAgent) isClaudeCode() bool {
959 cfg := config.Get()
960 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
961 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
962}
963
964// convertToToolResult converts a fantasy tool result to a message tool result.
965func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
966 baseResult := message.ToolResult{
967 ToolCallID: result.ToolCallID,
968 Name: result.ToolName,
969 Metadata: result.ClientMetadata,
970 }
971
972 switch result.Result.GetType() {
973 case fantasy.ToolResultContentTypeText:
974 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
975 baseResult.Content = r.Text
976 }
977 case fantasy.ToolResultContentTypeError:
978 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
979 baseResult.Content = r.Error.Error()
980 baseResult.IsError = true
981 }
982 case fantasy.ToolResultContentTypeMedia:
983 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
984 content := r.Text
985 if content == "" {
986 content = fmt.Sprintf("Loaded %s content", r.MediaType)
987 }
988 baseResult.Content = content
989 baseResult.Data = r.Data
990 baseResult.MIMEType = r.MediaType
991 }
992 }
993
994 return baseResult
995}
996
997// workaroundProviderMediaLimitations converts media content in tool results to
998// user messages for providers that don't natively support images in tool results.
999//
1000// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1001// don't support sending images/media in tool result messages - they only accept
1002// text in tool results. However, they DO support images in user messages.
1003//
1004// If we send media in tool results to these providers, the API returns an error.
1005//
1006// Solution: For these providers, we:
1007// 1. Replace the media in the tool result with a text placeholder
1008// 2. Inject a user message immediately after with the image as a file attachment
1009// 3. This maintains the tool execution flow while working around API limitations
1010//
1011// Anthropic and Bedrock support images natively in tool results, so we skip
1012// this workaround for them.
1013//
1014// Example transformation:
1015//
1016// BEFORE: [tool result: image data]
1017// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1018func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1019 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1020 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1021
1022 if providerSupportsMedia {
1023 return messages
1024 }
1025
1026 convertedMessages := make([]fantasy.Message, 0, len(messages))
1027
1028 for _, msg := range messages {
1029 if msg.Role != fantasy.MessageRoleTool {
1030 convertedMessages = append(convertedMessages, msg)
1031 continue
1032 }
1033
1034 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1035 var mediaFiles []fantasy.FilePart
1036
1037 for _, part := range msg.Content {
1038 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1039 if !ok {
1040 textParts = append(textParts, part)
1041 continue
1042 }
1043
1044 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1045 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1046 if err != nil {
1047 slog.Warn("failed to decode media data", "error", err)
1048 textParts = append(textParts, part)
1049 continue
1050 }
1051
1052 mediaFiles = append(mediaFiles, fantasy.FilePart{
1053 Data: decoded,
1054 MediaType: media.MediaType,
1055 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1056 })
1057
1058 textParts = append(textParts, fantasy.ToolResultPart{
1059 ToolCallID: toolResult.ToolCallID,
1060 Output: fantasy.ToolResultOutputContentText{
1061 Text: "[Image/media content loaded - see attached file]",
1062 },
1063 ProviderOptions: toolResult.ProviderOptions,
1064 })
1065 } else {
1066 textParts = append(textParts, part)
1067 }
1068 }
1069
1070 convertedMessages = append(convertedMessages, fantasy.Message{
1071 Role: fantasy.MessageRoleTool,
1072 Content: textParts,
1073 })
1074
1075 if len(mediaFiles) > 0 {
1076 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1077 "Here is the media content from the tool result:",
1078 mediaFiles...,
1079 ))
1080 }
1081 }
1082
1083 return convertedMessages
1084}