1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "github.com/charmbracelet/catwalk/pkg/catwalk"
30 "github.com/charmbracelet/crush/internal/agent/tools"
31 "github.com/charmbracelet/crush/internal/config"
32 "github.com/charmbracelet/crush/internal/csync"
33 "github.com/charmbracelet/crush/internal/hooks"
34 "github.com/charmbracelet/crush/internal/message"
35 "github.com/charmbracelet/crush/internal/permission"
36 "github.com/charmbracelet/crush/internal/session"
37 "github.com/charmbracelet/crush/internal/stringext"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57}
58
59type SessionAgent interface {
60 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
61 SetModels(large Model, small Model)
62 SetTools(tools []fantasy.AgentTool)
63 Cancel(sessionID string)
64 CancelAll()
65 IsSessionBusy(sessionID string) bool
66 IsBusy() bool
67 QueuedPrompts(sessionID string) int
68 ClearQueue(sessionID string)
69 Summarize(context.Context, string, fantasy.ProviderOptions) error
70 Model() Model
71}
72
73type Model struct {
74 Model fantasy.LanguageModel
75 CatwalkCfg catwalk.Model
76 ModelCfg config.SelectedModel
77}
78
79type sessionAgent struct {
80 largeModel Model
81 smallModel Model
82 systemPromptPrefix string
83 systemPrompt string
84 tools []fantasy.AgentTool
85 sessions session.Service
86 messages message.Service
87 disableAutoSummarize bool
88 isYolo bool
89 isSubAgent bool
90 hooksManager hooks.Manager
91 workingDir string
92
93 messageQueue *csync.Map[string, []SessionAgentCall]
94 activeRequests *csync.Map[string, context.CancelFunc]
95}
96
97type SessionAgentOptions struct {
98 LargeModel Model
99 SmallModel Model
100 SystemPromptPrefix string
101 SystemPrompt string
102 DisableAutoSummarize bool
103 IsYolo bool
104 IsSubAgent bool
105 HooksManager hooks.Manager
106 WorkingDir string
107 Sessions session.Service
108 Messages message.Service
109 Tools []fantasy.AgentTool
110}
111
112func NewSessionAgent(
113 opts SessionAgentOptions,
114) SessionAgent {
115 return &sessionAgent{
116 largeModel: opts.LargeModel,
117 smallModel: opts.SmallModel,
118 systemPromptPrefix: opts.SystemPromptPrefix,
119 systemPrompt: opts.SystemPrompt,
120 sessions: opts.Sessions,
121 messages: opts.Messages,
122 disableAutoSummarize: opts.DisableAutoSummarize,
123 tools: opts.Tools,
124 isYolo: opts.IsYolo,
125 isSubAgent: opts.IsSubAgent,
126 hooksManager: opts.HooksManager,
127 workingDir: opts.WorkingDir,
128 messageQueue: csync.NewMap[string, []SessionAgentCall](),
129 activeRequests: csync.NewMap[string, context.CancelFunc](),
130 }
131}
132
133func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
134 if call.Prompt == "" {
135 return nil, ErrEmptyPrompt
136 }
137 if call.SessionID == "" {
138 return nil, ErrSessionMissing
139 }
140
141 // Queue the message if busy
142 if a.IsSessionBusy(call.SessionID) {
143 existing, ok := a.messageQueue.Get(call.SessionID)
144 if !ok {
145 existing = []SessionAgentCall{}
146 }
147 existing = append(existing, call)
148 a.messageQueue.Set(call.SessionID, existing)
149 return nil, nil
150 }
151
152 if len(a.tools) > 0 {
153 // Add Anthropic caching to the last tool.
154 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
155 }
156
157 agent := fantasy.NewAgent(
158 a.largeModel.Model,
159 fantasy.WithSystemPrompt(a.systemPrompt),
160 fantasy.WithTools(a.tools...),
161 )
162
163 sessionLock := sync.Mutex{}
164 currentSession, err := a.sessions.Get(ctx, call.SessionID)
165 if err != nil {
166 return nil, fmt.Errorf("failed to get session: %w", err)
167 }
168
169 msgs, err := a.getSessionMessages(ctx, currentSession)
170 if err != nil {
171 return nil, fmt.Errorf("failed to get session messages: %w", err)
172 }
173
174 var wg sync.WaitGroup
175 // Generate title if first message.
176 if len(msgs) == 0 {
177 wg.Go(func() {
178 sessionLock.Lock()
179 a.generateTitle(ctx, ¤tSession, call.Prompt)
180 sessionLock.Unlock()
181 })
182 }
183
184 // Add the user message to the session.
185 msg, err := a.createUserMessage(ctx, call)
186 if err != nil {
187 return nil, err
188 }
189
190 // Add the session to the context.
191 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
192
193 genCtx, cancel := context.WithCancel(ctx)
194 a.activeRequests.Set(call.SessionID, cancel)
195
196 defer cancel()
197 defer a.activeRequests.Del(call.SessionID)
198
199 // create the agent message asap to show loading
200 var currentAssistant *message.Message
201 assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
202 Role: message.Assistant,
203 Parts: []message.ContentPart{},
204 Model: a.largeModel.ModelCfg.Model,
205 Provider: a.largeModel.ModelCfg.Provider,
206 })
207 if err != nil {
208 return nil, err
209 }
210
211 currentAssistant = &assistantMessage
212
213 hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0)
214 if hookErr != nil {
215 // Delete the assistant message
216 // use the ctx since this could be a cancellation
217 deleteErr := a.messages.Delete(ctx, currentAssistant.ID)
218 return nil, cmp.Or(deleteErr, hookErr)
219 }
220
221 history, files := a.preparePrompt(msgs, call.Attachments...)
222
223 startTime := time.Now()
224 a.eventPromptSent(call.SessionID)
225
226 var shouldSummarize bool
227 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
228 Prompt: msg.ContentWithHookContext(),
229 Files: files,
230 Messages: history,
231 ProviderOptions: call.ProviderOptions,
232 MaxOutputTokens: &call.MaxOutputTokens,
233 TopP: call.TopP,
234 Temperature: call.Temperature,
235 PresencePenalty: call.PresencePenalty,
236 TopK: call.TopK,
237 FrequencyPenalty: call.FrequencyPenalty,
238 // Before each step create a new assistant message.
239 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
240 // only add new assistant message when its not the first step
241 if options.StepNumber != 0 {
242 var assistantMsg message.Message
243 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
244 Role: message.Assistant,
245 Model: a.largeModel.ModelCfg.Model,
246 Provider: a.largeModel.ModelCfg.Provider,
247 })
248 currentAssistant = &assistantMsg
249 // create the message first so we show loading asap
250 if err != nil {
251 return callContext, prepared, err
252 }
253 }
254 prepared.Messages = options.Messages
255 // Reset all cached items.
256 for i := range prepared.Messages {
257 prepared.Messages[i].ProviderOptions = nil
258 }
259
260 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
261 a.messageQueue.Del(call.SessionID)
262 for _, queued := range queuedCalls {
263 userMessage, createErr := a.createUserMessage(callContext, queued)
264 if createErr != nil {
265 return callContext, prepared, createErr
266 }
267
268 hookErr := a.executePromptSubmitHook(ctx, &msg, len(msgs) == 0)
269 if hookErr != nil {
270 return callContext, prepared, hookErr
271 }
272
273 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
274 }
275
276 lastSystemRoleInx := 0
277 systemMessageUpdated := false
278 for i, msg := range prepared.Messages {
279 // Only add cache control to the last message.
280 if msg.Role == fantasy.MessageRoleSystem {
281 lastSystemRoleInx = i
282 } else if !systemMessageUpdated {
283 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
284 systemMessageUpdated = true
285 }
286 // Than add cache control to the last 2 messages.
287 if i > len(prepared.Messages)-3 {
288 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
289 }
290 }
291
292 if a.systemPromptPrefix != "" {
293 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
294 }
295
296 callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
297 return callContext, prepared, err
298 },
299 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
300 currentAssistant.AppendReasoningContent(reasoning.Text)
301 return a.messages.Update(genCtx, *currentAssistant)
302 },
303 OnReasoningDelta: func(id string, text string) error {
304 currentAssistant.AppendReasoningContent(text)
305 return a.messages.Update(genCtx, *currentAssistant)
306 },
307 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
308 // handle anthropic signature
309 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
310 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
311 currentAssistant.AppendReasoningSignature(reasoning.Signature)
312 }
313 }
314 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
315 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
316 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
317 }
318 }
319 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
320 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
321 currentAssistant.SetReasoningResponsesData(reasoning)
322 }
323 }
324 currentAssistant.FinishThinking()
325 return a.messages.Update(genCtx, *currentAssistant)
326 },
327 OnTextDelta: func(id string, text string) error {
328 // Strip leading newline from initial text content. This is is
329 // particularly important in non-interactive mode where leading
330 // newlines are very visible.
331 if len(currentAssistant.Parts) == 0 {
332 text = strings.TrimPrefix(text, "\n")
333 }
334
335 currentAssistant.AppendContent(text)
336 return a.messages.Update(genCtx, *currentAssistant)
337 },
338 OnToolInputStart: func(id string, toolName string) error {
339 toolCall := message.ToolCall{
340 ID: id,
341 Name: toolName,
342 ProviderExecuted: false,
343 Finished: false,
344 }
345 currentAssistant.AddToolCall(toolCall)
346 return a.messages.Update(genCtx, *currentAssistant)
347 },
348 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
349 // TODO: implement
350 },
351 OnToolCall: func(tc fantasy.ToolCallContent) error {
352 toolCall := message.ToolCall{
353 ID: tc.ToolCallID,
354 Name: tc.ToolName,
355 Input: tc.Input,
356 ProviderExecuted: false,
357 Finished: true,
358 }
359 currentAssistant.AddToolCall(toolCall)
360 return a.messages.Update(genCtx, *currentAssistant)
361 },
362 OnToolResult: func(result fantasy.ToolResultContent) error {
363 var resultContent string
364 isError := false
365 switch result.Result.GetType() {
366 case fantasy.ToolResultContentTypeText:
367 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
368 if ok {
369 resultContent = r.Text
370 }
371 case fantasy.ToolResultContentTypeError:
372 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
373 if ok {
374 isError = true
375 resultContent = r.Error.Error()
376 }
377 case fantasy.ToolResultContentTypeMedia:
378 // TODO: handle this message type
379 }
380 toolResult := message.ToolResult{
381 ToolCallID: result.ToolCallID,
382 Name: result.ToolName,
383 Content: resultContent,
384 IsError: isError,
385 Metadata: result.ClientMetadata,
386 }
387 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
388 Role: message.Tool,
389 Parts: []message.ContentPart{
390 toolResult,
391 },
392 })
393 if createMsgErr != nil {
394 return createMsgErr
395 }
396 return nil
397 },
398 OnStepFinish: func(stepResult fantasy.StepResult) error {
399 finishReason := message.FinishReasonUnknown
400 switch stepResult.FinishReason {
401 case fantasy.FinishReasonLength:
402 finishReason = message.FinishReasonMaxTokens
403 case fantasy.FinishReasonStop:
404 finishReason = message.FinishReasonEndTurn
405 case fantasy.FinishReasonToolCalls:
406 finishReason = message.FinishReasonToolUse
407 }
408 currentAssistant.AddFinish(finishReason, "", "")
409 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
410 sessionLock.Lock()
411 _, sessionErr := a.sessions.Save(genCtx, currentSession)
412 sessionLock.Unlock()
413 if sessionErr != nil {
414 return sessionErr
415 }
416 return a.messages.Update(genCtx, *currentAssistant)
417 },
418 StopWhen: []fantasy.StopCondition{
419 func(_ []fantasy.StepResult) bool {
420 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
421 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
422 remaining := cw - tokens
423 var threshold int64
424 if cw > 200_000 {
425 threshold = 20_000
426 } else {
427 threshold = int64(float64(cw) * 0.2)
428 }
429 if (remaining <= threshold) && !a.disableAutoSummarize {
430 shouldSummarize = true
431 return true
432 }
433 return false
434 },
435 },
436 })
437
438 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
439
440 if err != nil {
441 isCancelErr := errors.Is(err, context.Canceled)
442 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
443 if currentAssistant == nil {
444 return result, err
445 }
446 // Ensure we finish thinking on error to close the reasoning state.
447 currentAssistant.FinishThinking()
448 toolCalls := currentAssistant.ToolCalls()
449 // INFO: we use the parent context here because the genCtx has been cancelled.
450 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
451 if createErr != nil {
452 return nil, createErr
453 }
454 for _, tc := range toolCalls {
455 if !tc.Finished {
456 tc.Finished = true
457 tc.Input = "{}"
458 currentAssistant.AddToolCall(tc)
459 updateErr := a.messages.Update(ctx, *currentAssistant)
460 if updateErr != nil {
461 return nil, updateErr
462 }
463 }
464
465 found := false
466 for _, msg := range msgs {
467 if msg.Role == message.Tool {
468 for _, tr := range msg.ToolResults() {
469 if tr.ToolCallID == tc.ID {
470 found = true
471 break
472 }
473 }
474 }
475 if found {
476 break
477 }
478 }
479 if found {
480 continue
481 }
482 content := "There was an error while executing the tool"
483 if isCancelErr {
484 content = "Tool execution canceled by user"
485 } else if isPermissionErr {
486 content = "User denied permission"
487 }
488 toolResult := message.ToolResult{
489 ToolCallID: tc.ID,
490 Name: tc.Name,
491 Content: content,
492 IsError: true,
493 }
494 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
495 Role: message.Tool,
496 Parts: []message.ContentPart{
497 toolResult,
498 },
499 })
500 if createErr != nil {
501 return nil, createErr
502 }
503 }
504 var fantasyErr *fantasy.Error
505 var providerErr *fantasy.ProviderError
506 const defaultTitle = "Provider Error"
507 if isCancelErr {
508 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
509 } else if isPermissionErr {
510 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
511 } else if errors.As(err, &providerErr) {
512 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
513 } else if errors.As(err, &fantasyErr) {
514 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
515 } else {
516 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
517 }
518 // Note: we use the parent context here because the genCtx has been
519 // cancelled.
520 updateErr := a.messages.Update(ctx, *currentAssistant)
521 if updateErr != nil {
522 return nil, updateErr
523 }
524 return nil, err
525 }
526 wg.Wait()
527
528 if shouldSummarize {
529 a.activeRequests.Del(call.SessionID)
530 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
531 return nil, summarizeErr
532 }
533 // If the agent wasn't done...
534 if len(currentAssistant.ToolCalls()) > 0 {
535 existing, ok := a.messageQueue.Get(call.SessionID)
536 if !ok {
537 existing = []SessionAgentCall{}
538 }
539 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
540 existing = append(existing, call)
541 a.messageQueue.Set(call.SessionID, existing)
542 }
543 }
544
545 // Release active request before processing queued messages.
546 a.activeRequests.Del(call.SessionID)
547 cancel()
548
549 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
550 if !ok || len(queuedMessages) == 0 {
551 return result, err
552 }
553 // There are queued messages restart the loop.
554 firstQueuedMessage := queuedMessages[0]
555 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
556 return a.Run(ctx, firstQueuedMessage)
557}
558
559func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
560 if a.IsSessionBusy(sessionID) {
561 return ErrSessionBusy
562 }
563
564 currentSession, err := a.sessions.Get(ctx, sessionID)
565 if err != nil {
566 return fmt.Errorf("failed to get session: %w", err)
567 }
568 msgs, err := a.getSessionMessages(ctx, currentSession)
569 if err != nil {
570 return err
571 }
572 if len(msgs) == 0 {
573 // Nothing to summarize.
574 return nil
575 }
576
577 aiMsgs, _ := a.preparePrompt(msgs)
578
579 genCtx, cancel := context.WithCancel(ctx)
580 a.activeRequests.Set(sessionID, cancel)
581 defer a.activeRequests.Del(sessionID)
582 defer cancel()
583
584 agent := fantasy.NewAgent(a.largeModel.Model,
585 fantasy.WithSystemPrompt(string(summaryPrompt)),
586 )
587 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
588 Role: message.Assistant,
589 Model: a.largeModel.Model.Model(),
590 Provider: a.largeModel.Model.Provider(),
591 IsSummaryMessage: true,
592 })
593 if err != nil {
594 return err
595 }
596
597 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
598 Prompt: "Provide a detailed summary of our conversation above.",
599 Messages: aiMsgs,
600 ProviderOptions: opts,
601 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
602 prepared.Messages = options.Messages
603 if a.systemPromptPrefix != "" {
604 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
605 }
606 return callContext, prepared, nil
607 },
608 OnReasoningDelta: func(id string, text string) error {
609 summaryMessage.AppendReasoningContent(text)
610 return a.messages.Update(genCtx, summaryMessage)
611 },
612 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
613 // Handle anthropic signature.
614 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
615 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
616 summaryMessage.AppendReasoningSignature(signature.Signature)
617 }
618 }
619 summaryMessage.FinishThinking()
620 return a.messages.Update(genCtx, summaryMessage)
621 },
622 OnTextDelta: func(id, text string) error {
623 summaryMessage.AppendContent(text)
624 return a.messages.Update(genCtx, summaryMessage)
625 },
626 })
627 if err != nil {
628 isCancelErr := errors.Is(err, context.Canceled)
629 if isCancelErr {
630 // User cancelled summarize we need to remove the summary message.
631 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
632 return deleteErr
633 }
634 return err
635 }
636
637 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
638 err = a.messages.Update(genCtx, summaryMessage)
639 if err != nil {
640 return err
641 }
642
643 var openrouterCost *float64
644 for _, step := range resp.Steps {
645 stepCost := a.openrouterCost(step.ProviderMetadata)
646 if stepCost != nil {
647 newCost := *stepCost
648 if openrouterCost != nil {
649 newCost += *openrouterCost
650 }
651 openrouterCost = &newCost
652 }
653 }
654
655 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
656
657 // Just in case, get just the last usage info.
658 usage := resp.Response.Usage
659 currentSession.SummaryMessageID = summaryMessage.ID
660 currentSession.CompletionTokens = usage.OutputTokens
661 currentSession.PromptTokens = 0
662 _, err = a.sessions.Save(genCtx, currentSession)
663 return err
664}
665
666func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
667 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
668 return fantasy.ProviderOptions{}
669 }
670 return fantasy.ProviderOptions{
671 anthropic.Name: &anthropic.ProviderCacheControlOptions{
672 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
673 },
674 bedrock.Name: &anthropic.ProviderCacheControlOptions{
675 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
676 },
677 }
678}
679
680func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
681 var attachmentParts []message.ContentPart
682 for _, attachment := range call.Attachments {
683 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
684 }
685 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
686 parts = append(parts, attachmentParts...)
687 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
688 Role: message.User,
689 Parts: parts,
690 })
691 if err != nil {
692 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
693 }
694 return msg, nil
695}
696
697func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
698 var history []fantasy.Message
699 for _, m := range msgs {
700 if len(m.Parts) == 0 {
701 continue
702 }
703 // Assistant message without content or tool calls (cancelled before it
704 // returned anything).
705 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
706 continue
707 }
708 history = append(history, m.ToAIMessage()...)
709 }
710
711 var files []fantasy.FilePart
712 for _, attachment := range attachments {
713 files = append(files, fantasy.FilePart{
714 Filename: attachment.FileName,
715 Data: attachment.Content,
716 MediaType: attachment.MimeType,
717 })
718 }
719
720 return history, files
721}
722
723func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
724 msgs, err := a.messages.List(ctx, session.ID)
725 if err != nil {
726 return nil, fmt.Errorf("failed to list messages: %w", err)
727 }
728
729 if session.SummaryMessageID != "" {
730 summaryMsgInex := -1
731 for i, msg := range msgs {
732 if msg.ID == session.SummaryMessageID {
733 summaryMsgInex = i
734 break
735 }
736 }
737 if summaryMsgInex != -1 {
738 msgs = msgs[summaryMsgInex:]
739 msgs[0].Role = message.User
740 }
741 }
742 return msgs, nil
743}
744
745func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
746 if prompt == "" {
747 return
748 }
749
750 var maxOutput int64 = 40
751 if a.smallModel.CatwalkCfg.CanReason {
752 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
753 }
754
755 agent := fantasy.NewAgent(a.smallModel.Model,
756 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
757 fantasy.WithMaxOutputTokens(maxOutput),
758 )
759
760 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
761 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
762 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
763 prepared.Messages = options.Messages
764 if a.systemPromptPrefix != "" {
765 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
766 }
767 return callContext, prepared, nil
768 },
769 })
770 if err != nil {
771 slog.Error("error generating title", "err", err)
772 return
773 }
774
775 title := resp.Response.Content.Text()
776
777 title = strings.ReplaceAll(title, "\n", " ")
778
779 // Remove thinking tags if present.
780 if idx := strings.Index(title, "</think>"); idx > 0 {
781 title = title[idx+len("</think>"):]
782 }
783
784 title = strings.TrimSpace(title)
785 if title == "" {
786 slog.Warn("failed to generate title", "warn", "empty title")
787 return
788 }
789
790 session.Title = title
791
792 var openrouterCost *float64
793 for _, step := range resp.Steps {
794 stepCost := a.openrouterCost(step.ProviderMetadata)
795 if stepCost != nil {
796 newCost := *stepCost
797 if openrouterCost != nil {
798 newCost += *openrouterCost
799 }
800 openrouterCost = &newCost
801 }
802 }
803
804 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
805 _, saveErr := a.sessions.Save(ctx, *session)
806 if saveErr != nil {
807 slog.Error("failed to save session title & usage", "error", saveErr)
808 return
809 }
810}
811
812func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
813 openrouterMetadata, ok := metadata[openrouter.Name]
814 if !ok {
815 return nil
816 }
817
818 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
819 if !ok {
820 return nil
821 }
822 return &opts.Usage.Cost
823}
824
825func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
826 modelConfig := model.CatwalkCfg
827 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
828 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
829 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
830 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
831
832 a.eventTokensUsed(session.ID, model, usage, cost)
833
834 if overrideCost != nil {
835 session.Cost += *overrideCost
836 } else {
837 session.Cost += cost
838 }
839
840 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
841 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
842}
843
844func (a *sessionAgent) Cancel(sessionID string) {
845 // Cancel regular requests.
846 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
847 slog.Info("Request cancellation initiated", "session_id", sessionID)
848 cancel()
849 }
850
851 // Also check for summarize requests.
852 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
853 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
854 cancel()
855 }
856
857 if a.QueuedPrompts(sessionID) > 0 {
858 slog.Info("Clearing queued prompts", "session_id", sessionID)
859 a.messageQueue.Del(sessionID)
860 }
861}
862
863func (a *sessionAgent) ClearQueue(sessionID string) {
864 if a.QueuedPrompts(sessionID) > 0 {
865 slog.Info("Clearing queued prompts", "session_id", sessionID)
866 a.messageQueue.Del(sessionID)
867 }
868}
869
870func (a *sessionAgent) CancelAll() {
871 if !a.IsBusy() {
872 return
873 }
874 for key := range a.activeRequests.Seq2() {
875 a.Cancel(key) // key is sessionID
876 }
877
878 timeout := time.After(5 * time.Second)
879 for a.IsBusy() {
880 select {
881 case <-timeout:
882 return
883 default:
884 time.Sleep(200 * time.Millisecond)
885 }
886 }
887}
888
889func (a *sessionAgent) IsBusy() bool {
890 var busy bool
891 for cancelFunc := range a.activeRequests.Seq() {
892 if cancelFunc != nil {
893 busy = true
894 break
895 }
896 }
897 return busy
898}
899
900func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
901 _, busy := a.activeRequests.Get(sessionID)
902 return busy
903}
904
905func (a *sessionAgent) QueuedPrompts(sessionID string) int {
906 l, ok := a.messageQueue.Get(sessionID)
907 if !ok {
908 return 0
909 }
910 return len(l)
911}
912
913func (a *sessionAgent) SetModels(large Model, small Model) {
914 a.largeModel = large
915 a.smallModel = small
916}
917
918func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
919 a.tools = tools
920}
921
922func (a *sessionAgent) Model() Model {
923 return a.largeModel
924}
925
926// executePromptSubmitHook executes the user-prompt-submit hook and applies modifications to the call.
927// Only runs for main agent (not sub-agents).
928func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message.Message, isFirstMessage bool) error {
929 // Skip if sub-agent or no hooks manager.
930 if a.isSubAgent || a.hooksManager == nil {
931 return nil
932 }
933
934 // Convert attachments to file paths.
935 attachmentPaths := make([]string, len(msg.BinaryContent()))
936 for i, att := range msg.BinaryContent() {
937 attachmentPaths[i] = att.Path
938 }
939
940 hookResult, err := a.hooksManager.ExecuteUserPromptSubmit(ctx, msg.SessionID, a.workingDir, hooks.UserPromptSubmitData{
941 Prompt: msg.Content().Text,
942 Attachments: attachmentPaths,
943 Model: a.largeModel.CatwalkCfg.ID,
944 Provider: a.largeModel.Model.Provider(),
945 IsFirstMessage: isFirstMessage,
946 })
947 if err != nil {
948 return fmt.Errorf("hook execution failed: %w", err)
949 }
950
951 // Apply hook modifications to the prompt.
952 if hookResult.ModifiedPrompt != nil {
953 for i, part := range msg.Parts {
954 if _, ok := part.(message.TextContent); ok {
955 msg.Parts[i] = message.TextContent{Text: *hookResult.ModifiedPrompt}
956 }
957 }
958 }
959 msg.AddHookResult(hookResult)
960 err = a.messages.Update(ctx, *msg)
961 if err != nil {
962 return err
963 }
964 // If hook returned Continue: false, stop execution.
965 if !hookResult.Continue {
966 return ErrHookExecutionStop
967 }
968 return nil
969}