1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "github.com/charmbracelet/catwalk/pkg/catwalk"
30 "github.com/charmbracelet/crush/internal/agent/tools"
31 "github.com/charmbracelet/crush/internal/config"
32 "github.com/charmbracelet/crush/internal/csync"
33 "github.com/charmbracelet/crush/internal/hooks"
34 "github.com/charmbracelet/crush/internal/message"
35 "github.com/charmbracelet/crush/internal/permission"
36 "github.com/charmbracelet/crush/internal/session"
37 "github.com/charmbracelet/crush/internal/stringext"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57}
58
59type SessionAgent interface {
60 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
61 SetModels(large Model, small Model)
62 SetTools(tools []fantasy.AgentTool)
63 Cancel(sessionID string)
64 CancelAll()
65 IsSessionBusy(sessionID string) bool
66 IsBusy() bool
67 QueuedPrompts(sessionID string) int
68 ClearQueue(sessionID string)
69 Summarize(context.Context, string, fantasy.ProviderOptions) error
70 Model() Model
71}
72
73type Model struct {
74 Model fantasy.LanguageModel
75 CatwalkCfg catwalk.Model
76 ModelCfg config.SelectedModel
77}
78
79type sessionAgent struct {
80 largeModel Model
81 smallModel Model
82 systemPromptPrefix string
83 systemPrompt string
84 tools []fantasy.AgentTool
85 sessions session.Service
86 messages message.Service
87 disableAutoSummarize bool
88 isYolo bool
89 isSubAgent bool
90 hooksManager hooks.Manager
91 workingDir string
92
93 messageQueue *csync.Map[string, []SessionAgentCall]
94 activeRequests *csync.Map[string, context.CancelFunc]
95}
96
97type SessionAgentOptions struct {
98 LargeModel Model
99 SmallModel Model
100 SystemPromptPrefix string
101 SystemPrompt string
102 DisableAutoSummarize bool
103 IsYolo bool
104 IsSubAgent bool
105 HooksManager hooks.Manager
106 WorkingDir string
107 Sessions session.Service
108 Messages message.Service
109 Tools []fantasy.AgentTool
110}
111
112func NewSessionAgent(
113 opts SessionAgentOptions,
114) SessionAgent {
115 return &sessionAgent{
116 largeModel: opts.LargeModel,
117 smallModel: opts.SmallModel,
118 systemPromptPrefix: opts.SystemPromptPrefix,
119 systemPrompt: opts.SystemPrompt,
120 sessions: opts.Sessions,
121 messages: opts.Messages,
122 disableAutoSummarize: opts.DisableAutoSummarize,
123 tools: opts.Tools,
124 isYolo: opts.IsYolo,
125 isSubAgent: opts.IsSubAgent,
126 hooksManager: opts.HooksManager,
127 workingDir: opts.WorkingDir,
128 messageQueue: csync.NewMap[string, []SessionAgentCall](),
129 activeRequests: csync.NewMap[string, context.CancelFunc](),
130 }
131}
132
133func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
134 if call.Prompt == "" {
135 return nil, ErrEmptyPrompt
136 }
137 if call.SessionID == "" {
138 return nil, ErrSessionMissing
139 }
140
141 // Queue the message if busy
142 if a.IsSessionBusy(call.SessionID) {
143 existing, ok := a.messageQueue.Get(call.SessionID)
144 if !ok {
145 existing = []SessionAgentCall{}
146 }
147 existing = append(existing, call)
148 a.messageQueue.Set(call.SessionID, existing)
149 return nil, nil
150 }
151
152 if len(a.tools) > 0 {
153 // Add Anthropic caching to the last tool.
154 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
155 }
156
157 agent := fantasy.NewAgent(
158 a.largeModel.Model,
159 fantasy.WithSystemPrompt(a.systemPrompt),
160 fantasy.WithTools(a.tools...),
161 )
162
163 sessionLock := sync.Mutex{}
164 currentSession, err := a.sessions.Get(ctx, call.SessionID)
165 if err != nil {
166 return nil, fmt.Errorf("failed to get session: %w", err)
167 }
168
169 msgs, err := a.getSessionMessages(ctx, currentSession)
170 if err != nil {
171 return nil, fmt.Errorf("failed to get session messages: %w", err)
172 }
173
174 var wg sync.WaitGroup
175 // Generate title if first message.
176 if len(msgs) == 0 {
177 wg.Go(func() {
178 sessionLock.Lock()
179 a.generateTitle(ctx, ¤tSession, call.Prompt)
180 sessionLock.Unlock()
181 })
182 }
183
184 // Add the user message to the session.
185 msg, err := a.createUserMessage(ctx, call)
186 if err != nil {
187 return nil, err
188 }
189
190 // Add the session to the context.
191 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
192
193 genCtx, cancel := context.WithCancel(ctx)
194 a.activeRequests.Set(call.SessionID, cancel)
195
196 defer cancel()
197 defer a.activeRequests.Del(call.SessionID)
198
199 // create the agent message asap to show loading
200 var currentAssistant *message.Message
201 assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
202 Role: message.Assistant,
203 Parts: []message.ContentPart{},
204 Model: a.largeModel.ModelCfg.Model,
205 Provider: a.largeModel.ModelCfg.Provider,
206 })
207 if err != nil {
208 return nil, err
209 }
210
211 currentAssistant = &assistantMessage
212
213 hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0)
214 if hookErr != nil {
215 // Delete the assistant message
216 // use the ctx since this could be a cancellation
217 deleteErr := a.messages.Delete(ctx, currentAssistant.ID)
218 return nil, cmp.Or(deleteErr, hookErr)
219 }
220
221 history, files := a.preparePrompt(msgs, call.Attachments...)
222
223 startTime := time.Now()
224 a.eventPromptSent(call.SessionID)
225
226 var shouldSummarize bool
227 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
228 Prompt: msg.ContentWithHookContext(),
229 Files: files,
230 Messages: history,
231 ProviderOptions: call.ProviderOptions,
232 MaxOutputTokens: &call.MaxOutputTokens,
233 TopP: call.TopP,
234 Temperature: call.Temperature,
235 PresencePenalty: call.PresencePenalty,
236 TopK: call.TopK,
237 FrequencyPenalty: call.FrequencyPenalty,
238 // Before each step create a new assistant message.
239 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
240 // only add new assistant message when its not the first step
241 if options.StepNumber != 0 {
242 var assistantMsg message.Message
243 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
244 Role: message.Assistant,
245 Parts: []message.ContentPart{},
246 Model: a.largeModel.ModelCfg.Model,
247 Provider: a.largeModel.ModelCfg.Provider,
248 })
249 currentAssistant = &assistantMsg
250 // create the message first so we show loading asap
251 if err != nil {
252 return callContext, prepared, err
253 }
254 }
255 prepared.Messages = options.Messages
256 // Reset all cached items.
257 for i := range prepared.Messages {
258 prepared.Messages[i].ProviderOptions = nil
259 }
260
261 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
262 a.messageQueue.Del(call.SessionID)
263 for _, queued := range queuedCalls {
264 userMessage, createErr := a.createUserMessage(callContext, queued)
265 if createErr != nil {
266 return callContext, prepared, createErr
267 }
268
269 hookErr := a.executePromptSubmitHook(ctx, &msg, len(msgs) == 0)
270 if hookErr != nil {
271 return callContext, prepared, hookErr
272 }
273
274 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
275 }
276
277 lastSystemRoleInx := 0
278 systemMessageUpdated := false
279 for i, msg := range prepared.Messages {
280 // Only add cache control to the last message.
281 if msg.Role == fantasy.MessageRoleSystem {
282 lastSystemRoleInx = i
283 } else if !systemMessageUpdated {
284 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
285 systemMessageUpdated = true
286 }
287 // Than add cache control to the last 2 messages.
288 if i > len(prepared.Messages)-3 {
289 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
290 }
291 }
292
293 if a.systemPromptPrefix != "" {
294 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
295 }
296
297 callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
298 return callContext, prepared, err
299 },
300 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
301 currentAssistant.AppendReasoningContent(reasoning.Text)
302 return a.messages.Update(genCtx, *currentAssistant)
303 },
304 OnReasoningDelta: func(id string, text string) error {
305 currentAssistant.AppendReasoningContent(text)
306 return a.messages.Update(genCtx, *currentAssistant)
307 },
308 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
309 // handle anthropic signature
310 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
311 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
312 currentAssistant.AppendReasoningSignature(reasoning.Signature)
313 }
314 }
315 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
316 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
317 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
318 }
319 }
320 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
321 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
322 currentAssistant.SetReasoningResponsesData(reasoning)
323 }
324 }
325 currentAssistant.FinishThinking()
326 return a.messages.Update(genCtx, *currentAssistant)
327 },
328 OnTextDelta: func(id string, text string) error {
329 // Strip leading newline from initial text content. This is is
330 // particularly important in non-interactive mode where leading
331 // newlines are very visible.
332 if len(currentAssistant.Parts) == 0 {
333 text = strings.TrimPrefix(text, "\n")
334 }
335
336 currentAssistant.AppendContent(text)
337 return a.messages.Update(genCtx, *currentAssistant)
338 },
339 OnToolInputStart: func(id string, toolName string) error {
340 toolCall := message.ToolCall{
341 ID: id,
342 Name: toolName,
343 ProviderExecuted: false,
344 Finished: false,
345 }
346 currentAssistant.AddToolCall(toolCall)
347 return a.messages.Update(genCtx, *currentAssistant)
348 },
349 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
350 // TODO: implement
351 },
352 OnToolCall: func(tc fantasy.ToolCallContent) error {
353 toolCall := message.ToolCall{
354 ID: tc.ToolCallID,
355 Name: tc.ToolName,
356 Input: tc.Input,
357 ProviderExecuted: false,
358 Finished: true,
359 }
360 currentAssistant.AddToolCall(toolCall)
361 return a.messages.Update(genCtx, *currentAssistant)
362 },
363 OnToolResult: func(result fantasy.ToolResultContent) error {
364 var resultContent string
365 isError := false
366 switch result.Result.GetType() {
367 case fantasy.ToolResultContentTypeText:
368 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
369 if ok {
370 resultContent = r.Text
371 }
372 case fantasy.ToolResultContentTypeError:
373 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
374 if ok {
375 isError = true
376 resultContent = r.Error.Error()
377 }
378 case fantasy.ToolResultContentTypeMedia:
379 // TODO: handle this message type
380 }
381 toolResult := message.ToolResult{
382 ToolCallID: result.ToolCallID,
383 Name: result.ToolName,
384 Content: resultContent,
385 IsError: isError,
386 Metadata: result.ClientMetadata,
387 }
388 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
389 Role: message.Tool,
390 Parts: []message.ContentPart{
391 toolResult,
392 },
393 })
394 if createMsgErr != nil {
395 return createMsgErr
396 }
397 return nil
398 },
399 OnStepFinish: func(stepResult fantasy.StepResult) error {
400 finishReason := message.FinishReasonUnknown
401 switch stepResult.FinishReason {
402 case fantasy.FinishReasonLength:
403 finishReason = message.FinishReasonMaxTokens
404 case fantasy.FinishReasonStop:
405 finishReason = message.FinishReasonEndTurn
406 case fantasy.FinishReasonToolCalls:
407 finishReason = message.FinishReasonToolUse
408 }
409 currentAssistant.AddFinish(finishReason, "", "")
410 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
411 sessionLock.Lock()
412 _, sessionErr := a.sessions.Save(genCtx, currentSession)
413 sessionLock.Unlock()
414 if sessionErr != nil {
415 return sessionErr
416 }
417 return a.messages.Update(genCtx, *currentAssistant)
418 },
419 StopWhen: []fantasy.StopCondition{
420 func(_ []fantasy.StepResult) bool {
421 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
422 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
423 remaining := cw - tokens
424 var threshold int64
425 if cw > 200_000 {
426 threshold = 20_000
427 } else {
428 threshold = int64(float64(cw) * 0.2)
429 }
430 if (remaining <= threshold) && !a.disableAutoSummarize {
431 shouldSummarize = true
432 return true
433 }
434 return false
435 },
436 },
437 })
438
439 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
440
441 if err != nil {
442 isCancelErr := errors.Is(err, context.Canceled)
443 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
444 if currentAssistant == nil {
445 return result, err
446 }
447 // Ensure we finish thinking on error to close the reasoning state.
448 currentAssistant.FinishThinking()
449 toolCalls := currentAssistant.ToolCalls()
450 // INFO: we use the parent context here because the genCtx has been cancelled.
451 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
452 if createErr != nil {
453 return nil, createErr
454 }
455 for _, tc := range toolCalls {
456 if !tc.Finished {
457 tc.Finished = true
458 tc.Input = "{}"
459 currentAssistant.AddToolCall(tc)
460 updateErr := a.messages.Update(ctx, *currentAssistant)
461 if updateErr != nil {
462 return nil, updateErr
463 }
464 }
465
466 found := false
467 for _, msg := range msgs {
468 if msg.Role == message.Tool {
469 for _, tr := range msg.ToolResults() {
470 if tr.ToolCallID == tc.ID {
471 found = true
472 break
473 }
474 }
475 }
476 if found {
477 break
478 }
479 }
480 if found {
481 continue
482 }
483 content := "There was an error while executing the tool"
484 if isCancelErr {
485 content = "Tool execution canceled by user"
486 } else if isPermissionErr {
487 content = "User denied permission"
488 }
489 toolResult := message.ToolResult{
490 ToolCallID: tc.ID,
491 Name: tc.Name,
492 Content: content,
493 IsError: true,
494 }
495 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
496 Role: message.Tool,
497 Parts: []message.ContentPart{
498 toolResult,
499 },
500 })
501 if createErr != nil {
502 return nil, createErr
503 }
504 }
505 var fantasyErr *fantasy.Error
506 var providerErr *fantasy.ProviderError
507 const defaultTitle = "Provider Error"
508 if isCancelErr {
509 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
510 } else if isPermissionErr {
511 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
512 } else if errors.As(err, &providerErr) {
513 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
514 } else if errors.As(err, &fantasyErr) {
515 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
516 } else {
517 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
518 }
519 // Note: we use the parent context here because the genCtx has been
520 // cancelled.
521 updateErr := a.messages.Update(ctx, *currentAssistant)
522 if updateErr != nil {
523 return nil, updateErr
524 }
525 return nil, err
526 }
527 wg.Wait()
528
529 if shouldSummarize {
530 a.activeRequests.Del(call.SessionID)
531 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
532 return nil, summarizeErr
533 }
534 // If the agent wasn't done...
535 if len(currentAssistant.ToolCalls()) > 0 {
536 existing, ok := a.messageQueue.Get(call.SessionID)
537 if !ok {
538 existing = []SessionAgentCall{}
539 }
540 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
541 existing = append(existing, call)
542 a.messageQueue.Set(call.SessionID, existing)
543 }
544 }
545
546 // Release active request before processing queued messages.
547 a.activeRequests.Del(call.SessionID)
548 cancel()
549
550 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
551 if !ok || len(queuedMessages) == 0 {
552 return result, err
553 }
554 // There are queued messages restart the loop.
555 firstQueuedMessage := queuedMessages[0]
556 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
557 return a.Run(ctx, firstQueuedMessage)
558}
559
560func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
561 if a.IsSessionBusy(sessionID) {
562 return ErrSessionBusy
563 }
564
565 currentSession, err := a.sessions.Get(ctx, sessionID)
566 if err != nil {
567 return fmt.Errorf("failed to get session: %w", err)
568 }
569 msgs, err := a.getSessionMessages(ctx, currentSession)
570 if err != nil {
571 return err
572 }
573 if len(msgs) == 0 {
574 // Nothing to summarize.
575 return nil
576 }
577
578 aiMsgs, _ := a.preparePrompt(msgs)
579
580 genCtx, cancel := context.WithCancel(ctx)
581 a.activeRequests.Set(sessionID, cancel)
582 defer a.activeRequests.Del(sessionID)
583 defer cancel()
584
585 agent := fantasy.NewAgent(a.largeModel.Model,
586 fantasy.WithSystemPrompt(string(summaryPrompt)),
587 )
588 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
589 Role: message.Assistant,
590 Model: a.largeModel.Model.Model(),
591 Provider: a.largeModel.Model.Provider(),
592 IsSummaryMessage: true,
593 })
594 if err != nil {
595 return err
596 }
597
598 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
599 Prompt: "Provide a detailed summary of our conversation above.",
600 Messages: aiMsgs,
601 ProviderOptions: opts,
602 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
603 prepared.Messages = options.Messages
604 if a.systemPromptPrefix != "" {
605 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
606 }
607 return callContext, prepared, nil
608 },
609 OnReasoningDelta: func(id string, text string) error {
610 summaryMessage.AppendReasoningContent(text)
611 return a.messages.Update(genCtx, summaryMessage)
612 },
613 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
614 // Handle anthropic signature.
615 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
616 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
617 summaryMessage.AppendReasoningSignature(signature.Signature)
618 }
619 }
620 summaryMessage.FinishThinking()
621 return a.messages.Update(genCtx, summaryMessage)
622 },
623 OnTextDelta: func(id, text string) error {
624 summaryMessage.AppendContent(text)
625 return a.messages.Update(genCtx, summaryMessage)
626 },
627 })
628 if err != nil {
629 isCancelErr := errors.Is(err, context.Canceled)
630 if isCancelErr {
631 // User cancelled summarize we need to remove the summary message.
632 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
633 return deleteErr
634 }
635 return err
636 }
637
638 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
639 err = a.messages.Update(genCtx, summaryMessage)
640 if err != nil {
641 return err
642 }
643
644 var openrouterCost *float64
645 for _, step := range resp.Steps {
646 stepCost := a.openrouterCost(step.ProviderMetadata)
647 if stepCost != nil {
648 newCost := *stepCost
649 if openrouterCost != nil {
650 newCost += *openrouterCost
651 }
652 openrouterCost = &newCost
653 }
654 }
655
656 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
657
658 // Just in case, get just the last usage info.
659 usage := resp.Response.Usage
660 currentSession.SummaryMessageID = summaryMessage.ID
661 currentSession.CompletionTokens = usage.OutputTokens
662 currentSession.PromptTokens = 0
663 _, err = a.sessions.Save(genCtx, currentSession)
664 return err
665}
666
667func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
668 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
669 return fantasy.ProviderOptions{}
670 }
671 return fantasy.ProviderOptions{
672 anthropic.Name: &anthropic.ProviderCacheControlOptions{
673 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
674 },
675 bedrock.Name: &anthropic.ProviderCacheControlOptions{
676 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
677 },
678 }
679}
680
681func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
682 var attachmentParts []message.ContentPart
683 for _, attachment := range call.Attachments {
684 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
685 }
686 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
687 parts = append(parts, attachmentParts...)
688 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
689 Role: message.User,
690 Parts: parts,
691 })
692 if err != nil {
693 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
694 }
695 return msg, nil
696}
697
698func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
699 var history []fantasy.Message
700 for _, m := range msgs {
701 if len(m.Parts) == 0 {
702 continue
703 }
704 // Assistant message without content or tool calls (cancelled before it
705 // returned anything).
706 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
707 continue
708 }
709 history = append(history, m.ToAIMessage()...)
710 }
711
712 var files []fantasy.FilePart
713 for _, attachment := range attachments {
714 files = append(files, fantasy.FilePart{
715 Filename: attachment.FileName,
716 Data: attachment.Content,
717 MediaType: attachment.MimeType,
718 })
719 }
720
721 return history, files
722}
723
724func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
725 msgs, err := a.messages.List(ctx, session.ID)
726 if err != nil {
727 return nil, fmt.Errorf("failed to list messages: %w", err)
728 }
729
730 if session.SummaryMessageID != "" {
731 summaryMsgInex := -1
732 for i, msg := range msgs {
733 if msg.ID == session.SummaryMessageID {
734 summaryMsgInex = i
735 break
736 }
737 }
738 if summaryMsgInex != -1 {
739 msgs = msgs[summaryMsgInex:]
740 msgs[0].Role = message.User
741 }
742 }
743 return msgs, nil
744}
745
746func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
747 if prompt == "" {
748 return
749 }
750
751 var maxOutput int64 = 40
752 if a.smallModel.CatwalkCfg.CanReason {
753 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
754 }
755
756 agent := fantasy.NewAgent(a.smallModel.Model,
757 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
758 fantasy.WithMaxOutputTokens(maxOutput),
759 )
760
761 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
762 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
763 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
764 prepared.Messages = options.Messages
765 if a.systemPromptPrefix != "" {
766 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
767 }
768 return callContext, prepared, nil
769 },
770 })
771 if err != nil {
772 slog.Error("error generating title", "err", err)
773 return
774 }
775
776 title := resp.Response.Content.Text()
777
778 title = strings.ReplaceAll(title, "\n", " ")
779
780 // Remove thinking tags if present.
781 if idx := strings.Index(title, "</think>"); idx > 0 {
782 title = title[idx+len("</think>"):]
783 }
784
785 title = strings.TrimSpace(title)
786 if title == "" {
787 slog.Warn("failed to generate title", "warn", "empty title")
788 return
789 }
790
791 session.Title = title
792
793 var openrouterCost *float64
794 for _, step := range resp.Steps {
795 stepCost := a.openrouterCost(step.ProviderMetadata)
796 if stepCost != nil {
797 newCost := *stepCost
798 if openrouterCost != nil {
799 newCost += *openrouterCost
800 }
801 openrouterCost = &newCost
802 }
803 }
804
805 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
806 _, saveErr := a.sessions.Save(ctx, *session)
807 if saveErr != nil {
808 slog.Error("failed to save session title & usage", "error", saveErr)
809 return
810 }
811}
812
813func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
814 openrouterMetadata, ok := metadata[openrouter.Name]
815 if !ok {
816 return nil
817 }
818
819 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
820 if !ok {
821 return nil
822 }
823 return &opts.Usage.Cost
824}
825
826func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
827 modelConfig := model.CatwalkCfg
828 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
829 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
830 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
831 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
832
833 a.eventTokensUsed(session.ID, model, usage, cost)
834
835 if overrideCost != nil {
836 session.Cost += *overrideCost
837 } else {
838 session.Cost += cost
839 }
840
841 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
842 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
843}
844
845func (a *sessionAgent) Cancel(sessionID string) {
846 // Cancel regular requests.
847 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
848 slog.Info("Request cancellation initiated", "session_id", sessionID)
849 cancel()
850 }
851
852 // Also check for summarize requests.
853 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
854 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
855 cancel()
856 }
857
858 if a.QueuedPrompts(sessionID) > 0 {
859 slog.Info("Clearing queued prompts", "session_id", sessionID)
860 a.messageQueue.Del(sessionID)
861 }
862}
863
864func (a *sessionAgent) ClearQueue(sessionID string) {
865 if a.QueuedPrompts(sessionID) > 0 {
866 slog.Info("Clearing queued prompts", "session_id", sessionID)
867 a.messageQueue.Del(sessionID)
868 }
869}
870
871func (a *sessionAgent) CancelAll() {
872 if !a.IsBusy() {
873 return
874 }
875 for key := range a.activeRequests.Seq2() {
876 a.Cancel(key) // key is sessionID
877 }
878
879 timeout := time.After(5 * time.Second)
880 for a.IsBusy() {
881 select {
882 case <-timeout:
883 return
884 default:
885 time.Sleep(200 * time.Millisecond)
886 }
887 }
888}
889
890func (a *sessionAgent) IsBusy() bool {
891 var busy bool
892 for cancelFunc := range a.activeRequests.Seq() {
893 if cancelFunc != nil {
894 busy = true
895 break
896 }
897 }
898 return busy
899}
900
901func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
902 _, busy := a.activeRequests.Get(sessionID)
903 return busy
904}
905
906func (a *sessionAgent) QueuedPrompts(sessionID string) int {
907 l, ok := a.messageQueue.Get(sessionID)
908 if !ok {
909 return 0
910 }
911 return len(l)
912}
913
914func (a *sessionAgent) SetModels(large Model, small Model) {
915 a.largeModel = large
916 a.smallModel = small
917}
918
919func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
920 a.tools = tools
921}
922
923func (a *sessionAgent) Model() Model {
924 return a.largeModel
925}
926
927// executePromptSubmitHook executes the user-prompt-submit hook and applies modifications to the call.
928// Only runs for main agent (not sub-agents).
929func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message.Message, isFirstMessage bool) error {
930 // Skip if sub-agent or no hooks manager.
931 if a.isSubAgent || a.hooksManager == nil {
932 return nil
933 }
934
935 // Convert attachments to file paths.
936 attachmentPaths := make([]string, len(msg.BinaryContent()))
937 for i, att := range msg.BinaryContent() {
938 attachmentPaths[i] = att.Path
939 }
940
941 hookResult, err := a.hooksManager.ExecuteUserPromptSubmit(ctx, msg.SessionID, a.workingDir, hooks.UserPromptSubmitData{
942 Prompt: msg.Content().Text,
943 Attachments: attachmentPaths,
944 Model: a.largeModel.CatwalkCfg.ID,
945 Provider: a.largeModel.Model.Provider(),
946 IsFirstMessage: isFirstMessage,
947 })
948 if err != nil {
949 return fmt.Errorf("hook execution failed: %w", err)
950 }
951
952 // Apply hook modifications to the prompt.
953 if hookResult.ModifiedPrompt != nil {
954 for i, part := range msg.Parts {
955 if _, ok := part.(message.TextContent); ok {
956 msg.Parts[i] = message.TextContent{Text: *hookResult.ModifiedPrompt}
957 }
958 }
959 }
960 msg.AddHookResult(hookResult)
961 err = a.messages.Update(ctx, *msg)
962 if err != nil {
963 return err
964 }
965 // If hook returned Continue: false, stop execution.
966 if !hookResult.Continue {
967 return ErrHookExecutionStop
968 }
969 return nil
970}