1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "git.secluded.site/crush/internal/agent/tools"
31 "git.secluded.site/crush/internal/config"
32 "git.secluded.site/crush/internal/csync"
33 "git.secluded.site/crush/internal/message"
34 "git.secluded.site/crush/internal/notification"
35 "git.secluded.site/crush/internal/permission"
36 "git.secluded.site/crush/internal/session"
37 "git.secluded.site/crush/internal/stringext"
38 "github.com/charmbracelet/catwalk/pkg/catwalk"
39)
40
41//go:embed templates/title.md
42var titlePrompt []byte
43
44//go:embed templates/summary.md
45var summaryPrompt []byte
46
47type SessionAgentCall struct {
48 SessionID string
49 Prompt string
50 ProviderOptions fantasy.ProviderOptions
51 Attachments []message.Attachment
52 MaxOutputTokens int64
53 Temperature *float64
54 TopP *float64
55 TopK *int64
56 FrequencyPenalty *float64
57 PresencePenalty *float64
58 NonInteractive bool
59}
60
61type SessionAgent interface {
62 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
63 SetModels(large Model, small Model)
64 SetTools(tools []fantasy.AgentTool)
65 Cancel(sessionID string)
66 CancelAll()
67 IsSessionBusy(sessionID string) bool
68 IsBusy() bool
69 QueuedPrompts(sessionID string) int
70 QueuedPromptsList(sessionID string) []string
71 ClearQueue(sessionID string)
72 Summarize(context.Context, string, fantasy.ProviderOptions) error
73 Model() Model
74}
75
76type Model struct {
77 Model fantasy.LanguageModel
78 CatwalkCfg catwalk.Model
79 ModelCfg config.SelectedModel
80}
81
82type sessionAgent struct {
83 largeModel Model
84 smallModel Model
85 systemPromptPrefix string
86 systemPrompt string
87 isSubAgent bool
88 tools []fantasy.AgentTool
89 sessions session.Service
90 messages message.Service
91 disableAutoSummarize bool
92 isYolo bool
93
94 messageQueue *csync.Map[string, []SessionAgentCall]
95 activeRequests *csync.Map[string, context.CancelFunc]
96}
97
98type SessionAgentOptions struct {
99 LargeModel Model
100 SmallModel Model
101 SystemPromptPrefix string
102 SystemPrompt string
103 IsSubAgent bool
104 DisableAutoSummarize bool
105 IsYolo bool
106 Sessions session.Service
107 Messages message.Service
108 Tools []fantasy.AgentTool
109}
110
111func NewSessionAgent(
112 opts SessionAgentOptions,
113) SessionAgent {
114 return &sessionAgent{
115 largeModel: opts.LargeModel,
116 smallModel: opts.SmallModel,
117 systemPromptPrefix: opts.SystemPromptPrefix,
118 systemPrompt: opts.SystemPrompt,
119 isSubAgent: opts.IsSubAgent,
120 sessions: opts.Sessions,
121 messages: opts.Messages,
122 disableAutoSummarize: opts.DisableAutoSummarize,
123 tools: opts.Tools,
124 isYolo: opts.IsYolo,
125 messageQueue: csync.NewMap[string, []SessionAgentCall](),
126 activeRequests: csync.NewMap[string, context.CancelFunc](),
127 }
128}
129
130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
131 if call.Prompt == "" {
132 return nil, ErrEmptyPrompt
133 }
134 if call.SessionID == "" {
135 return nil, ErrSessionMissing
136 }
137
138 // Queue the message if busy
139 if a.IsSessionBusy(call.SessionID) {
140 existing, ok := a.messageQueue.Get(call.SessionID)
141 if !ok {
142 existing = []SessionAgentCall{}
143 }
144 existing = append(existing, call)
145 a.messageQueue.Set(call.SessionID, existing)
146 return nil, nil
147 }
148
149 if len(a.tools) > 0 {
150 // Add Anthropic caching to the last tool.
151 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
152 }
153
154 agent := fantasy.NewAgent(
155 a.largeModel.Model,
156 fantasy.WithSystemPrompt(a.systemPrompt),
157 fantasy.WithTools(a.tools...),
158 )
159
160 sessionLock := sync.Mutex{}
161 currentSession, err := a.sessions.Get(ctx, call.SessionID)
162 if err != nil {
163 return nil, fmt.Errorf("failed to get session: %w", err)
164 }
165
166 msgs, err := a.getSessionMessages(ctx, currentSession)
167 if err != nil {
168 return nil, fmt.Errorf("failed to get session messages: %w", err)
169 }
170
171 var wg sync.WaitGroup
172 // Generate title if first message.
173 if len(msgs) == 0 {
174 wg.Go(func() {
175 sessionLock.Lock()
176 a.generateTitle(ctx, ¤tSession, call.Prompt)
177 sessionLock.Unlock()
178 })
179 }
180
181 // Add the user message to the session.
182 _, err = a.createUserMessage(ctx, call)
183 if err != nil {
184 return nil, err
185 }
186
187 // Add the session to the context.
188 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
189
190 genCtx, cancel := context.WithCancel(ctx)
191 a.activeRequests.Set(call.SessionID, cancel)
192
193 defer cancel()
194 defer a.activeRequests.Del(call.SessionID)
195
196 history, files := a.preparePrompt(msgs, call.Attachments...)
197
198 startTime := time.Now()
199 a.eventPromptSent(call.SessionID)
200
201 var currentAssistant *message.Message
202 var shouldSummarize bool
203 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
204 Prompt: call.Prompt,
205 Files: files,
206 Messages: history,
207 ProviderOptions: call.ProviderOptions,
208 MaxOutputTokens: &call.MaxOutputTokens,
209 TopP: call.TopP,
210 Temperature: call.Temperature,
211 PresencePenalty: call.PresencePenalty,
212 TopK: call.TopK,
213 FrequencyPenalty: call.FrequencyPenalty,
214 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
215 prepared.Messages = options.Messages
216 for i := range prepared.Messages {
217 prepared.Messages[i].ProviderOptions = nil
218 }
219
220 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
221 a.messageQueue.Del(call.SessionID)
222 for _, queued := range queuedCalls {
223 userMessage, createErr := a.createUserMessage(callContext, queued)
224 if createErr != nil {
225 return callContext, prepared, createErr
226 }
227 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
228 }
229
230 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
231
232 lastSystemRoleInx := 0
233 systemMessageUpdated := false
234 for i, msg := range prepared.Messages {
235 // Only add cache control to the last message.
236 if msg.Role == fantasy.MessageRoleSystem {
237 lastSystemRoleInx = i
238 } else if !systemMessageUpdated {
239 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
240 systemMessageUpdated = true
241 }
242 // Than add cache control to the last 2 messages.
243 if i > len(prepared.Messages)-3 {
244 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
245 }
246 }
247
248 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
249 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
250 }
251
252 var assistantMsg message.Message
253 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
254 Role: message.Assistant,
255 Parts: []message.ContentPart{},
256 Model: a.largeModel.ModelCfg.Model,
257 Provider: a.largeModel.ModelCfg.Provider,
258 })
259 if err != nil {
260 return callContext, prepared, err
261 }
262 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
263 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
264 callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
265 currentAssistant = &assistantMsg
266 return callContext, prepared, err
267 },
268 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
269 currentAssistant.AppendReasoningContent(reasoning.Text)
270 return a.messages.Update(genCtx, *currentAssistant)
271 },
272 OnReasoningDelta: func(id string, text string) error {
273 currentAssistant.AppendReasoningContent(text)
274 return a.messages.Update(genCtx, *currentAssistant)
275 },
276 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
277 // handle anthropic signature
278 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
279 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
280 currentAssistant.AppendReasoningSignature(reasoning.Signature)
281 }
282 }
283 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
284 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
285 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
286 }
287 }
288 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
289 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
290 currentAssistant.SetReasoningResponsesData(reasoning)
291 }
292 }
293 currentAssistant.FinishThinking()
294 return a.messages.Update(genCtx, *currentAssistant)
295 },
296 OnTextDelta: func(id string, text string) error {
297 // Strip leading newline from initial text content. This is is
298 // particularly important in non-interactive mode where leading
299 // newlines are very visible.
300 if len(currentAssistant.Parts) == 0 {
301 text = strings.TrimPrefix(text, "\n")
302 }
303
304 currentAssistant.AppendContent(text)
305 return a.messages.Update(genCtx, *currentAssistant)
306 },
307 OnToolInputStart: func(id string, toolName string) error {
308 toolCall := message.ToolCall{
309 ID: id,
310 Name: toolName,
311 ProviderExecuted: false,
312 Finished: false,
313 }
314 currentAssistant.AddToolCall(toolCall)
315 return a.messages.Update(genCtx, *currentAssistant)
316 },
317 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
318 // TODO: implement
319 },
320 OnToolCall: func(tc fantasy.ToolCallContent) error {
321 toolCall := message.ToolCall{
322 ID: tc.ToolCallID,
323 Name: tc.ToolName,
324 Input: tc.Input,
325 ProviderExecuted: false,
326 Finished: true,
327 }
328 currentAssistant.AddToolCall(toolCall)
329 return a.messages.Update(genCtx, *currentAssistant)
330 },
331 OnToolResult: func(result fantasy.ToolResultContent) error {
332 toolResult := a.convertToToolResult(result)
333 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
334 Role: message.Tool,
335 Parts: []message.ContentPart{
336 toolResult,
337 },
338 })
339 return createMsgErr
340 },
341 OnStepFinish: func(stepResult fantasy.StepResult) error {
342 finishReason := message.FinishReasonUnknown
343 switch stepResult.FinishReason {
344 case fantasy.FinishReasonLength:
345 finishReason = message.FinishReasonMaxTokens
346 case fantasy.FinishReasonStop:
347 finishReason = message.FinishReasonEndTurn
348 case fantasy.FinishReasonToolCalls:
349 finishReason = message.FinishReasonToolUse
350 }
351 currentAssistant.AddFinish(finishReason, "", "")
352 sessionLock.Lock()
353 updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
354 if getSessionErr != nil {
355 sessionLock.Unlock()
356 return getSessionErr
357 }
358 a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
359 _, sessionErr := a.sessions.Save(genCtx, updatedSession)
360 sessionLock.Unlock()
361 if sessionErr != nil {
362 return sessionErr
363 }
364 return a.messages.Update(genCtx, *currentAssistant)
365 },
366 StopWhen: []fantasy.StopCondition{
367 func(_ []fantasy.StepResult) bool {
368 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
369 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
370 remaining := cw - tokens
371 var threshold int64
372 if cw > 200_000 {
373 threshold = 20_000
374 } else {
375 threshold = int64(float64(cw) * 0.2)
376 }
377 if (remaining <= threshold) && !a.disableAutoSummarize {
378 shouldSummarize = true
379 return true
380 }
381 return false
382 },
383 },
384 })
385
386 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
387
388 if err != nil {
389 isCancelErr := errors.Is(err, context.Canceled)
390 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
391 if currentAssistant == nil {
392 return result, err
393 }
394 // Ensure we finish thinking on error to close the reasoning state.
395 currentAssistant.FinishThinking()
396 toolCalls := currentAssistant.ToolCalls()
397 // INFO: we use the parent context here because the genCtx has been cancelled.
398 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
399 if createErr != nil {
400 return nil, createErr
401 }
402 for _, tc := range toolCalls {
403 if !tc.Finished {
404 tc.Finished = true
405 tc.Input = "{}"
406 currentAssistant.AddToolCall(tc)
407 updateErr := a.messages.Update(ctx, *currentAssistant)
408 if updateErr != nil {
409 return nil, updateErr
410 }
411 }
412
413 found := false
414 for _, msg := range msgs {
415 if msg.Role == message.Tool {
416 for _, tr := range msg.ToolResults() {
417 if tr.ToolCallID == tc.ID {
418 found = true
419 break
420 }
421 }
422 }
423 if found {
424 break
425 }
426 }
427 if found {
428 continue
429 }
430 content := "There was an error while executing the tool"
431 if isCancelErr {
432 content = "Tool execution canceled by user"
433 } else if isPermissionErr {
434 content = "User denied permission"
435 }
436 toolResult := message.ToolResult{
437 ToolCallID: tc.ID,
438 Name: tc.Name,
439 Content: content,
440 IsError: true,
441 }
442 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
443 Role: message.Tool,
444 Parts: []message.ContentPart{
445 toolResult,
446 },
447 })
448 if createErr != nil {
449 return nil, createErr
450 }
451 }
452 var fantasyErr *fantasy.Error
453 var providerErr *fantasy.ProviderError
454 const defaultTitle = "Provider Error"
455 if isCancelErr {
456 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
457 } else if isPermissionErr {
458 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
459 } else if errors.As(err, &providerErr) {
460 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
461 } else if errors.As(err, &fantasyErr) {
462 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
463 } else {
464 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
465 }
466 // Note: we use the parent context here because the genCtx has been
467 // cancelled.
468 updateErr := a.messages.Update(ctx, *currentAssistant)
469 if updateErr != nil {
470 return nil, updateErr
471 }
472 return nil, err
473 }
474 wg.Wait()
475
476 // Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
477 if !call.NonInteractive {
478 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
479 _ = notification.Send("Crush is waiting...", notifBody)
480 }
481
482 if shouldSummarize {
483 a.activeRequests.Del(call.SessionID)
484 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
485 return nil, summarizeErr
486 }
487 // If the agent wasn't done...
488 if len(currentAssistant.ToolCalls()) > 0 {
489 existing, ok := a.messageQueue.Get(call.SessionID)
490 if !ok {
491 existing = []SessionAgentCall{}
492 }
493 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
494 existing = append(existing, call)
495 a.messageQueue.Set(call.SessionID, existing)
496 }
497 }
498
499 // Release active request before processing queued messages.
500 a.activeRequests.Del(call.SessionID)
501 cancel()
502
503 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
504 if !ok || len(queuedMessages) == 0 {
505 return result, err
506 }
507 // There are queued messages restart the loop.
508 firstQueuedMessage := queuedMessages[0]
509 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
510 return a.Run(ctx, firstQueuedMessage)
511}
512
513func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
514 if a.IsSessionBusy(sessionID) {
515 return ErrSessionBusy
516 }
517
518 currentSession, err := a.sessions.Get(ctx, sessionID)
519 if err != nil {
520 return fmt.Errorf("failed to get session: %w", err)
521 }
522 msgs, err := a.getSessionMessages(ctx, currentSession)
523 if err != nil {
524 return err
525 }
526 if len(msgs) == 0 {
527 // Nothing to summarize.
528 return nil
529 }
530
531 aiMsgs, _ := a.preparePrompt(msgs)
532
533 genCtx, cancel := context.WithCancel(ctx)
534 a.activeRequests.Set(sessionID, cancel)
535 defer a.activeRequests.Del(sessionID)
536 defer cancel()
537
538 agent := fantasy.NewAgent(a.largeModel.Model,
539 fantasy.WithSystemPrompt(string(summaryPrompt)),
540 )
541 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
542 Role: message.Assistant,
543 Model: a.largeModel.Model.Model(),
544 Provider: a.largeModel.Model.Provider(),
545 IsSummaryMessage: true,
546 })
547 if err != nil {
548 return err
549 }
550
551 summaryPromptText := "Provide a detailed summary of our conversation above."
552 if len(currentSession.Todos) > 0 {
553 summaryPromptText += "\n\n## Current Todo List\n\n"
554 for _, t := range currentSession.Todos {
555 summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
556 }
557 summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
558 summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
559 }
560
561 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
562 Prompt: summaryPromptText,
563 Messages: aiMsgs,
564 ProviderOptions: opts,
565 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
566 prepared.Messages = options.Messages
567 if a.systemPromptPrefix != "" {
568 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
569 }
570 return callContext, prepared, nil
571 },
572 OnReasoningDelta: func(id string, text string) error {
573 summaryMessage.AppendReasoningContent(text)
574 return a.messages.Update(genCtx, summaryMessage)
575 },
576 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
577 // Handle anthropic signature.
578 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
579 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
580 summaryMessage.AppendReasoningSignature(signature.Signature)
581 }
582 }
583 summaryMessage.FinishThinking()
584 return a.messages.Update(genCtx, summaryMessage)
585 },
586 OnTextDelta: func(id, text string) error {
587 summaryMessage.AppendContent(text)
588 return a.messages.Update(genCtx, summaryMessage)
589 },
590 })
591 if err != nil {
592 isCancelErr := errors.Is(err, context.Canceled)
593 if isCancelErr {
594 // User cancelled summarize we need to remove the summary message.
595 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
596 return deleteErr
597 }
598 return err
599 }
600
601 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
602 err = a.messages.Update(genCtx, summaryMessage)
603 if err != nil {
604 return err
605 }
606
607 var openrouterCost *float64
608 for _, step := range resp.Steps {
609 stepCost := a.openrouterCost(step.ProviderMetadata)
610 if stepCost != nil {
611 newCost := *stepCost
612 if openrouterCost != nil {
613 newCost += *openrouterCost
614 }
615 openrouterCost = &newCost
616 }
617 }
618
619 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
620
621 // Just in case, get just the last usage info.
622 usage := resp.Response.Usage
623 currentSession.SummaryMessageID = summaryMessage.ID
624 currentSession.CompletionTokens = usage.OutputTokens
625 currentSession.PromptTokens = 0
626 _, err = a.sessions.Save(genCtx, currentSession)
627 return err
628}
629
630func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
631 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
632 return fantasy.ProviderOptions{}
633 }
634 return fantasy.ProviderOptions{
635 anthropic.Name: &anthropic.ProviderCacheControlOptions{
636 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
637 },
638 bedrock.Name: &anthropic.ProviderCacheControlOptions{
639 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
640 },
641 }
642}
643
644func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
645 var attachmentParts []message.ContentPart
646 for _, attachment := range call.Attachments {
647 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
648 }
649 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
650 parts = append(parts, attachmentParts...)
651 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
652 Role: message.User,
653 Parts: parts,
654 })
655 if err != nil {
656 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
657 }
658 return msg, nil
659}
660
661func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
662 var history []fantasy.Message
663 if !a.isSubAgent {
664 history = append(history, fantasy.NewUserMessage(
665 fmt.Sprintf("<system_reminder>%s</system_reminder>",
666 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
667If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
668If not, please feel free to ignore. Again do not mention this message to the user.`,
669 ),
670 ))
671 }
672 for _, m := range msgs {
673 if len(m.Parts) == 0 {
674 continue
675 }
676 // Assistant message without content or tool calls (cancelled before it
677 // returned anything).
678 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
679 continue
680 }
681 history = append(history, m.ToAIMessage()...)
682 }
683
684 var files []fantasy.FilePart
685 for _, attachment := range attachments {
686 files = append(files, fantasy.FilePart{
687 Filename: attachment.FileName,
688 Data: attachment.Content,
689 MediaType: attachment.MimeType,
690 })
691 }
692
693 return history, files
694}
695
696func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
697 msgs, err := a.messages.List(ctx, session.ID)
698 if err != nil {
699 return nil, fmt.Errorf("failed to list messages: %w", err)
700 }
701
702 if session.SummaryMessageID != "" {
703 summaryMsgInex := -1
704 for i, msg := range msgs {
705 if msg.ID == session.SummaryMessageID {
706 summaryMsgInex = i
707 break
708 }
709 }
710 if summaryMsgInex != -1 {
711 msgs = msgs[summaryMsgInex:]
712 msgs[0].Role = message.User
713 }
714 }
715 return msgs, nil
716}
717
718func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
719 if prompt == "" {
720 return
721 }
722
723 var maxOutput int64 = 40
724 if a.smallModel.CatwalkCfg.CanReason {
725 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
726 }
727
728 agent := fantasy.NewAgent(a.smallModel.Model,
729 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
730 fantasy.WithMaxOutputTokens(maxOutput),
731 )
732
733 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
734 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
735 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
736 prepared.Messages = options.Messages
737 if a.systemPromptPrefix != "" {
738 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
739 }
740 return callContext, prepared, nil
741 },
742 })
743 if err != nil {
744 slog.Error("error generating title", "err", err)
745 return
746 }
747
748 title := resp.Response.Content.Text()
749
750 title = strings.ReplaceAll(title, "\n", " ")
751
752 // Remove thinking tags if present.
753 if idx := strings.Index(title, "</think>"); idx > 0 {
754 title = title[idx+len("</think>"):]
755 }
756
757 title = strings.TrimSpace(title)
758 if title == "" {
759 slog.Warn("failed to generate title", "warn", "empty title")
760 return
761 }
762
763 session.Title = title
764
765 var openrouterCost *float64
766 for _, step := range resp.Steps {
767 stepCost := a.openrouterCost(step.ProviderMetadata)
768 if stepCost != nil {
769 newCost := *stepCost
770 if openrouterCost != nil {
771 newCost += *openrouterCost
772 }
773 openrouterCost = &newCost
774 }
775 }
776
777 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
778 _, saveErr := a.sessions.Save(ctx, *session)
779 if saveErr != nil {
780 slog.Error("failed to save session title & usage", "error", saveErr)
781 return
782 }
783}
784
785func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
786 openrouterMetadata, ok := metadata[openrouter.Name]
787 if !ok {
788 return nil
789 }
790
791 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
792 if !ok {
793 return nil
794 }
795 return &opts.Usage.Cost
796}
797
798func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
799 modelConfig := model.CatwalkCfg
800 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
801 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
802 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
803 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
804
805 if a.isClaudeCode() {
806 cost = 0
807 }
808
809 a.eventTokensUsed(session.ID, model, usage, cost)
810
811 if overrideCost != nil {
812 session.Cost += *overrideCost
813 } else {
814 session.Cost += cost
815 }
816
817 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
818 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
819}
820
821func (a *sessionAgent) Cancel(sessionID string) {
822 // Cancel regular requests.
823 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
824 slog.Info("Request cancellation initiated", "session_id", sessionID)
825 cancel()
826 }
827
828 // Also check for summarize requests.
829 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
830 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
831 cancel()
832 }
833
834 if a.QueuedPrompts(sessionID) > 0 {
835 slog.Info("Clearing queued prompts", "session_id", sessionID)
836 a.messageQueue.Del(sessionID)
837 }
838}
839
840func (a *sessionAgent) ClearQueue(sessionID string) {
841 if a.QueuedPrompts(sessionID) > 0 {
842 slog.Info("Clearing queued prompts", "session_id", sessionID)
843 a.messageQueue.Del(sessionID)
844 }
845}
846
847func (a *sessionAgent) CancelAll() {
848 if !a.IsBusy() {
849 return
850 }
851 for key := range a.activeRequests.Seq2() {
852 a.Cancel(key) // key is sessionID
853 }
854
855 timeout := time.After(5 * time.Second)
856 for a.IsBusy() {
857 select {
858 case <-timeout:
859 return
860 default:
861 time.Sleep(200 * time.Millisecond)
862 }
863 }
864}
865
866func (a *sessionAgent) IsBusy() bool {
867 var busy bool
868 for cancelFunc := range a.activeRequests.Seq() {
869 if cancelFunc != nil {
870 busy = true
871 break
872 }
873 }
874 return busy
875}
876
877func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
878 _, busy := a.activeRequests.Get(sessionID)
879 return busy
880}
881
882func (a *sessionAgent) QueuedPrompts(sessionID string) int {
883 l, ok := a.messageQueue.Get(sessionID)
884 if !ok {
885 return 0
886 }
887 return len(l)
888}
889
890func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
891 l, ok := a.messageQueue.Get(sessionID)
892 if !ok {
893 return nil
894 }
895 prompts := make([]string, len(l))
896 for i, call := range l {
897 prompts[i] = call.Prompt
898 }
899 return prompts
900}
901
902func (a *sessionAgent) SetModels(large Model, small Model) {
903 a.largeModel = large
904 a.smallModel = small
905}
906
907func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
908 a.tools = tools
909}
910
911func (a *sessionAgent) Model() Model {
912 return a.largeModel
913}
914
915func (a *sessionAgent) promptPrefix() string {
916 if a.isClaudeCode() {
917 return "You are Claude Code, Anthropic's official CLI for Claude."
918 }
919 return a.systemPromptPrefix
920}
921
922func (a *sessionAgent) isClaudeCode() bool {
923 cfg := config.Get()
924 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
925 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
926}
927
928// convertToToolResult converts a fantasy tool result to a message tool result.
929func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
930 baseResult := message.ToolResult{
931 ToolCallID: result.ToolCallID,
932 Name: result.ToolName,
933 Metadata: result.ClientMetadata,
934 }
935
936 switch result.Result.GetType() {
937 case fantasy.ToolResultContentTypeText:
938 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
939 baseResult.Content = r.Text
940 }
941 case fantasy.ToolResultContentTypeError:
942 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
943 baseResult.Content = r.Error.Error()
944 baseResult.IsError = true
945 }
946 case fantasy.ToolResultContentTypeMedia:
947 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
948 content := r.Text
949 if content == "" {
950 content = fmt.Sprintf("Loaded %s content", r.MediaType)
951 }
952 baseResult.Content = content
953 baseResult.Data = r.Data
954 baseResult.MIMEType = r.MediaType
955 }
956 }
957
958 return baseResult
959}
960
961// workaroundProviderMediaLimitations converts media content in tool results to
962// user messages for providers that don't natively support images in tool results.
963//
964// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
965// don't support sending images/media in tool result messages - they only accept
966// text in tool results. However, they DO support images in user messages.
967//
968// If we send media in tool results to these providers, the API returns an error.
969//
970// Solution: For these providers, we:
971// 1. Replace the media in the tool result with a text placeholder
972// 2. Inject a user message immediately after with the image as a file attachment
973// 3. This maintains the tool execution flow while working around API limitations
974//
975// Anthropic and Bedrock support images natively in tool results, so we skip
976// this workaround for them.
977//
978// Example transformation:
979//
980// BEFORE: [tool result: image data]
981// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
982func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
983 providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
984 a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
985
986 if providerSupportsMedia {
987 return messages
988 }
989
990 convertedMessages := make([]fantasy.Message, 0, len(messages))
991
992 for _, msg := range messages {
993 if msg.Role != fantasy.MessageRoleTool {
994 convertedMessages = append(convertedMessages, msg)
995 continue
996 }
997
998 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
999 var mediaFiles []fantasy.FilePart
1000
1001 for _, part := range msg.Content {
1002 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1003 if !ok {
1004 textParts = append(textParts, part)
1005 continue
1006 }
1007
1008 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1009 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1010 if err != nil {
1011 slog.Warn("failed to decode media data", "error", err)
1012 textParts = append(textParts, part)
1013 continue
1014 }
1015
1016 mediaFiles = append(mediaFiles, fantasy.FilePart{
1017 Data: decoded,
1018 MediaType: media.MediaType,
1019 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1020 })
1021
1022 textParts = append(textParts, fantasy.ToolResultPart{
1023 ToolCallID: toolResult.ToolCallID,
1024 Output: fantasy.ToolResultOutputContentText{
1025 Text: "[Image/media content loaded - see attached file]",
1026 },
1027 ProviderOptions: toolResult.ProviderOptions,
1028 })
1029 } else {
1030 textParts = append(textParts, part)
1031 }
1032 }
1033
1034 convertedMessages = append(convertedMessages, fantasy.Message{
1035 Role: fantasy.MessageRoleTool,
1036 Content: textParts,
1037 })
1038
1039 if len(mediaFiles) > 0 {
1040 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1041 "Here is the media content from the tool result:",
1042 mediaFiles...,
1043 ))
1044 }
1045 }
1046
1047 return convertedMessages
1048}