1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/fantasy"
26 "charm.land/fantasy/providers/anthropic"
27 "charm.land/fantasy/providers/bedrock"
28 "charm.land/fantasy/providers/google"
29 "charm.land/fantasy/providers/openai"
30 "charm.land/fantasy/providers/openrouter"
31 "charm.land/lipgloss/v2"
32 "github.com/charmbracelet/catwalk/pkg/catwalk"
33 "github.com/charmbracelet/crush/internal/agent/hyper"
34 "github.com/charmbracelet/crush/internal/agent/tools"
35 "github.com/charmbracelet/crush/internal/config"
36 "github.com/charmbracelet/crush/internal/csync"
37 "github.com/charmbracelet/crush/internal/message"
38 "github.com/charmbracelet/crush/internal/permission"
39 "github.com/charmbracelet/crush/internal/session"
40 "github.com/charmbracelet/crush/internal/stringext"
41)
42
43const defaultSessionName = "Untitled Session"
44
45//go:embed templates/title.md
46var titlePrompt []byte
47
48//go:embed templates/summary.md
49var summaryPrompt []byte
50
51// Used to remove <think> tags from generated titles.
52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
53
54type SessionAgentCall struct {
55 SessionID string
56 Prompt string
57 ProviderOptions fantasy.ProviderOptions
58 Attachments []message.Attachment
59 MaxOutputTokens int64
60 Temperature *float64
61 TopP *float64
62 TopK *int64
63 FrequencyPenalty *float64
64 PresencePenalty *float64
65}
66
67type SessionAgent interface {
68 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
69 SetModels(large Model, small Model)
70 SetTools(tools []fantasy.AgentTool)
71 SetSystemPrompt(systemPrompt string)
72 Cancel(sessionID string)
73 CancelAll()
74 IsSessionBusy(sessionID string) bool
75 IsBusy() bool
76 QueuedPrompts(sessionID string) int
77 QueuedPromptsList(sessionID string) []string
78 ClearQueue(sessionID string)
79 Summarize(context.Context, string, fantasy.ProviderOptions) error
80 Model() Model
81}
82
83type Model struct {
84 Model fantasy.LanguageModel
85 CatwalkCfg catwalk.Model
86 ModelCfg config.SelectedModel
87}
88
89type sessionAgent struct {
90 largeModel Model
91 smallModel Model
92 systemPromptPrefix string
93 systemPrompt string
94 isSubAgent bool
95 tools []fantasy.AgentTool
96 sessions session.Service
97 messages message.Service
98 disableAutoSummarize bool
99 isYolo bool
100
101 messageQueue *csync.Map[string, []SessionAgentCall]
102 activeRequests *csync.Map[string, context.CancelFunc]
103}
104
105type SessionAgentOptions struct {
106 LargeModel Model
107 SmallModel Model
108 SystemPromptPrefix string
109 SystemPrompt string
110 IsSubAgent bool
111 DisableAutoSummarize bool
112 IsYolo bool
113 Sessions session.Service
114 Messages message.Service
115 Tools []fantasy.AgentTool
116}
117
118func NewSessionAgent(
119 opts SessionAgentOptions,
120) SessionAgent {
121 return &sessionAgent{
122 largeModel: opts.LargeModel,
123 smallModel: opts.SmallModel,
124 systemPromptPrefix: opts.SystemPromptPrefix,
125 systemPrompt: opts.SystemPrompt,
126 isSubAgent: opts.IsSubAgent,
127 sessions: opts.Sessions,
128 messages: opts.Messages,
129 disableAutoSummarize: opts.DisableAutoSummarize,
130 tools: opts.Tools,
131 isYolo: opts.IsYolo,
132 messageQueue: csync.NewMap[string, []SessionAgentCall](),
133 activeRequests: csync.NewMap[string, context.CancelFunc](),
134 }
135}
136
137func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
138 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
139 return nil, ErrEmptyPrompt
140 }
141 if call.SessionID == "" {
142 return nil, ErrSessionMissing
143 }
144
145 // Queue the message if busy
146 if a.IsSessionBusy(call.SessionID) {
147 existing, ok := a.messageQueue.Get(call.SessionID)
148 if !ok {
149 existing = []SessionAgentCall{}
150 }
151 existing = append(existing, call)
152 a.messageQueue.Set(call.SessionID, existing)
153 return nil, nil
154 }
155
156 if len(a.tools) > 0 {
157 // Add Anthropic caching to the last tool.
158 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
159 }
160
161 agent := fantasy.NewAgent(
162 a.largeModel.Model,
163 fantasy.WithSystemPrompt(a.systemPrompt),
164 fantasy.WithTools(a.tools...),
165 )
166
167 sessionLock := sync.Mutex{}
168 currentSession, err := a.sessions.Get(ctx, call.SessionID)
169 if err != nil {
170 return nil, fmt.Errorf("failed to get session: %w", err)
171 }
172
173 msgs, err := a.getSessionMessages(ctx, currentSession)
174 if err != nil {
175 return nil, fmt.Errorf("failed to get session messages: %w", err)
176 }
177
178 var wg sync.WaitGroup
179 // Generate title if first message.
180 if len(msgs) == 0 {
181 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
182 wg.Go(func() {
183 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
184 })
185 }
186
187 // Add the user message to the session.
188 _, err = a.createUserMessage(ctx, call)
189 if err != nil {
190 return nil, err
191 }
192
193 // Add the session to the context.
194 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
195
196 genCtx, cancel := context.WithCancel(ctx)
197 a.activeRequests.Set(call.SessionID, cancel)
198
199 defer cancel()
200 defer a.activeRequests.Del(call.SessionID)
201
202 history, files := a.preparePrompt(msgs, call.Attachments...)
203
204 startTime := time.Now()
205 a.eventPromptSent(call.SessionID)
206
207 var currentAssistant *message.Message
208 var shouldSummarize bool
209 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
210 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
211 Files: files,
212 Messages: history,
213 ProviderOptions: call.ProviderOptions,
214 MaxOutputTokens: &call.MaxOutputTokens,
215 TopP: call.TopP,
216 Temperature: call.Temperature,
217 PresencePenalty: call.PresencePenalty,
218 TopK: call.TopK,
219 FrequencyPenalty: call.FrequencyPenalty,
220 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
221 prepared.Messages = options.Messages
222 for i := range prepared.Messages {
223 prepared.Messages[i].ProviderOptions = nil
224 }
225
226 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
227 a.messageQueue.Del(call.SessionID)
228 for _, queued := range queuedCalls {
229 userMessage, createErr := a.createUserMessage(callContext, queued)
230 if createErr != nil {
231 return callContext, prepared, createErr
232 }
233 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
234 }
235
236 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
237
238 lastSystemRoleInx := 0
239 systemMessageUpdated := false
240 for i, msg := range prepared.Messages {
241 // Only add cache control to the last message.
242 if msg.Role == fantasy.MessageRoleSystem {
243 lastSystemRoleInx = i
244 } else if !systemMessageUpdated {
245 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
246 systemMessageUpdated = true
247 }
248 // Than add cache control to the last 2 messages.
249 if i > len(prepared.Messages)-3 {
250 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
251 }
252 }
253
254 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
255 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
256 }
257
258 var assistantMsg message.Message
259 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
260 Role: message.Assistant,
261 Parts: []message.ContentPart{},
262 Model: a.largeModel.ModelCfg.Model,
263 Provider: a.largeModel.ModelCfg.Provider,
264 })
265 if err != nil {
266 return callContext, prepared, err
267 }
268 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
269 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
270 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
271 currentAssistant = &assistantMsg
272 return callContext, prepared, err
273 },
274 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
275 currentAssistant.AppendReasoningContent(reasoning.Text)
276 return a.messages.Update(genCtx, *currentAssistant)
277 },
278 OnReasoningDelta: func(id string, text string) error {
279 currentAssistant.AppendReasoningContent(text)
280 return a.messages.Update(genCtx, *currentAssistant)
281 },
282 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
283 // handle anthropic signature
284 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
285 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
286 currentAssistant.AppendReasoningSignature(reasoning.Signature)
287 }
288 }
289 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
290 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
291 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
292 }
293 }
294 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
295 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
296 currentAssistant.SetReasoningResponsesData(reasoning)
297 }
298 }
299 currentAssistant.FinishThinking()
300 return a.messages.Update(genCtx, *currentAssistant)
301 },
302 OnTextDelta: func(id string, text string) error {
303 // Strip leading newline from initial text content. This is is
304 // particularly important in non-interactive mode where leading
305 // newlines are very visible.
306 if len(currentAssistant.Parts) == 0 {
307 text = strings.TrimPrefix(text, "\n")
308 }
309
310 currentAssistant.AppendContent(text)
311 return a.messages.Update(genCtx, *currentAssistant)
312 },
313 OnToolInputStart: func(id string, toolName string) error {
314 toolCall := message.ToolCall{
315 ID: id,
316 Name: toolName,
317 ProviderExecuted: false,
318 Finished: false,
319 }
320 currentAssistant.AddToolCall(toolCall)
321 return a.messages.Update(genCtx, *currentAssistant)
322 },
323 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
324 // TODO: implement
325 },
326 OnToolCall: func(tc fantasy.ToolCallContent) error {
327 toolCall := message.ToolCall{
328 ID: tc.ToolCallID,
329 Name: tc.ToolName,
330 Input: tc.Input,
331 ProviderExecuted: false,
332 Finished: true,
333 }
334 currentAssistant.AddToolCall(toolCall)
335 return a.messages.Update(genCtx, *currentAssistant)
336 },
337 OnToolResult: func(result fantasy.ToolResultContent) error {
338 toolResult := a.convertToToolResult(result)
339 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
340 Role: message.Tool,
341 Parts: []message.ContentPart{
342 toolResult,
343 },
344 })
345 return createMsgErr
346 },
347 OnStepFinish: func(stepResult fantasy.StepResult) error {
348 finishReason := message.FinishReasonUnknown
349 switch stepResult.FinishReason {
350 case fantasy.FinishReasonLength:
351 finishReason = message.FinishReasonMaxTokens
352 case fantasy.FinishReasonStop:
353 finishReason = message.FinishReasonEndTurn
354 case fantasy.FinishReasonToolCalls:
355 finishReason = message.FinishReasonToolUse
356 }
357 currentAssistant.AddFinish(finishReason, "", "")
358 sessionLock.Lock()
359 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
360 if getSessionErr != nil {
361 sessionLock.Unlock()
362 return getSessionErr
363 }
364 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
365 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
366 sessionLock.Unlock()
367 if sessionErr != nil {
368 return sessionErr
369 }
370 return a.messages.Update(genCtx, *currentAssistant)
371 },
372 StopWhen: []fantasy.StopCondition{
373 func(_ []fantasy.StepResult) bool {
374 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
375 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
376 remaining := cw - tokens
377 var threshold int64
378 if cw > 200_000 {
379 threshold = 20_000
380 } else {
381 threshold = int64(float64(cw) * 0.2)
382 }
383 if (remaining <= threshold) && !a.disableAutoSummarize {
384 shouldSummarize = true
385 return true
386 }
387 return false
388 },
389 },
390 })
391
392 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
393
394 if err != nil {
395 isCancelErr := errors.Is(err, context.Canceled)
396 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
397 if currentAssistant == nil {
398 return result, err
399 }
400 // Ensure we finish thinking on error to close the reasoning state.
401 currentAssistant.FinishThinking()
402 toolCalls := currentAssistant.ToolCalls()
403 // INFO: we use the parent context here because the genCtx has been cancelled.
404 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
405 if createErr != nil {
406 return nil, createErr
407 }
408 for _, tc := range toolCalls {
409 if !tc.Finished {
410 tc.Finished = true
411 tc.Input = "{}"
412 currentAssistant.AddToolCall(tc)
413 updateErr := a.messages.Update(ctx, *currentAssistant)
414 if updateErr != nil {
415 return nil, updateErr
416 }
417 }
418
419 found := false
420 for _, msg := range msgs {
421 if msg.Role == message.Tool {
422 for _, tr := range msg.ToolResults() {
423 if tr.ToolCallID == tc.ID {
424 found = true
425 break
426 }
427 }
428 }
429 if found {
430 break
431 }
432 }
433 if found {
434 continue
435 }
436 content := "There was an error while executing the tool"
437 if isCancelErr {
438 content = "Tool execution canceled by user"
439 } else if isPermissionErr {
440 content = "User denied permission"
441 }
442 toolResult := message.ToolResult{
443 ToolCallID: tc.ID,
444 Name: tc.Name,
445 Content: content,
446 IsError: true,
447 }
448 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
449 Role: message.Tool,
450 Parts: []message.ContentPart{
451 toolResult,
452 },
453 })
454 if createErr != nil {
455 return nil, createErr
456 }
457 }
458 var fantasyErr *fantasy.Error
459 var providerErr *fantasy.ProviderError
460 const defaultTitle = "Provider Error"
461 if isCancelErr {
462 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
463 } else if isPermissionErr {
464 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
465 } else if errors.Is(err, hyper.ErrNoCredits) {
466 url := hyper.BaseURL()
467 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
468 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
469 } else if errors.As(err, &providerErr) {
470 if providerErr.Message == "The requested model is not supported." {
471 url := "https://github.com/settings/copilot/features"
472 link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
473 currentAssistant.AddFinish(
474 message.FinishReasonError,
475 "Copilot model not enabled",
476 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
477 )
478 } else {
479 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
480 }
481 } else if errors.As(err, &fantasyErr) {
482 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
483 } else {
484 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
485 }
486 // Note: we use the parent context here because the genCtx has been
487 // cancelled.
488 updateErr := a.messages.Update(ctx, *currentAssistant)
489 if updateErr != nil {
490 return nil, updateErr
491 }
492 return nil, err
493 }
494 wg.Wait()
495
496 if shouldSummarize {
497 a.activeRequests.Del(call.SessionID)
498 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
499 return nil, summarizeErr
500 }
501 // If the agent wasn't done...
502 if len(currentAssistant.ToolCalls()) > 0 {
503 existing, ok := a.messageQueue.Get(call.SessionID)
504 if !ok {
505 existing = []SessionAgentCall{}
506 }
507 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
508 existing = append(existing, call)
509 a.messageQueue.Set(call.SessionID, existing)
510 }
511 }
512
513 // Release active request before processing queued messages.
514 a.activeRequests.Del(call.SessionID)
515 cancel()
516
517 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
518 if !ok || len(queuedMessages) == 0 {
519 return result, err
520 }
521 // There are queued messages restart the loop.
522 firstQueuedMessage := queuedMessages[0]
523 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
524 return a.Run(ctx, firstQueuedMessage)
525}
526
527func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
528 if a.IsSessionBusy(sessionID) {
529 return ErrSessionBusy
530 }
531
532 currentSession, err := a.sessions.Get(ctx, sessionID)
533 if err != nil {
534 return fmt.Errorf("failed to get session: %w", err)
535 }
536 msgs, err := a.getSessionMessages(ctx, currentSession)
537 if err != nil {
538 return err
539 }
540 if len(msgs) == 0 {
541 // Nothing to summarize.
542 return nil
543 }
544
545 aiMsgs, _ := a.preparePrompt(msgs)
546
547 genCtx, cancel := context.WithCancel(ctx)
548 a.activeRequests.Set(sessionID, cancel)
549 defer a.activeRequests.Del(sessionID)
550 defer cancel()
551
552 agent := fantasy.NewAgent(a.largeModel.Model,
553 fantasy.WithSystemPrompt(string(summaryPrompt)),
554 )
555 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
556 Role: message.Assistant,
557 Model: a.largeModel.Model.Model(),
558 Provider: a.largeModel.Model.Provider(),
559 IsSummaryMessage: true,
560 })
561 if err != nil {
562 return err
563 }
564
565 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
566
567 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
568 Prompt: summaryPromptText,
569 Messages: aiMsgs,
570 ProviderOptions: opts,
571 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
572 prepared.Messages = options.Messages
573 if a.systemPromptPrefix != "" {
574 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
575 }
576 return callContext, prepared, nil
577 },
578 OnReasoningDelta: func(id string, text string) error {
579 summaryMessage.AppendReasoningContent(text)
580 return a.messages.Update(genCtx, summaryMessage)
581 },
582 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
583 // Handle anthropic signature.
584 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
585 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
586 summaryMessage.AppendReasoningSignature(signature.Signature)
587 }
588 }
589 summaryMessage.FinishThinking()
590 return a.messages.Update(genCtx, summaryMessage)
591 },
592 OnTextDelta: func(id, text string) error {
593 summaryMessage.AppendContent(text)
594 return a.messages.Update(genCtx, summaryMessage)
595 },
596 })
597 if err != nil {
598 isCancelErr := errors.Is(err, context.Canceled)
599 if isCancelErr {
600 // User cancelled summarize we need to remove the summary message.
601 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
602 return deleteErr
603 }
604 return err
605 }
606
607 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
608 err = a.messages.Update(genCtx, summaryMessage)
609 if err != nil {
610 return err
611 }
612
613 var openrouterCost *float64
614 for _, step := range resp.Steps {
615 stepCost := a.openrouterCost(step.ProviderMetadata)
616 if stepCost != nil {
617 newCost := *stepCost
618 if openrouterCost != nil {
619 newCost += *openrouterCost
620 }
621 openrouterCost = &newCost
622 }
623 }
624
625 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
626
627 // Just in case, get just the last usage info.
628 usage := resp.Response.Usage
629 currentSession.SummaryMessageID = summaryMessage.ID
630 currentSession.CompletionTokens = usage.OutputTokens
631 currentSession.PromptTokens = 0
632 _, err = a.sessions.Save(genCtx, currentSession)
633 return err
634}
635
636func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
637 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
638 return fantasy.ProviderOptions{}
639 }
640 return fantasy.ProviderOptions{
641 anthropic.Name: &anthropic.ProviderCacheControlOptions{
642 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
643 },
644 bedrock.Name: &anthropic.ProviderCacheControlOptions{
645 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
646 },
647 }
648}
649
650func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
651 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
652 var attachmentParts []message.ContentPart
653 for _, attachment := range call.Attachments {
654 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
655 }
656 parts = append(parts, attachmentParts...)
657 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
658 Role: message.User,
659 Parts: parts,
660 })
661 if err != nil {
662 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
663 }
664 return msg, nil
665}
666
667func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
668 var history []fantasy.Message
669 if !a.isSubAgent {
670 history = append(history, fantasy.NewUserMessage(
671 fmt.Sprintf("<system_reminder>%s</system_reminder>",
672 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
673If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
674If not, please feel free to ignore. Again do not mention this message to the user.`,
675 ),
676 ))
677 }
678 for _, m := range msgs {
679 if len(m.Parts) == 0 {
680 continue
681 }
682 // Assistant message without content or tool calls (cancelled before it
683 // returned anything).
684 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
685 continue
686 }
687 history = append(history, m.ToAIMessage()...)
688 }
689
690 var files []fantasy.FilePart
691 for _, attachment := range attachments {
692 if attachment.IsText() {
693 continue
694 }
695 files = append(files, fantasy.FilePart{
696 Filename: attachment.FileName,
697 Data: attachment.Content,
698 MediaType: attachment.MimeType,
699 })
700 }
701
702 return history, files
703}
704
705func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
706 msgs, err := a.messages.List(ctx, session.ID)
707 if err != nil {
708 return nil, fmt.Errorf("failed to list messages: %w", err)
709 }
710
711 if session.SummaryMessageID != "" {
712 summaryMsgInex := -1
713 for i, msg := range msgs {
714 if msg.ID == session.SummaryMessageID {
715 summaryMsgInex = i
716 break
717 }
718 }
719 if summaryMsgInex != -1 {
720 msgs = msgs[summaryMsgInex:]
721 msgs[0].Role = message.User
722 }
723 }
724 return msgs, nil
725}
726
727// generateTitle generates a session titled based on the initial prompt.
728func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
729 if userPrompt == "" {
730 return
731 }
732
733 var maxOutputTokens int64 = 40
734 if a.smallModel.CatwalkCfg.CanReason {
735 maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
736 }
737
738 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
739 return fantasy.NewAgent(m,
740 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
741 fantasy.WithMaxOutputTokens(tok),
742 )
743 }
744
745 streamCall := fantasy.AgentStreamCall{
746 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
747 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
748 prepared.Messages = opts.Messages
749 if a.systemPromptPrefix != "" {
750 prepared.Messages = append([]fantasy.Message{
751 fantasy.NewSystemMessage(a.systemPromptPrefix),
752 }, prepared.Messages...)
753 }
754 return callCtx, prepared, nil
755 },
756 }
757
758 // Use the small model to generate the title.
759 model := &a.smallModel
760 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
761 resp, err := agent.Stream(ctx, streamCall)
762 if err == nil {
763 // We successfully generated a title with the small model.
764 slog.Info("generated title with small model")
765 } else {
766 // It didn't work. Let's try with the big model.
767 slog.Error("error generating title with small model; trying big model", "err", err)
768 model = &a.largeModel
769 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
770 resp, err = agent.Stream(ctx, streamCall)
771 if err == nil {
772 slog.Info("generated title with large model")
773 } else {
774 // Welp, the large model didn't work either. Use the default
775 // session name and return.
776 slog.Error("error generating title with large model", "err", err)
777 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
778 if saveErr != nil {
779 slog.Error("failed to save session title and usage", "error", saveErr)
780 }
781 return
782 }
783 }
784
785 if resp == nil {
786 // Actually, we didn't get a response so we can't. Use the default
787 // session name and return.
788 slog.Error("response is nil; can't generate title")
789 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
790 if saveErr != nil {
791 slog.Error("failed to save session title and usage", "error", saveErr)
792 }
793 return
794 }
795
796 // Clean up title.
797 var title string
798 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
799 slog.Info("generated title", "title", title)
800
801 // Remove thinking tags if present.
802 title = thinkTagRegex.ReplaceAllString(title, "")
803
804 title = strings.TrimSpace(title)
805 if title == "" {
806 slog.Warn("empty title; using fallback")
807 title = defaultSessionName
808 }
809
810 // Calculate usage and cost.
811 var openrouterCost *float64
812 for _, step := range resp.Steps {
813 stepCost := a.openrouterCost(step.ProviderMetadata)
814 if stepCost != nil {
815 newCost := *stepCost
816 if openrouterCost != nil {
817 newCost += *openrouterCost
818 }
819 openrouterCost = &newCost
820 }
821 }
822
823 modelConfig := model.CatwalkCfg
824 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
825 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
826 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
827 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
828
829 // Use override cost if available (e.g., from OpenRouter).
830 if openrouterCost != nil {
831 cost = *openrouterCost
832 }
833
834 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
835 completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
836
837 // Atomically update only title and usage fields to avoid overriding other
838 // concurrent session updates.
839 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
840 if saveErr != nil {
841 slog.Error("failed to save session title and usage", "error", saveErr)
842 return
843 }
844}
845
846func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
847 openrouterMetadata, ok := metadata[openrouter.Name]
848 if !ok {
849 return nil
850 }
851
852 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
853 if !ok {
854 return nil
855 }
856 return &opts.Usage.Cost
857}
858
859func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
860 modelConfig := model.CatwalkCfg
861 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
862 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
863 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
864 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
865
866 a.eventTokensUsed(session.ID, model, usage, cost)
867
868 if overrideCost != nil {
869 session.Cost += *overrideCost
870 } else {
871 session.Cost += cost
872 }
873
874 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
875 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
876}
877
878func (a *sessionAgent) Cancel(sessionID string) {
879 // Cancel regular requests. Don't use Take() here - we need the entry to
880 // remain in activeRequests so IsBusy() returns true until the goroutine
881 // fully completes (including error handling that may access the DB).
882 // The defer in processRequest will clean up the entry.
883 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
884 slog.Info("Request cancellation initiated", "session_id", sessionID)
885 cancel()
886 }
887
888 // Also check for summarize requests.
889 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
890 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
891 cancel()
892 }
893
894 if a.QueuedPrompts(sessionID) > 0 {
895 slog.Info("Clearing queued prompts", "session_id", sessionID)
896 a.messageQueue.Del(sessionID)
897 }
898}
899
900func (a *sessionAgent) ClearQueue(sessionID string) {
901 if a.QueuedPrompts(sessionID) > 0 {
902 slog.Info("Clearing queued prompts", "session_id", sessionID)
903 a.messageQueue.Del(sessionID)
904 }
905}
906
907func (a *sessionAgent) CancelAll() {
908 if !a.IsBusy() {
909 return
910 }
911 for key := range a.activeRequests.Seq2() {
912 a.Cancel(key) // key is sessionID
913 }
914
915 timeout := time.After(5 * time.Second)
916 for a.IsBusy() {
917 select {
918 case <-timeout:
919 return
920 default:
921 time.Sleep(200 * time.Millisecond)
922 }
923 }
924}
925
926func (a *sessionAgent) IsBusy() bool {
927 var busy bool
928 for cancelFunc := range a.activeRequests.Seq() {
929 if cancelFunc != nil {
930 busy = true
931 break
932 }
933 }
934 return busy
935}
936
937func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
938 _, busy := a.activeRequests.Get(sessionID)
939 return busy
940}
941
942func (a *sessionAgent) QueuedPrompts(sessionID string) int {
943 l, ok := a.messageQueue.Get(sessionID)
944 if !ok {
945 return 0
946 }
947 return len(l)
948}
949
950func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
951 l, ok := a.messageQueue.Get(sessionID)
952 if !ok {
953 return nil
954 }
955 prompts := make([]string, len(l))
956 for i, call := range l {
957 prompts[i] = call.Prompt
958 }
959 return prompts
960}
961
962func (a *sessionAgent) SetModels(large Model, small Model) {
963 a.largeModel = large
964 a.smallModel = small
965}
966
967func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
968 a.tools = tools
969}
970
971func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
972 a.systemPrompt = systemPrompt
973}
974
975func (a *sessionAgent) Model() Model {
976 return a.largeModel
977}
978
979func (a *sessionAgent) promptPrefix() string {
980 return a.systemPromptPrefix
981}
982
983// convertToToolResult converts a fantasy tool result to a message tool result.
984func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
985 baseResult := message.ToolResult{
986 ToolCallID: result.ToolCallID,
987 Name: result.ToolName,
988 Metadata: result.ClientMetadata,
989 }
990
991 switch result.Result.GetType() {
992 case fantasy.ToolResultContentTypeText:
993 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
994 baseResult.Content = r.Text
995 }
996 case fantasy.ToolResultContentTypeError:
997 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
998 baseResult.Content = r.Error.Error()
999 baseResult.IsError = true
1000 }
1001 case fantasy.ToolResultContentTypeMedia:
1002 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1003 content := r.Text
1004 if content == "" {
1005 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1006 }
1007 baseResult.Content = content
1008 baseResult.Data = r.Data
1009 baseResult.MIMEType = r.MediaType
1010 }
1011 }
1012
1013 return baseResult
1014}
1015
1016// workaroundProviderMediaLimitations converts media content in tool results to
1017// user messages for providers that don't natively support images in tool results.
1018//
1019// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1020// don't support sending images/media in tool result messages - they only accept
1021// text in tool results. However, they DO support images in user messages.
1022//
1023// If we send media in tool results to these providers, the API returns an error.
1024//
1025// Solution: For these providers, we:
1026// 1. Replace the media in the tool result with a text placeholder
1027// 2. Inject a user message immediately after with the image as a file attachment
1028// 3. This maintains the tool execution flow while working around API limitations
1029//
1030// Anthropic and Bedrock support images natively in tool results, so we skip
1031// this workaround for them.
1032//
1033// Example transformation:
1034//
1035// BEFORE: [tool result: image data]
1036// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1037func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1038 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1039 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1040
1041 if providerSupportsMedia {
1042 return messages
1043 }
1044
1045 convertedMessages := make([]fantasy.Message, 0, len(messages))
1046
1047 for _, msg := range messages {
1048 if msg.Role != fantasy.MessageRoleTool {
1049 convertedMessages = append(convertedMessages, msg)
1050 continue
1051 }
1052
1053 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1054 var mediaFiles []fantasy.FilePart
1055
1056 for _, part := range msg.Content {
1057 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1058 if !ok {
1059 textParts = append(textParts, part)
1060 continue
1061 }
1062
1063 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1064 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1065 if err != nil {
1066 slog.Warn("failed to decode media data", "error", err)
1067 textParts = append(textParts, part)
1068 continue
1069 }
1070
1071 mediaFiles = append(mediaFiles, fantasy.FilePart{
1072 Data: decoded,
1073 MediaType: media.MediaType,
1074 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1075 })
1076
1077 textParts = append(textParts, fantasy.ToolResultPart{
1078 ToolCallID: toolResult.ToolCallID,
1079 Output: fantasy.ToolResultOutputContentText{
1080 Text: "[Image/media content loaded - see attached file]",
1081 },
1082 ProviderOptions: toolResult.ProviderOptions,
1083 })
1084 } else {
1085 textParts = append(textParts, part)
1086 }
1087 }
1088
1089 convertedMessages = append(convertedMessages, fantasy.Message{
1090 Role: fantasy.MessageRoleTool,
1091 Content: textParts,
1092 })
1093
1094 if len(mediaFiles) > 0 {
1095 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1096 "Here is the media content from the tool result:",
1097 mediaFiles...,
1098 ))
1099 }
1100 }
1101
1102 return convertedMessages
1103}
1104
1105// buildSummaryPrompt constructs the prompt text for session summarization.
1106func buildSummaryPrompt(todos []session.Todo) string {
1107 var sb strings.Builder
1108 sb.WriteString("Provide a detailed summary of our conversation above.")
1109 if len(todos) > 0 {
1110 sb.WriteString("\n\n## Current Todo List\n\n")
1111 for _, t := range todos {
1112 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1113 }
1114 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1115 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1116 }
1117 return sb.String()
1118}