1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43const defaultSessionName = "Untitled Session"
44
45//go:embed templates/title.md
46var titlePrompt []byte
47
48//go:embed templates/summary.md
49var summaryPrompt []byte
50
51// Used to remove <think> tags from generated titles.
52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
53
54type SessionAgentCall struct {
55 SessionID string
56 Prompt string
57 ProviderOptions fantasy.ProviderOptions
58 Attachments []message.Attachment
59 MaxOutputTokens int64
60 Temperature *float64
61 TopP *float64
62 TopK *int64
63 FrequencyPenalty *float64
64 PresencePenalty *float64
65}
66
67type SessionAgent interface {
68 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
69 SetModels(large Model, small Model)
70 SetTools(tools []fantasy.AgentTool)
71 Cancel(sessionID string)
72 CancelAll()
73 IsSessionBusy(sessionID string) bool
74 IsBusy() bool
75 QueuedPrompts(sessionID string) int
76 QueuedPromptsList(sessionID string) []string
77 ClearQueue(sessionID string)
78 Summarize(context.Context, string, fantasy.ProviderOptions) error
79 Model() Model
80}
81
82type Model struct {
83 Model fantasy.LanguageModel
84 CatwalkCfg catwalk.Model
85 ModelCfg config.SelectedModel
86}
87
88type sessionAgent struct {
89 largeModel Model
90 smallModel Model
91 systemPromptPrefix string
92 systemPrompt string
93 isSubAgent bool
94 tools []fantasy.AgentTool
95 sessions session.Service
96 messages message.Service
97 disableAutoSummarize bool
98 isYolo bool
99
100 messageQueue *csync.Map[string, []SessionAgentCall]
101 activeRequests *csync.Map[string, context.CancelFunc]
102}
103
104type SessionAgentOptions struct {
105 LargeModel Model
106 SmallModel Model
107 SystemPromptPrefix string
108 SystemPrompt string
109 IsSubAgent bool
110 DisableAutoSummarize bool
111 IsYolo bool
112 Sessions session.Service
113 Messages message.Service
114 Tools []fantasy.AgentTool
115}
116
117func NewSessionAgent(
118 opts SessionAgentOptions,
119) SessionAgent {
120 return &sessionAgent{
121 largeModel: opts.LargeModel,
122 smallModel: opts.SmallModel,
123 systemPromptPrefix: opts.SystemPromptPrefix,
124 systemPrompt: opts.SystemPrompt,
125 isSubAgent: opts.IsSubAgent,
126 sessions: opts.Sessions,
127 messages: opts.Messages,
128 disableAutoSummarize: opts.DisableAutoSummarize,
129 tools: opts.Tools,
130 isYolo: opts.IsYolo,
131 messageQueue: csync.NewMap[string, []SessionAgentCall](),
132 activeRequests: csync.NewMap[string, context.CancelFunc](),
133 }
134}
135
136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
137 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
138 return nil, ErrEmptyPrompt
139 }
140 if call.SessionID == "" {
141 return nil, ErrSessionMissing
142 }
143
144 // Queue the message if busy
145 if a.IsSessionBusy(call.SessionID) {
146 existing, ok := a.messageQueue.Get(call.SessionID)
147 if !ok {
148 existing = []SessionAgentCall{}
149 }
150 existing = append(existing, call)
151 a.messageQueue.Set(call.SessionID, existing)
152 return nil, nil
153 }
154
155 if len(a.tools) > 0 {
156 // Add Anthropic caching to the last tool.
157 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
158 }
159
160 agent := fantasy.NewAgent(
161 a.largeModel.Model,
162 fantasy.WithSystemPrompt(a.systemPrompt),
163 fantasy.WithTools(a.tools...),
164 )
165
166 sessionLock := sync.Mutex{}
167 currentSession, err := a.sessions.Get(ctx, call.SessionID)
168 if err != nil {
169 return nil, fmt.Errorf("failed to get session: %w", err)
170 }
171
172 msgs, err := a.getSessionMessages(ctx, currentSession)
173 if err != nil {
174 return nil, fmt.Errorf("failed to get session messages: %w", err)
175 }
176
177 var wg sync.WaitGroup
178 // Generate title if first message.
179 if len(msgs) == 0 {
180 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
181 wg.Go(func() {
182 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
183 })
184 }
185
186 // Add the user message to the session.
187 _, err = a.createUserMessage(ctx, call)
188 if err != nil {
189 return nil, err
190 }
191
192 // Add the session to the context.
193 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
194
195 genCtx, cancel := context.WithCancel(ctx)
196 a.activeRequests.Set(call.SessionID, cancel)
197
198 defer cancel()
199 defer a.activeRequests.Del(call.SessionID)
200
201 history, files := a.preparePrompt(msgs, call.Attachments...)
202
203 startTime := time.Now()
204 a.eventPromptSent(call.SessionID)
205
206 var currentAssistant *message.Message
207 var shouldSummarize bool
208 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
209 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
210 Files: files,
211 Messages: history,
212 ProviderOptions: call.ProviderOptions,
213 MaxOutputTokens: &call.MaxOutputTokens,
214 TopP: call.TopP,
215 Temperature: call.Temperature,
216 PresencePenalty: call.PresencePenalty,
217 TopK: call.TopK,
218 FrequencyPenalty: call.FrequencyPenalty,
219 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
220 prepared.Messages = options.Messages
221 for i := range prepared.Messages {
222 prepared.Messages[i].ProviderOptions = nil
223 }
224
225 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
226 a.messageQueue.Del(call.SessionID)
227 for _, queued := range queuedCalls {
228 userMessage, createErr := a.createUserMessage(callContext, queued)
229 if createErr != nil {
230 return callContext, prepared, createErr
231 }
232 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
233 }
234
235 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
236
237 lastSystemRoleInx := 0
238 systemMessageUpdated := false
239 for i, msg := range prepared.Messages {
240 // Only add cache control to the last message.
241 if msg.Role == fantasy.MessageRoleSystem {
242 lastSystemRoleInx = i
243 } else if !systemMessageUpdated {
244 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
245 systemMessageUpdated = true
246 }
247 // Than add cache control to the last 2 messages.
248 if i > len(prepared.Messages)-3 {
249 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
250 }
251 }
252
253 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
254 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
255 }
256
257 var assistantMsg message.Message
258 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
259 Role: message.Assistant,
260 Parts: []message.ContentPart{},
261 Model: a.largeModel.ModelCfg.Model,
262 Provider: a.largeModel.ModelCfg.Provider,
263 })
264 if err != nil {
265 return callContext, prepared, err
266 }
267 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
268 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
269 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
270 currentAssistant = &assistantMsg
271 return callContext, prepared, err
272 },
273 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
274 currentAssistant.AppendReasoningContent(reasoning.Text)
275 return a.messages.Update(genCtx, *currentAssistant)
276 },
277 OnReasoningDelta: func(id string, text string) error {
278 currentAssistant.AppendReasoningContent(text)
279 return a.messages.Update(genCtx, *currentAssistant)
280 },
281 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
282 // handle anthropic signature
283 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
284 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
285 currentAssistant.AppendReasoningSignature(reasoning.Signature)
286 }
287 }
288 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
289 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
290 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
291 }
292 }
293 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
294 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
295 currentAssistant.SetReasoningResponsesData(reasoning)
296 }
297 }
298 currentAssistant.FinishThinking()
299 return a.messages.Update(genCtx, *currentAssistant)
300 },
301 OnTextDelta: func(id string, text string) error {
302 // Strip leading newline from initial text content. This is is
303 // particularly important in non-interactive mode where leading
304 // newlines are very visible.
305 if len(currentAssistant.Parts) == 0 {
306 text = strings.TrimPrefix(text, "\n")
307 }
308
309 currentAssistant.AppendContent(text)
310 return a.messages.Update(genCtx, *currentAssistant)
311 },
312 OnToolInputStart: func(id string, toolName string) error {
313 toolCall := message.ToolCall{
314 ID: id,
315 Name: toolName,
316 ProviderExecuted: false,
317 Finished: false,
318 }
319 currentAssistant.AddToolCall(toolCall)
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
323 // TODO: implement
324 },
325 OnToolCall: func(tc fantasy.ToolCallContent) error {
326 toolCall := message.ToolCall{
327 ID: tc.ToolCallID,
328 Name: tc.ToolName,
329 Input: tc.Input,
330 ProviderExecuted: false,
331 Finished: true,
332 }
333 currentAssistant.AddToolCall(toolCall)
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnToolResult: func(result fantasy.ToolResultContent) error {
337 toolResult := a.convertToToolResult(result)
338 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
339 Role: message.Tool,
340 Parts: []message.ContentPart{
341 toolResult,
342 },
343 })
344 return createMsgErr
345 },
346 OnStepFinish: func(stepResult fantasy.StepResult) error {
347 finishReason := message.FinishReasonUnknown
348 switch stepResult.FinishReason {
349 case fantasy.FinishReasonLength:
350 finishReason = message.FinishReasonMaxTokens
351 case fantasy.FinishReasonStop:
352 finishReason = message.FinishReasonEndTurn
353 case fantasy.FinishReasonToolCalls:
354 finishReason = message.FinishReasonToolUse
355 }
356 currentAssistant.AddFinish(finishReason, "", "")
357 sessionLock.Lock()
358 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
359 if getSessionErr != nil {
360 sessionLock.Unlock()
361 return getSessionErr
362 }
363 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
364 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
365 sessionLock.Unlock()
366 if sessionErr != nil {
367 return sessionErr
368 }
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 StopWhen: []fantasy.StopCondition{
372 func(_ []fantasy.StepResult) bool {
373 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
374 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
375 remaining := cw - tokens
376 var threshold int64
377 if cw > 200_000 {
378 threshold = 20_000
379 } else {
380 threshold = int64(float64(cw) * 0.2)
381 }
382 if (remaining <= threshold) && !a.disableAutoSummarize {
383 shouldSummarize = true
384 return true
385 }
386 return false
387 },
388 },
389 })
390
391 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
392
393 if err != nil {
394 isCancelErr := errors.Is(err, context.Canceled)
395 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
396 if currentAssistant == nil {
397 return result, err
398 }
399 // Ensure we finish thinking on error to close the reasoning state.
400 currentAssistant.FinishThinking()
401 toolCalls := currentAssistant.ToolCalls()
402 // INFO: we use the parent context here because the genCtx has been cancelled.
403 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
404 if createErr != nil {
405 return nil, createErr
406 }
407 for _, tc := range toolCalls {
408 if !tc.Finished {
409 tc.Finished = true
410 tc.Input = "{}"
411 currentAssistant.AddToolCall(tc)
412 updateErr := a.messages.Update(ctx, *currentAssistant)
413 if updateErr != nil {
414 return nil, updateErr
415 }
416 }
417
418 found := false
419 for _, msg := range msgs {
420 if msg.Role == message.Tool {
421 for _, tr := range msg.ToolResults() {
422 if tr.ToolCallID == tc.ID {
423 found = true
424 break
425 }
426 }
427 }
428 if found {
429 break
430 }
431 }
432 if found {
433 continue
434 }
435 content := "There was an error while executing the tool"
436 if isCancelErr {
437 content = "Tool execution canceled by user"
438 } else if isPermissionErr {
439 content = "User denied permission"
440 }
441 toolResult := message.ToolResult{
442 ToolCallID: tc.ID,
443 Name: tc.Name,
444 Content: content,
445 IsError: true,
446 }
447 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
448 Role: message.Tool,
449 Parts: []message.ContentPart{
450 toolResult,
451 },
452 })
453 if createErr != nil {
454 return nil, createErr
455 }
456 }
457 var fantasyErr *fantasy.Error
458 var providerErr *fantasy.ProviderError
459 const defaultTitle = "Provider Error"
460 if isCancelErr {
461 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
462 } else if isPermissionErr {
463 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
464 } else if errors.Is(err, hyper.ErrNoCredits) {
465 url := hyper.BaseURL()
466 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
467 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
468 } else if errors.As(err, &providerErr) {
469 if providerErr.Message == "The requested model is not supported." {
470 url := "https://github.com/settings/copilot/features"
471 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
472 currentAssistant.AddFinish(
473 message.FinishReasonError,
474 "Copilot model not enabled",
475 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
476 )
477 } else {
478 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
479 }
480 } else if errors.As(err, &fantasyErr) {
481 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
482 } else {
483 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
484 }
485 // Note: we use the parent context here because the genCtx has been
486 // cancelled.
487 updateErr := a.messages.Update(ctx, *currentAssistant)
488 if updateErr != nil {
489 return nil, updateErr
490 }
491 return nil, err
492 }
493 wg.Wait()
494
495 if shouldSummarize {
496 a.activeRequests.Del(call.SessionID)
497 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
498 return nil, summarizeErr
499 }
500 // If the agent wasn't done...
501 if len(currentAssistant.ToolCalls()) > 0 {
502 existing, ok := a.messageQueue.Get(call.SessionID)
503 if !ok {
504 existing = []SessionAgentCall{}
505 }
506 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
507 existing = append(existing, call)
508 a.messageQueue.Set(call.SessionID, existing)
509 }
510 }
511
512 // Release active request before processing queued messages.
513 a.activeRequests.Del(call.SessionID)
514 cancel()
515
516 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
517 if !ok || len(queuedMessages) == 0 {
518 return result, err
519 }
520 // There are queued messages restart the loop.
521 firstQueuedMessage := queuedMessages[0]
522 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
523 return a.Run(ctx, firstQueuedMessage)
524}
525
526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
527 if a.IsSessionBusy(sessionID) {
528 return ErrSessionBusy
529 }
530
531 currentSession, err := a.sessions.Get(ctx, sessionID)
532 if err != nil {
533 return fmt.Errorf("failed to get session: %w", err)
534 }
535 msgs, err := a.getSessionMessages(ctx, currentSession)
536 if err != nil {
537 return err
538 }
539 if len(msgs) == 0 {
540 // Nothing to summarize.
541 return nil
542 }
543
544 aiMsgs, _ := a.preparePrompt(msgs)
545
546 genCtx, cancel := context.WithCancel(ctx)
547 a.activeRequests.Set(sessionID, cancel)
548 defer a.activeRequests.Del(sessionID)
549 defer cancel()
550
551 agent := fantasy.NewAgent(a.largeModel.Model,
552 fantasy.WithSystemPrompt(string(summaryPrompt)),
553 )
554 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
555 Role: message.Assistant,
556 Model: a.largeModel.Model.Model(),
557 Provider: a.largeModel.Model.Provider(),
558 IsSummaryMessage: true,
559 })
560 if err != nil {
561 return err
562 }
563
564 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
565
566 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
567 Prompt: summaryPromptText,
568 Messages: aiMsgs,
569 ProviderOptions: opts,
570 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
571 prepared.Messages = options.Messages
572 if a.systemPromptPrefix != "" {
573 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
574 }
575 return callContext, prepared, nil
576 },
577 OnReasoningDelta: func(id string, text string) error {
578 summaryMessage.AppendReasoningContent(text)
579 return a.messages.Update(genCtx, summaryMessage)
580 },
581 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
582 // Handle anthropic signature.
583 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
584 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
585 summaryMessage.AppendReasoningSignature(signature.Signature)
586 }
587 }
588 summaryMessage.FinishThinking()
589 return a.messages.Update(genCtx, summaryMessage)
590 },
591 OnTextDelta: func(id, text string) error {
592 summaryMessage.AppendContent(text)
593 return a.messages.Update(genCtx, summaryMessage)
594 },
595 })
596 if err != nil {
597 isCancelErr := errors.Is(err, context.Canceled)
598 if isCancelErr {
599 // User cancelled summarize we need to remove the summary message.
600 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
601 return deleteErr
602 }
603 return err
604 }
605
606 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
607 err = a.messages.Update(genCtx, summaryMessage)
608 if err != nil {
609 return err
610 }
611
612 var openrouterCost *float64
613 for _, step := range resp.Steps {
614 stepCost := a.openrouterCost(step.ProviderMetadata)
615 if stepCost != nil {
616 newCost := *stepCost
617 if openrouterCost != nil {
618 newCost += *openrouterCost
619 }
620 openrouterCost = &newCost
621 }
622 }
623
624 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
625
626 // Just in case, get just the last usage info.
627 usage := resp.Response.Usage
628 currentSession.SummaryMessageID = summaryMessage.ID
629 currentSession.CompletionTokens = usage.OutputTokens
630 currentSession.PromptTokens = 0
631 _, err = a.sessions.Save(genCtx, currentSession)
632 return err
633}
634
635func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
636 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
637 return fantasy.ProviderOptions{}
638 }
639 return fantasy.ProviderOptions{
640 anthropic.Name: &anthropic.ProviderCacheControlOptions{
641 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
642 },
643 bedrock.Name: &anthropic.ProviderCacheControlOptions{
644 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
645 },
646 }
647}
648
649func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
650 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
651 var attachmentParts []message.ContentPart
652 for _, attachment := range call.Attachments {
653 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
654 }
655 parts = append(parts, attachmentParts...)
656 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
657 Role: message.User,
658 Parts: parts,
659 })
660 if err != nil {
661 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
662 }
663 return msg, nil
664}
665
666func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
667 var history []fantasy.Message
668 if !a.isSubAgent {
669 history = append(history, fantasy.NewUserMessage(
670 fmt.Sprintf("<system_reminder>%s</system_reminder>",
671 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
672If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
673If not, please feel free to ignore. Again do not mention this message to the user.`,
674 ),
675 ))
676 }
677 for _, m := range msgs {
678 if len(m.Parts) == 0 {
679 continue
680 }
681 // Assistant message without content or tool calls (cancelled before it
682 // returned anything).
683 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
684 continue
685 }
686 history = append(history, m.ToAIMessage()...)
687 }
688
689 var files []fantasy.FilePart
690 for _, attachment := range attachments {
691 if attachment.IsText() {
692 continue
693 }
694 files = append(files, fantasy.FilePart{
695 Filename: attachment.FileName,
696 Data: attachment.Content,
697 MediaType: attachment.MimeType,
698 })
699 }
700
701 return history, files
702}
703
704func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
705 msgs, err := a.messages.List(ctx, session.ID)
706 if err != nil {
707 return nil, fmt.Errorf("failed to list messages: %w", err)
708 }
709
710 if session.SummaryMessageID != "" {
711 summaryMsgInex := -1
712 for i, msg := range msgs {
713 if msg.ID == session.SummaryMessageID {
714 summaryMsgInex = i
715 break
716 }
717 }
718 if summaryMsgInex != -1 {
719 msgs = msgs[summaryMsgInex:]
720 msgs[0].Role = message.User
721 }
722 }
723 return msgs, nil
724}
725
726// generateTitle generates a session titled based on the initial prompt.
727func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
728 if userPrompt == "" {
729 return
730 }
731
732 var maxOutputTokens int64 = 40
733 if a.smallModel.CatwalkCfg.CanReason {
734 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
735 }
736
737 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
738 return fantasy.NewAgent(m,
739 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
740 fantasy.WithMaxOutputTokens(tok),
741 )
742 }
743
744 streamCall := fantasy.AgentStreamCall{
745 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
746 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
747 prepared.Messages = opts.Messages
748 if a.systemPromptPrefix != "" {
749 prepared.Messages = append([]fantasy.Message{
750 fantasy.NewSystemMessage(a.systemPromptPrefix),
751 }, prepared.Messages...)
752 }
753 return callCtx, prepared, nil
754 },
755 }
756
757 // Use the small model to generate the title.
758 model := &a.smallModel
759 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
760 resp, err := agent.Stream(ctx, streamCall)
761 if err == nil {
762 // We successfully generated a title with the small model.
763 slog.Info("generated title with small model")
764 } else {
765 // It didn't work. Let's try with the big model.
766 slog.Error("error generating title with small model; trying big model", "err", err)
767 model = &a.largeModel
768 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
769 resp, err = agent.Stream(ctx, streamCall)
770 if err == nil {
771 slog.Info("generated title with large model")
772 } else {
773 // Welp, the large model didn't work either. Use the default
774 // session name and return.
775 slog.Error("error generating title with large model", "err", err)
776 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
777 if saveErr != nil {
778 slog.Error("failed to save session title and usage", "error", saveErr)
779 }
780 return
781 }
782 }
783
784 if resp == nil {
785 // Actually, we didn't get a response so we can't. Use the default
786 // session name and return.
787 slog.Error("response is nil; can't generate title")
788 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
789 if saveErr != nil {
790 slog.Error("failed to save session title and usage", "error", saveErr)
791 }
792 return
793 }
794
795 // Clean up title.
796 var title string
797 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
798 slog.Info("generated title", "title", title)
799
800 // Remove thinking tags if present.
801 title = thinkTagRegex.ReplaceAllString(title, "")
802
803 title = strings.TrimSpace(title)
804 if title == "" {
805 slog.Warn("empty title; using fallback")
806 title = defaultSessionName
807 }
808
809 // Calculate usage and cost.
810 var openrouterCost *float64
811 for _, step := range resp.Steps {
812 stepCost := a.openrouterCost(step.ProviderMetadata)
813 if stepCost != nil {
814 newCost := *stepCost
815 if openrouterCost != nil {
816 newCost += *openrouterCost
817 }
818 openrouterCost = &newCost
819 }
820 }
821
822 modelConfig := model.CatwalkCfg
823 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
824 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
825 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
826 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
827
828 // Use override cost if available (e.g., from OpenRouter).
829 if openrouterCost != nil {
830 cost = *openrouterCost
831 }
832
833 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
834 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
835
836 // Atomically update only title and usage fields to avoid overriding other
837 // concurrent session updates.
838 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
839 if saveErr != nil {
840 slog.Error("failed to save session title and usage", "error", saveErr)
841 return
842 }
843}
844
845func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
846 openrouterMetadata, ok := metadata[openrouter.Name]
847 if !ok {
848 return nil
849 }
850
851 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
852 if !ok {
853 return nil
854 }
855 return &opts.Usage.Cost
856}
857
858func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
859 modelConfig := model.CatwalkCfg
860 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
861 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
862 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
863 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
864
865 a.eventTokensUsed(session.ID, model, usage, cost)
866
867 if overrideCost != nil {
868 session.Cost += *overrideCost
869 } else {
870 session.Cost += cost
871 }
872
873 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
874 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
875}
876
877func (a *sessionAgent) Cancel(sessionID string) {
878 // Cancel regular requests.
879 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
880 slog.Info("Request cancellation initiated", "session_id", sessionID)
881 cancel()
882 }
883
884 // Also check for summarize requests.
885 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
886 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
887 cancel()
888 }
889
890 if a.QueuedPrompts(sessionID) > 0 {
891 slog.Info("Clearing queued prompts", "session_id", sessionID)
892 a.messageQueue.Del(sessionID)
893 }
894}
895
896func (a *sessionAgent) ClearQueue(sessionID string) {
897 if a.QueuedPrompts(sessionID) > 0 {
898 slog.Info("Clearing queued prompts", "session_id", sessionID)
899 a.messageQueue.Del(sessionID)
900 }
901}
902
903func (a *sessionAgent) CancelAll() {
904 if !a.IsBusy() {
905 return
906 }
907 for key := range a.activeRequests.Seq2() {
908 a.Cancel(key) // key is sessionID
909 }
910
911 timeout := time.After(5 * time.Second)
912 for a.IsBusy() {
913 select {
914 case <-timeout:
915 return
916 default:
917 time.Sleep(200 * time.Millisecond)
918 }
919 }
920}
921
922func (a *sessionAgent) IsBusy() bool {
923 var busy bool
924 for cancelFunc := range a.activeRequests.Seq() {
925 if cancelFunc != nil {
926 busy = true
927 break
928 }
929 }
930 return busy
931}
932
933func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
934 _, busy := a.activeRequests.Get(sessionID)
935 return busy
936}
937
938func (a *sessionAgent) QueuedPrompts(sessionID string) int {
939 l, ok := a.messageQueue.Get(sessionID)
940 if !ok {
941 return 0
942 }
943 return len(l)
944}
945
946func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
947 l, ok := a.messageQueue.Get(sessionID)
948 if !ok {
949 return nil
950 }
951 prompts := make([]string, len(l))
952 for i, call := range l {
953 prompts[i] = call.Prompt
954 }
955 return prompts
956}
957
958func (a *sessionAgent) SetModels(large Model, small Model) {
959 a.largeModel = large
960 a.smallModel = small
961}
962
963func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
964 a.tools = tools
965}
966
967func (a *sessionAgent) Model() Model {
968 return a.largeModel
969}
970
971func (a *sessionAgent) promptPrefix() string {
972 return a.systemPromptPrefix
973}
974
975// convertToToolResult converts a fantasy tool result to a message tool result.
976func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
977 baseResult := message.ToolResult{
978 ToolCallID: result.ToolCallID,
979 Name: result.ToolName,
980 Metadata: result.ClientMetadata,
981 }
982
983 switch result.Result.GetType() {
984 case fantasy.ToolResultContentTypeText:
985 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
986 baseResult.Content = r.Text
987 }
988 case fantasy.ToolResultContentTypeError:
989 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
990 baseResult.Content = r.Error.Error()
991 baseResult.IsError = true
992 }
993 case fantasy.ToolResultContentTypeMedia:
994 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
995 content := r.Text
996 if content == "" {
997 content = fmt.Sprintf("Loaded %s content", r.MediaType)
998 }
999 baseResult.Content = content
1000 baseResult.Data = r.Data
1001 baseResult.MIMEType = r.MediaType
1002 }
1003 }
1004
1005 return baseResult
1006}
1007
1008// workaroundProviderMediaLimitations converts media content in tool results to
1009// user messages for providers that don't natively support images in tool results.
1010//
1011// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1012// don't support sending images/media in tool result messages - they only accept
1013// text in tool results. However, they DO support images in user messages.
1014//
1015// If we send media in tool results to these providers, the API returns an error.
1016//
1017// Solution: For these providers, we:
1018// 1. Replace the media in the tool result with a text placeholder
1019// 2. Inject a user message immediately after with the image as a file attachment
1020// 3. This maintains the tool execution flow while working around API limitations
1021//
1022// Anthropic and Bedrock support images natively in tool results, so we skip
1023// this workaround for them.
1024//
1025// Example transformation:
1026//
1027// BEFORE: [tool result: image data]
1028// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1029func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1030 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1031 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1032
1033 if providerSupportsMedia {
1034 return messages
1035 }
1036
1037 convertedMessages := make([]fantasy.Message, 0, len(messages))
1038
1039 for _, msg := range messages {
1040 if msg.Role != fantasy.MessageRoleTool {
1041 convertedMessages = append(convertedMessages, msg)
1042 continue
1043 }
1044
1045 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1046 var mediaFiles []fantasy.FilePart
1047
1048 for _, part := range msg.Content {
1049 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1050 if !ok {
1051 textParts = append(textParts, part)
1052 continue
1053 }
1054
1055 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1056 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1057 if err != nil {
1058 slog.Warn("failed to decode media data", "error", err)
1059 textParts = append(textParts, part)
1060 continue
1061 }
1062
1063 mediaFiles = append(mediaFiles, fantasy.FilePart{
1064 Data: decoded,
1065 MediaType: media.MediaType,
1066 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1067 })
1068
1069 textParts = append(textParts, fantasy.ToolResultPart{
1070 ToolCallID: toolResult.ToolCallID,
1071 Output: fantasy.ToolResultOutputContentText{
1072 Text: "[Image/media content loaded - see attached file]",
1073 },
1074 ProviderOptions: toolResult.ProviderOptions,
1075 })
1076 } else {
1077 textParts = append(textParts, part)
1078 }
1079 }
1080
1081 convertedMessages = append(convertedMessages, fantasy.Message{
1082 Role: fantasy.MessageRoleTool,
1083 Content: textParts,
1084 })
1085
1086 if len(mediaFiles) > 0 {
1087 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1088 "Here is the media content from the tool result:",
1089 mediaFiles...,
1090 ))
1091 }
1092 }
1093
1094 return convertedMessages
1095}
1096
1097// buildSummaryPrompt constructs the prompt text for session summarization.
1098func buildSummaryPrompt(todos []session.Todo) string {
1099 var sb strings.Builder
1100 sb.WriteString("Provide a detailed summary of our conversation above.")
1101 if len(todos) > 0 {
1102 sb.WriteString("\n\n## Current Todo List\n\n")
1103 for _, t := range todos {
1104 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1105 }
1106 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1107 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1108 }
1109 return sb.String()
1110}