1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43const defaultSessionName = "Untitled Session"
44
45//go:embed templates/title.md
46var titlePrompt []byte
47
48//go:embed templates/summary.md
49var summaryPrompt []byte
50
51// Used to remove <think> tags from generated titles.
52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
53
54type SessionAgentCall struct {
55 SessionID string
56 Prompt string
57 ProviderOptions fantasy.ProviderOptions
58 Attachments []message.Attachment
59 MaxOutputTokens int64
60 Temperature *float64
61 TopP *float64
62 TopK *int64
63 FrequencyPenalty *float64
64 PresencePenalty *float64
65}
66
67type SessionAgent interface {
68 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
69 SetModels(large Model, small Model)
70 SetTools(tools []fantasy.AgentTool)
71 Cancel(sessionID string)
72 CancelAll()
73 IsSessionBusy(sessionID string) bool
74 IsBusy() bool
75 QueuedPrompts(sessionID string) int
76 QueuedPromptsList(sessionID string) []string
77 ClearQueue(sessionID string)
78 Summarize(context.Context, string, fantasy.ProviderOptions) error
79 Model() Model
80}
81
82type Model struct {
83 Model fantasy.LanguageModel
84 CatwalkCfg catwalk.Model
85 ModelCfg config.SelectedModel
86}
87
88type sessionAgent struct {
89 largeModel Model
90 smallModel Model
91 systemPromptPrefix string
92 systemPrompt string
93 isSubAgent bool
94 tools []fantasy.AgentTool
95 sessions session.Service
96 messages message.Service
97 disableAutoSummarize bool
98 isYolo bool
99
100 messageQueue *csync.Map[string, []SessionAgentCall]
101 activeRequests *csync.Map[string, context.CancelFunc]
102}
103
104type SessionAgentOptions struct {
105 LargeModel Model
106 SmallModel Model
107 SystemPromptPrefix string
108 SystemPrompt string
109 IsSubAgent bool
110 DisableAutoSummarize bool
111 IsYolo bool
112 Sessions session.Service
113 Messages message.Service
114 Tools []fantasy.AgentTool
115}
116
117func NewSessionAgent(
118 opts SessionAgentOptions,
119) SessionAgent {
120 return &sessionAgent{
121 largeModel: opts.LargeModel,
122 smallModel: opts.SmallModel,
123 systemPromptPrefix: opts.SystemPromptPrefix,
124 systemPrompt: opts.SystemPrompt,
125 isSubAgent: opts.IsSubAgent,
126 sessions: opts.Sessions,
127 messages: opts.Messages,
128 disableAutoSummarize: opts.DisableAutoSummarize,
129 tools: opts.Tools,
130 isYolo: opts.IsYolo,
131 messageQueue: csync.NewMap[string, []SessionAgentCall](),
132 activeRequests: csync.NewMap[string, context.CancelFunc](),
133 }
134}
135
136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
137 if call.Prompt == "" {
138 return nil, ErrEmptyPrompt
139 }
140 if call.SessionID == "" {
141 return nil, ErrSessionMissing
142 }
143
144 // Queue the message if busy
145 if a.IsSessionBusy(call.SessionID) {
146 existing, ok := a.messageQueue.Get(call.SessionID)
147 if !ok {
148 existing = []SessionAgentCall{}
149 }
150 existing = append(existing, call)
151 a.messageQueue.Set(call.SessionID, existing)
152 return nil, nil
153 }
154
155 if len(a.tools) > 0 {
156 // Add Anthropic caching to the last tool.
157 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
158 }
159
160 agent := fantasy.NewAgent(
161 a.largeModel.Model,
162 fantasy.WithSystemPrompt(a.systemPrompt),
163 fantasy.WithTools(a.tools...),
164 )
165
166 sessionLock := sync.Mutex{}
167 currentSession, err := a.sessions.Get(ctx, call.SessionID)
168 if err != nil {
169 return nil, fmt.Errorf("failed to get session: %w", err)
170 }
171
172 msgs, err := a.getSessionMessages(ctx, currentSession)
173 if err != nil {
174 return nil, fmt.Errorf("failed to get session messages: %w", err)
175 }
176
177 var wg sync.WaitGroup
178 // Generate title if first message.
179 if len(msgs) == 0 {
180 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
181 wg.Go(func() {
182 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
183 })
184 }
185
186 // Add the user message to the session.
187 _, err = a.createUserMessage(ctx, call)
188 if err != nil {
189 return nil, err
190 }
191
192 // Add the session to the context.
193 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
194
195 genCtx, cancel := context.WithCancel(ctx)
196 a.activeRequests.Set(call.SessionID, cancel)
197
198 defer cancel()
199 defer a.activeRequests.Del(call.SessionID)
200
201 history, files := a.preparePrompt(msgs, call.Attachments...)
202
203 startTime := time.Now()
204 a.eventPromptSent(call.SessionID)
205
206 var currentAssistant *message.Message
207 var shouldSummarize bool
208 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
209 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
210 Files: files,
211 Messages: history,
212 ProviderOptions: call.ProviderOptions,
213 MaxOutputTokens: &call.MaxOutputTokens,
214 TopP: call.TopP,
215 Temperature: call.Temperature,
216 PresencePenalty: call.PresencePenalty,
217 TopK: call.TopK,
218 FrequencyPenalty: call.FrequencyPenalty,
219 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
220 prepared.Messages = options.Messages
221 for i := range prepared.Messages {
222 prepared.Messages[i].ProviderOptions = nil
223 }
224
225 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
226 a.messageQueue.Del(call.SessionID)
227 for _, queued := range queuedCalls {
228 userMessage, createErr := a.createUserMessage(callContext, queued)
229 if createErr != nil {
230 return callContext, prepared, createErr
231 }
232 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
233 }
234
235 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
236
237 lastSystemRoleInx := 0
238 systemMessageUpdated := false
239 for i, msg := range prepared.Messages {
240 // Only add cache control to the last message.
241 if msg.Role == fantasy.MessageRoleSystem {
242 lastSystemRoleInx = i
243 } else if !systemMessageUpdated {
244 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
245 systemMessageUpdated = true
246 }
247 // Than add cache control to the last 2 messages.
248 if i > len(prepared.Messages)-3 {
249 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
250 }
251 }
252
253 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
254 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
255 }
256
257 var assistantMsg message.Message
258 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
259 Role: message.Assistant,
260 Parts: []message.ContentPart{},
261 Model: a.largeModel.ModelCfg.Model,
262 Provider: a.largeModel.ModelCfg.Provider,
263 })
264 if err != nil {
265 return callContext, prepared, err
266 }
267 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
268 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
269 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
270 currentAssistant = &assistantMsg
271 return callContext, prepared, err
272 },
273 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
274 currentAssistant.AppendReasoningContent(reasoning.Text)
275 return a.messages.Update(genCtx, *currentAssistant)
276 },
277 OnReasoningDelta: func(id string, text string) error {
278 currentAssistant.AppendReasoningContent(text)
279 return a.messages.Update(genCtx, *currentAssistant)
280 },
281 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
282 // handle anthropic signature
283 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
284 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
285 currentAssistant.AppendReasoningSignature(reasoning.Signature)
286 }
287 }
288 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
289 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
290 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
291 }
292 }
293 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
294 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
295 currentAssistant.SetReasoningResponsesData(reasoning)
296 }
297 }
298 currentAssistant.FinishThinking()
299 return a.messages.Update(genCtx, *currentAssistant)
300 },
301 OnTextDelta: func(id string, text string) error {
302 // Strip leading newline from initial text content. This is is
303 // particularly important in non-interactive mode where leading
304 // newlines are very visible.
305 if len(currentAssistant.Parts) == 0 {
306 text = strings.TrimPrefix(text, "\n")
307 }
308
309 currentAssistant.AppendContent(text)
310 return a.messages.Update(genCtx, *currentAssistant)
311 },
312 OnToolInputStart: func(id string, toolName string) error {
313 toolCall := message.ToolCall{
314 ID: id,
315 Name: toolName,
316 ProviderExecuted: false,
317 Finished: false,
318 }
319 currentAssistant.AddToolCall(toolCall)
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
323 // TODO: implement
324 },
325 OnToolCall: func(tc fantasy.ToolCallContent) error {
326 toolCall := message.ToolCall{
327 ID: tc.ToolCallID,
328 Name: tc.ToolName,
329 Input: tc.Input,
330 ProviderExecuted: false,
331 Finished: true,
332 }
333 currentAssistant.AddToolCall(toolCall)
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnToolResult: func(result fantasy.ToolResultContent) error {
337 toolResult := a.convertToToolResult(result)
338 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
339 Role: message.Tool,
340 Parts: []message.ContentPart{
341 toolResult,
342 },
343 })
344 return createMsgErr
345 },
346 OnStepFinish: func(stepResult fantasy.StepResult) error {
347 finishReason := message.FinishReasonUnknown
348 switch stepResult.FinishReason {
349 case fantasy.FinishReasonLength:
350 finishReason = message.FinishReasonMaxTokens
351 case fantasy.FinishReasonStop:
352 finishReason = message.FinishReasonEndTurn
353 case fantasy.FinishReasonToolCalls:
354 finishReason = message.FinishReasonToolUse
355 }
356 currentAssistant.AddFinish(finishReason, "", "")
357 sessionLock.Lock()
358 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
359 if getSessionErr != nil {
360 sessionLock.Unlock()
361 return getSessionErr
362 }
363 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
364 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
365 sessionLock.Unlock()
366 if sessionErr != nil {
367 return sessionErr
368 }
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 StopWhen: []fantasy.StopCondition{
372 func(_ []fantasy.StepResult) bool {
373 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
374 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
375 remaining := cw - tokens
376 var threshold int64
377 if cw > 200_000 {
378 threshold = 20_000
379 } else {
380 threshold = int64(float64(cw) * 0.2)
381 }
382 if (remaining <= threshold) && !a.disableAutoSummarize {
383 shouldSummarize = true
384 return true
385 }
386 return false
387 },
388 },
389 })
390
391 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
392
393 if err != nil {
394 isCancelErr := errors.Is(err, context.Canceled)
395 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
396 if currentAssistant == nil {
397 return result, err
398 }
399 // Ensure we finish thinking on error to close the reasoning state.
400 currentAssistant.FinishThinking()
401 toolCalls := currentAssistant.ToolCalls()
402 // INFO: we use the parent context here because the genCtx has been cancelled.
403 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
404 if createErr != nil {
405 return nil, createErr
406 }
407 for _, tc := range toolCalls {
408 if !tc.Finished {
409 tc.Finished = true
410 tc.Input = "{}"
411 currentAssistant.AddToolCall(tc)
412 updateErr := a.messages.Update(ctx, *currentAssistant)
413 if updateErr != nil {
414 return nil, updateErr
415 }
416 }
417
418 found := false
419 for _, msg := range msgs {
420 if msg.Role == message.Tool {
421 for _, tr := range msg.ToolResults() {
422 if tr.ToolCallID == tc.ID {
423 found = true
424 break
425 }
426 }
427 }
428 if found {
429 break
430 }
431 }
432 if found {
433 continue
434 }
435 content := "There was an error while executing the tool"
436 if isCancelErr {
437 content = "Tool execution canceled by user"
438 } else if isPermissionErr {
439 content = "User denied permission"
440 }
441 toolResult := message.ToolResult{
442 ToolCallID: tc.ID,
443 Name: tc.Name,
444 Content: content,
445 IsError: true,
446 }
447 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
448 Role: message.Tool,
449 Parts: []message.ContentPart{
450 toolResult,
451 },
452 })
453 if createErr != nil {
454 return nil, createErr
455 }
456 }
457 var fantasyErr *fantasy.Error
458 var providerErr *fantasy.ProviderError
459 const defaultTitle = "Provider Error"
460 if isCancelErr {
461 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
462 } else if isPermissionErr {
463 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
464 } else if errors.Is(err, hyper.ErrNoCredits) {
465 url := hyper.BaseURL()
466 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
467 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
468 } else if errors.As(err, &providerErr) {
469 if providerErr.Message == "The requested model is not supported." {
470 url := "https://github.com/settings/copilot/features"
471 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
472 currentAssistant.AddFinish(
473 message.FinishReasonError,
474 "Copilot model not enabled",
475 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
476 )
477 } else {
478 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
479 }
480 } else if errors.As(err, &fantasyErr) {
481 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
482 } else {
483 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
484 }
485 // Note: we use the parent context here because the genCtx has been
486 // cancelled.
487 updateErr := a.messages.Update(ctx, *currentAssistant)
488 if updateErr != nil {
489 return nil, updateErr
490 }
491 return nil, err
492 }
493 wg.Wait()
494
495 if shouldSummarize {
496 a.activeRequests.Del(call.SessionID)
497 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
498 return nil, summarizeErr
499 }
500 // If the agent wasn't done...
501 if len(currentAssistant.ToolCalls()) > 0 {
502 existing, ok := a.messageQueue.Get(call.SessionID)
503 if !ok {
504 existing = []SessionAgentCall{}
505 }
506 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
507 existing = append(existing, call)
508 a.messageQueue.Set(call.SessionID, existing)
509 }
510 }
511
512 // Release active request before processing queued messages.
513 a.activeRequests.Del(call.SessionID)
514 cancel()
515
516 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
517 if !ok || len(queuedMessages) == 0 {
518 return result, err
519 }
520 // There are queued messages restart the loop.
521 firstQueuedMessage := queuedMessages[0]
522 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
523 return a.Run(ctx, firstQueuedMessage)
524}
525
526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
527 if a.IsSessionBusy(sessionID) {
528 return ErrSessionBusy
529 }
530
531 currentSession, err := a.sessions.Get(ctx, sessionID)
532 if err != nil {
533 return fmt.Errorf("failed to get session: %w", err)
534 }
535 msgs, err := a.getSessionMessages(ctx, currentSession)
536 if err != nil {
537 return err
538 }
539 if len(msgs) == 0 {
540 // Nothing to summarize.
541 return nil
542 }
543
544 aiMsgs, _ := a.preparePrompt(msgs)
545
546 genCtx, cancel := context.WithCancel(ctx)
547 a.activeRequests.Set(sessionID, cancel)
548 defer a.activeRequests.Del(sessionID)
549 defer cancel()
550
551 agent := fantasy.NewAgent(a.largeModel.Model,
552 fantasy.WithSystemPrompt(string(summaryPrompt)),
553 )
554 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
555 Role: message.Assistant,
556 Model: a.largeModel.Model.Model(),
557 Provider: a.largeModel.Model.Provider(),
558 IsSummaryMessage: true,
559 })
560 if err != nil {
561 return err
562 }
563
564 summaryPromptText := "Provide a detailed summary of our conversation above."
565 if len(currentSession.Todos) > 0 {
566 summaryPromptText += "\n\n## Current Todo List\n\n"
567 for _, t := range currentSession.Todos {
568 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
569 }
570 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
571 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
572 }
573
574 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
575 Prompt: summaryPromptText,
576 Messages: aiMsgs,
577 ProviderOptions: opts,
578 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
579 prepared.Messages = options.Messages
580 if a.systemPromptPrefix != "" {
581 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
582 }
583 return callContext, prepared, nil
584 },
585 OnReasoningDelta: func(id string, text string) error {
586 summaryMessage.AppendReasoningContent(text)
587 return a.messages.Update(genCtx, summaryMessage)
588 },
589 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
590 // Handle anthropic signature.
591 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
592 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
593 summaryMessage.AppendReasoningSignature(signature.Signature)
594 }
595 }
596 summaryMessage.FinishThinking()
597 return a.messages.Update(genCtx, summaryMessage)
598 },
599 OnTextDelta: func(id, text string) error {
600 summaryMessage.AppendContent(text)
601 return a.messages.Update(genCtx, summaryMessage)
602 },
603 })
604 if err != nil {
605 isCancelErr := errors.Is(err, context.Canceled)
606 if isCancelErr {
607 // User cancelled summarize we need to remove the summary message.
608 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
609 return deleteErr
610 }
611 return err
612 }
613
614 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
615 err = a.messages.Update(genCtx, summaryMessage)
616 if err != nil {
617 return err
618 }
619
620 var openrouterCost *float64
621 for _, step := range resp.Steps {
622 stepCost := a.openrouterCost(step.ProviderMetadata)
623 if stepCost != nil {
624 newCost := *stepCost
625 if openrouterCost != nil {
626 newCost += *openrouterCost
627 }
628 openrouterCost = &newCost
629 }
630 }
631
632 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
633
634 // Just in case, get just the last usage info.
635 usage := resp.Response.Usage
636 currentSession.SummaryMessageID = summaryMessage.ID
637 currentSession.CompletionTokens = usage.OutputTokens
638 currentSession.PromptTokens = 0
639 _, err = a.sessions.Save(genCtx, currentSession)
640 return err
641}
642
643func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
644 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
645 return fantasy.ProviderOptions{}
646 }
647 return fantasy.ProviderOptions{
648 anthropic.Name: &anthropic.ProviderCacheControlOptions{
649 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
650 },
651 bedrock.Name: &anthropic.ProviderCacheControlOptions{
652 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
653 },
654 }
655}
656
657func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
658 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
659 var attachmentParts []message.ContentPart
660 for _, attachment := range call.Attachments {
661 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
662 }
663 parts = append(parts, attachmentParts...)
664 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
665 Role: message.User,
666 Parts: parts,
667 })
668 if err != nil {
669 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
670 }
671 return msg, nil
672}
673
674func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
675 var history []fantasy.Message
676 if !a.isSubAgent {
677 history = append(history, fantasy.NewUserMessage(
678 fmt.Sprintf("<system_reminder>%s</system_reminder>",
679 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
680If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
681If not, please feel free to ignore. Again do not mention this message to the user.`,
682 ),
683 ))
684 }
685 for _, m := range msgs {
686 if len(m.Parts) == 0 {
687 continue
688 }
689 // Assistant message without content or tool calls (cancelled before it
690 // returned anything).
691 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
692 continue
693 }
694 history = append(history, m.ToAIMessage()...)
695 }
696
697 var files []fantasy.FilePart
698 for _, attachment := range attachments {
699 if attachment.IsText() {
700 continue
701 }
702 files = append(files, fantasy.FilePart{
703 Filename: attachment.FileName,
704 Data: attachment.Content,
705 MediaType: attachment.MimeType,
706 })
707 }
708
709 return history, files
710}
711
712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
713 msgs, err := a.messages.List(ctx, session.ID)
714 if err != nil {
715 return nil, fmt.Errorf("failed to list messages: %w", err)
716 }
717
718 if session.SummaryMessageID != "" {
719 summaryMsgInex := -1
720 for i, msg := range msgs {
721 if msg.ID == session.SummaryMessageID {
722 summaryMsgInex = i
723 break
724 }
725 }
726 if summaryMsgInex != -1 {
727 msgs = msgs[summaryMsgInex:]
728 msgs[0].Role = message.User
729 }
730 }
731 return msgs, nil
732}
733
734// generateTitle generates a session titled based on the initial prompt.
735func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
736 if userPrompt == "" {
737 return
738 }
739
740 var maxOutputTokens int64 = 40
741 if a.smallModel.CatwalkCfg.CanReason {
742 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
743 }
744
745 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
746 return fantasy.NewAgent(m,
747 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
748 fantasy.WithMaxOutputTokens(tok),
749 )
750 }
751
752 streamCall := fantasy.AgentStreamCall{
753 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
754 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
755 prepared.Messages = opts.Messages
756 if a.systemPromptPrefix != "" {
757 prepared.Messages = append([]fantasy.Message{
758 fantasy.NewSystemMessage(a.systemPromptPrefix),
759 }, prepared.Messages...)
760 }
761 return callCtx, prepared, nil
762 },
763 }
764
765 // Use the small model to generate the title.
766 model := &a.smallModel
767 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
768 resp, err := agent.Stream(ctx, streamCall)
769 if err == nil {
770 // We successfully generated a title with the small model.
771 slog.Info("generated title with small model")
772 } else {
773 // It didn't work. Let's try with the big model.
774 slog.Error("error generating title with small model; trying big model", "err", err)
775 model = &a.largeModel
776 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
777 resp, err = agent.Stream(ctx, streamCall)
778 if err == nil {
779 slog.Info("generated title with large model")
780 } else {
781 // Welp, the large model didn't work either. Use the default
782 // session name and return.
783 slog.Error("error generating title with large model", "err", err)
784 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
785 if saveErr != nil {
786 slog.Error("failed to save session title and usage", "error", saveErr)
787 }
788 return
789 }
790 }
791
792 if resp == nil {
793 // Actually, we didn't get a response so we can't. Use the default
794 // session name and return.
795 slog.Error("response is nil; can't generate title")
796 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
797 if saveErr != nil {
798 slog.Error("failed to save session title and usage", "error", saveErr)
799 }
800 return
801 }
802
803 // Clean up title.
804 var title string
805 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
806 slog.Info("generated title", "title", title)
807
808 // Remove thinking tags if present.
809 title = thinkTagRegex.ReplaceAllString(title, "")
810
811 title = strings.TrimSpace(title)
812 if title == "" {
813 slog.Warn("empty title; using fallback")
814 title = defaultSessionName
815 }
816
817 // Calculate usage and cost.
818 var openrouterCost *float64
819 for _, step := range resp.Steps {
820 stepCost := a.openrouterCost(step.ProviderMetadata)
821 if stepCost != nil {
822 newCost := *stepCost
823 if openrouterCost != nil {
824 newCost += *openrouterCost
825 }
826 openrouterCost = &newCost
827 }
828 }
829
830 modelConfig := model.CatwalkCfg
831 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
832 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
833 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
834 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
835
836 // Use override cost if available (e.g., from OpenRouter).
837 if openrouterCost != nil {
838 cost = *openrouterCost
839 }
840
841 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
842 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
843
844 // Atomically update only title and usage fields to avoid overriding other
845 // concurrent session updates.
846 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
847 if saveErr != nil {
848 slog.Error("failed to save session title and usage", "error", saveErr)
849 return
850 }
851}
852
853func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
854 openrouterMetadata, ok := metadata[openrouter.Name]
855 if !ok {
856 return nil
857 }
858
859 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
860 if !ok {
861 return nil
862 }
863 return &opts.Usage.Cost
864}
865
866func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
867 modelConfig := model.CatwalkCfg
868 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
869 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
870 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
871 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
872
873 a.eventTokensUsed(session.ID, model, usage, cost)
874
875 if overrideCost != nil {
876 session.Cost += *overrideCost
877 } else {
878 session.Cost += cost
879 }
880
881 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
882 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
883}
884
885func (a *sessionAgent) Cancel(sessionID string) {
886 // Cancel regular requests.
887 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
888 slog.Info("Request cancellation initiated", "session_id", sessionID)
889 cancel()
890 }
891
892 // Also check for summarize requests.
893 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
894 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
895 cancel()
896 }
897
898 if a.QueuedPrompts(sessionID) > 0 {
899 slog.Info("Clearing queued prompts", "session_id", sessionID)
900 a.messageQueue.Del(sessionID)
901 }
902}
903
904func (a *sessionAgent) ClearQueue(sessionID string) {
905 if a.QueuedPrompts(sessionID) > 0 {
906 slog.Info("Clearing queued prompts", "session_id", sessionID)
907 a.messageQueue.Del(sessionID)
908 }
909}
910
911func (a *sessionAgent) CancelAll() {
912 if !a.IsBusy() {
913 return
914 }
915 for key := range a.activeRequests.Seq2() {
916 a.Cancel(key) // key is sessionID
917 }
918
919 timeout := time.After(5 * time.Second)
920 for a.IsBusy() {
921 select {
922 case <-timeout:
923 return
924 default:
925 time.Sleep(200 * time.Millisecond)
926 }
927 }
928}
929
930func (a *sessionAgent) IsBusy() bool {
931 var busy bool
932 for cancelFunc := range a.activeRequests.Seq() {
933 if cancelFunc != nil {
934 busy = true
935 break
936 }
937 }
938 return busy
939}
940
941func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
942 _, busy := a.activeRequests.Get(sessionID)
943 return busy
944}
945
946func (a *sessionAgent) QueuedPrompts(sessionID string) int {
947 l, ok := a.messageQueue.Get(sessionID)
948 if !ok {
949 return 0
950 }
951 return len(l)
952}
953
954func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
955 l, ok := a.messageQueue.Get(sessionID)
956 if !ok {
957 return nil
958 }
959 prompts := make([]string, len(l))
960 for i, call := range l {
961 prompts[i] = call.Prompt
962 }
963 return prompts
964}
965
966func (a *sessionAgent) SetModels(large Model, small Model) {
967 a.largeModel = large
968 a.smallModel = small
969}
970
971func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
972 a.tools = tools
973}
974
975func (a *sessionAgent) Model() Model {
976 return a.largeModel
977}
978
979func (a *sessionAgent) promptPrefix() string {
980 return a.systemPromptPrefix
981}
982
983// convertToToolResult converts a fantasy tool result to a message tool result.
984func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
985 baseResult := message.ToolResult{
986 ToolCallID: result.ToolCallID,
987 Name: result.ToolName,
988 Metadata: result.ClientMetadata,
989 }
990
991 switch result.Result.GetType() {
992 case fantasy.ToolResultContentTypeText:
993 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
994 baseResult.Content = r.Text
995 }
996 case fantasy.ToolResultContentTypeError:
997 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
998 baseResult.Content = r.Error.Error()
999 baseResult.IsError = true
1000 }
1001 case fantasy.ToolResultContentTypeMedia:
1002 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1003 content := r.Text
1004 if content == "" {
1005 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1006 }
1007 baseResult.Content = content
1008 baseResult.Data = r.Data
1009 baseResult.MIMEType = r.MediaType
1010 }
1011 }
1012
1013 return baseResult
1014}
1015
1016// workaroundProviderMediaLimitations converts media content in tool results to
1017// user messages for providers that don't natively support images in tool results.
1018//
1019// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1020// don't support sending images/media in tool result messages - they only accept
1021// text in tool results. However, they DO support images in user messages.
1022//
1023// If we send media in tool results to these providers, the API returns an error.
1024//
1025// Solution: For these providers, we:
1026// 1. Replace the media in the tool result with a text placeholder
1027// 2. Inject a user message immediately after with the image as a file attachment
1028// 3. This maintains the tool execution flow while working around API limitations
1029//
1030// Anthropic and Bedrock support images natively in tool results, so we skip
1031// this workaround for them.
1032//
1033// Example transformation:
1034//
1035// BEFORE: [tool result: image data]
1036// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1037func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1038 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1039 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1040
1041 if providerSupportsMedia {
1042 return messages
1043 }
1044
1045 convertedMessages := make([]fantasy.Message, 0, len(messages))
1046
1047 for _, msg := range messages {
1048 if msg.Role != fantasy.MessageRoleTool {
1049 convertedMessages = append(convertedMessages, msg)
1050 continue
1051 }
1052
1053 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1054 var mediaFiles []fantasy.FilePart
1055
1056 for _, part := range msg.Content {
1057 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1058 if !ok {
1059 textParts = append(textParts, part)
1060 continue
1061 }
1062
1063 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1064 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1065 if err != nil {
1066 slog.Warn("failed to decode media data", "error", err)
1067 textParts = append(textParts, part)
1068 continue
1069 }
1070
1071 mediaFiles = append(mediaFiles, fantasy.FilePart{
1072 Data: decoded,
1073 MediaType: media.MediaType,
1074 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1075 })
1076
1077 textParts = append(textParts, fantasy.ToolResultPart{
1078 ToolCallID: toolResult.ToolCallID,
1079 Output: fantasy.ToolResultOutputContentText{
1080 Text: "[Image/media content loaded - see attached file]",
1081 },
1082 ProviderOptions: toolResult.ProviderOptions,
1083 })
1084 } else {
1085 textParts = append(textParts, part)
1086 }
1087 }
1088
1089 convertedMessages = append(convertedMessages, fantasy.Message{
1090 Role: fantasy.MessageRoleTool,
1091 Content: textParts,
1092 })
1093
1094 if len(mediaFiles) > 0 {
1095 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1096 "Here is the media content from the tool result:",
1097 mediaFiles...,
1098 ))
1099 }
1100 }
1101
1102 return convertedMessages
1103}