1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/json"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "strconv"
20 "strings"
21 "sync"
22 "time"
23
24 "charm.land/fantasy"
25 "charm.land/fantasy/providers/anthropic"
26 "charm.land/fantasy/providers/bedrock"
27 "charm.land/fantasy/providers/google"
28 "charm.land/fantasy/providers/openai"
29 "charm.land/fantasy/providers/openrouter"
30 "github.com/charmbracelet/catwalk/pkg/catwalk"
31 "github.com/charmbracelet/crush/internal/agent/tools"
32 "github.com/charmbracelet/crush/internal/config"
33 "github.com/charmbracelet/crush/internal/csync"
34 "github.com/charmbracelet/crush/internal/hooks"
35 "github.com/charmbracelet/crush/internal/message"
36 "github.com/charmbracelet/crush/internal/permission"
37 "github.com/charmbracelet/crush/internal/session"
38 "github.com/charmbracelet/crush/internal/stringext"
39)
40
41//go:embed templates/title.md
42var titlePrompt []byte
43
44//go:embed templates/summary.md
45var summaryPrompt []byte
46
47type SessionAgentCall struct {
48 SessionID string
49 Prompt string
50 ProviderOptions fantasy.ProviderOptions
51 Attachments []message.Attachment
52 MaxOutputTokens int64
53 Temperature *float64
54 TopP *float64
55 TopK *int64
56 FrequencyPenalty *float64
57 PresencePenalty *float64
58}
59
60type SessionAgent interface {
61 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
62 SetModels(large Model, small Model)
63 SetTools(tools []fantasy.AgentTool)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 ClearQueue(sessionID string)
70 Summarize(context.Context, string, fantasy.ProviderOptions) error
71 Model() Model
72}
73
74type Model struct {
75 Model fantasy.LanguageModel
76 CatwalkCfg catwalk.Model
77 ModelCfg config.SelectedModel
78}
79
80type sessionAgent struct {
81 largeModel Model
82 smallModel Model
83 systemPromptPrefix string
84 systemPrompt string
85 tools []fantasy.AgentTool
86 sessions session.Service
87 messages message.Service
88 disableAutoSummarize bool
89 isYolo bool
90 isSubAgent bool
91 hooksManager hooks.Manager
92 workingDir string
93
94 messageQueue *csync.Map[string, []SessionAgentCall]
95 activeRequests *csync.Map[string, context.CancelFunc]
96}
97
98type SessionAgentOptions struct {
99 LargeModel Model
100 SmallModel Model
101 SystemPromptPrefix string
102 SystemPrompt string
103 DisableAutoSummarize bool
104 IsYolo bool
105 IsSubAgent bool
106 HooksManager hooks.Manager
107 WorkingDir string
108 Sessions session.Service
109 Messages message.Service
110 Tools []fantasy.AgentTool
111}
112
113func NewSessionAgent(
114 opts SessionAgentOptions,
115) SessionAgent {
116 return &sessionAgent{
117 largeModel: opts.LargeModel,
118 smallModel: opts.SmallModel,
119 systemPromptPrefix: opts.SystemPromptPrefix,
120 systemPrompt: opts.SystemPrompt,
121 sessions: opts.Sessions,
122 messages: opts.Messages,
123 disableAutoSummarize: opts.DisableAutoSummarize,
124 tools: opts.Tools,
125 isYolo: opts.IsYolo,
126 isSubAgent: opts.IsSubAgent,
127 hooksManager: opts.HooksManager,
128 workingDir: opts.WorkingDir,
129 messageQueue: csync.NewMap[string, []SessionAgentCall](),
130 activeRequests: csync.NewMap[string, context.CancelFunc](),
131 }
132}
133
134func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
135 if call.Prompt == "" {
136 return nil, ErrEmptyPrompt
137 }
138 if call.SessionID == "" {
139 return nil, ErrSessionMissing
140 }
141
142 // Queue the message if busy
143 if a.IsSessionBusy(call.SessionID) {
144 existing, ok := a.messageQueue.Get(call.SessionID)
145 if !ok {
146 existing = []SessionAgentCall{}
147 }
148 existing = append(existing, call)
149 a.messageQueue.Set(call.SessionID, existing)
150 return nil, nil
151 }
152
153 if len(a.tools) > 0 {
154 // Add Anthropic caching to the last tool.
155 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
156 }
157
158 agent := fantasy.NewAgent(
159 a.largeModel.Model,
160 fantasy.WithSystemPrompt(a.systemPrompt),
161 fantasy.WithTools(a.tools...),
162 )
163
164 sessionLock := sync.Mutex{}
165 currentSession, err := a.sessions.Get(ctx, call.SessionID)
166 if err != nil {
167 return nil, fmt.Errorf("failed to get session: %w", err)
168 }
169
170 msgs, err := a.getSessionMessages(ctx, currentSession)
171 if err != nil {
172 return nil, fmt.Errorf("failed to get session messages: %w", err)
173 }
174
175 var wg sync.WaitGroup
176 // Generate title if first message.
177 if len(msgs) == 0 {
178 wg.Go(func() {
179 sessionLock.Lock()
180 a.generateTitle(ctx, ¤tSession, call.Prompt)
181 sessionLock.Unlock()
182 })
183 }
184
185 // Add the user message to the session.
186 msg, err := a.createUserMessage(ctx, call)
187 if err != nil {
188 return nil, err
189 }
190
191 // Add the session to the context.
192 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
193
194 genCtx, cancel := context.WithCancel(ctx)
195 a.activeRequests.Set(call.SessionID, cancel)
196
197 defer cancel()
198 defer a.activeRequests.Del(call.SessionID)
199
200 // Track completion reason for stop hook
201 var stopReason string
202 defer func() {
203 if stopReason != "" {
204 a.executeStopHook(ctx, call.SessionID, stopReason)
205 }
206 }()
207
208 // create the agent message asap to show loading
209 var currentAssistant *message.Message
210 assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
211 Role: message.Assistant,
212 Parts: []message.ContentPart{},
213 Model: a.largeModel.ModelCfg.Model,
214 Provider: a.largeModel.ModelCfg.Provider,
215 })
216 if err != nil {
217 return nil, err
218 }
219
220 currentAssistant = &assistantMessage
221
222 hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0)
223 if hookErr != nil {
224 stopReason = "error"
225 // Delete the assistant message
226 // use the ctx since this could be a cancellation
227 deleteErr := a.messages.Delete(ctx, currentAssistant.ID)
228 return nil, cmp.Or(deleteErr, hookErr)
229 }
230
231 history, files := a.preparePrompt(msgs, call.Attachments...)
232
233 startTime := time.Now()
234 a.eventPromptSent(call.SessionID)
235
236 // Map to store post-tool-use hook results for OnToolResult callback
237 postToolHookResults := csync.NewMap[string, hooks.HookResult]()
238
239 var shouldSummarize bool
240 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
241 Prompt: msg.ContentWithHookContext(),
242 Files: files,
243 Messages: history,
244 ProviderOptions: call.ProviderOptions,
245 MaxOutputTokens: &call.MaxOutputTokens,
246 TopP: call.TopP,
247 Temperature: call.Temperature,
248 PresencePenalty: call.PresencePenalty,
249 TopK: call.TopK,
250 FrequencyPenalty: call.FrequencyPenalty,
251 // Before each step create a new assistant message.
252 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
253 // only add new assistant message when its not the first step
254 if options.StepNumber != 0 {
255 var assistantMsg message.Message
256 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
257 Role: message.Assistant,
258 Model: a.largeModel.ModelCfg.Model,
259 Provider: a.largeModel.ModelCfg.Provider,
260 })
261 currentAssistant = &assistantMsg
262 // create the message first so we show loading asap
263 if err != nil {
264 return callContext, prepared, err
265 }
266 }
267 prepared.Messages = options.Messages
268 // Reset all cached items.
269 for i := range prepared.Messages {
270 prepared.Messages[i].ProviderOptions = nil
271 }
272
273 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
274 a.messageQueue.Del(call.SessionID)
275 for _, queued := range queuedCalls {
276 userMessage, createErr := a.createUserMessage(callContext, queued)
277 if createErr != nil {
278 return callContext, prepared, createErr
279 }
280
281 hookErr := a.executePromptSubmitHook(ctx, &msg, len(msgs) == 0)
282 if hookErr != nil {
283 return callContext, prepared, hookErr
284 }
285
286 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
287 }
288
289 lastSystemRoleInx := 0
290 systemMessageUpdated := false
291 for i, msg := range prepared.Messages {
292 // Only add cache control to the last message.
293 if msg.Role == fantasy.MessageRoleSystem {
294 lastSystemRoleInx = i
295 } else if !systemMessageUpdated {
296 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
297 systemMessageUpdated = true
298 }
299 // Than add cache control to the last 2 messages.
300 if i > len(prepared.Messages)-3 {
301 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
302 }
303 }
304
305 if a.systemPromptPrefix != "" {
306 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
307 }
308
309 callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
310 return callContext, prepared, err
311 },
312 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
313 currentAssistant.AppendReasoningContent(reasoning.Text)
314 return a.messages.Update(genCtx, *currentAssistant)
315 },
316 OnReasoningDelta: func(id string, text string) error {
317 currentAssistant.AppendReasoningContent(text)
318 return a.messages.Update(genCtx, *currentAssistant)
319 },
320 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
321 // handle anthropic signature
322 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
323 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
324 currentAssistant.AppendReasoningSignature(reasoning.Signature)
325 }
326 }
327 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
328 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
329 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
330 }
331 }
332 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
333 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
334 currentAssistant.SetReasoningResponsesData(reasoning)
335 }
336 }
337 currentAssistant.FinishThinking()
338 return a.messages.Update(genCtx, *currentAssistant)
339 },
340 OnTextDelta: func(id string, text string) error {
341 // Strip leading newline from initial text content. This is is
342 // particularly important in non-interactive mode where leading
343 // newlines are very visible.
344 if len(currentAssistant.Parts) == 0 {
345 text = strings.TrimPrefix(text, "\n")
346 }
347
348 currentAssistant.AppendContent(text)
349 return a.messages.Update(genCtx, *currentAssistant)
350 },
351 OnToolInputStart: func(id string, toolName string) error {
352 toolCall := message.ToolCall{
353 ID: id,
354 Name: toolName,
355 ProviderExecuted: false,
356 Finished: false,
357 }
358 currentAssistant.AddToolCall(toolCall)
359 return a.messages.Update(genCtx, *currentAssistant)
360 },
361 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
362 // TODO: implement
363 },
364 OnToolCall: func(tc fantasy.ToolCallContent) error {
365 toolCall := message.ToolCall{
366 ID: tc.ToolCallID,
367 Name: tc.ToolName,
368 Input: tc.Input,
369 ProviderExecuted: false,
370 Finished: true,
371 }
372 currentAssistant.AddToolCall(toolCall)
373 return a.messages.Update(genCtx, *currentAssistant)
374 },
375 PreToolExecute: func(ctx context.Context, toolCall fantasy.ToolCall) (context.Context, *fantasy.ToolCall, error) {
376 return a.executePreToolUseHook(ctx, call.SessionID, toolCall, currentAssistant)
377 },
378 OnToolResult: func(result fantasy.ToolResultContent) error {
379 var resultContent string
380 isError := false
381 switch result.Result.GetType() {
382 case fantasy.ToolResultContentTypeText:
383 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
384 if ok {
385 resultContent = r.Text
386 }
387 case fantasy.ToolResultContentTypeError:
388 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
389 if ok {
390 isError = true
391 resultContent = r.Error.Error()
392 }
393 case fantasy.ToolResultContentTypeMedia:
394 // TODO: handle this message type
395 }
396 toolResult := message.ToolResult{
397 ToolCallID: result.ToolCallID,
398 Name: result.ToolName,
399 Content: resultContent,
400 IsError: isError,
401 Metadata: result.ClientMetadata,
402 }
403 // Attach hook result if available
404 if hookRes, ok := postToolHookResults.Get(result.ToolCallID); ok {
405 toolResult.HookResult = &hookRes
406 }
407 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
408 Role: message.Tool,
409 Parts: []message.ContentPart{
410 toolResult,
411 },
412 })
413 if createMsgErr != nil {
414 return createMsgErr
415 }
416 return nil
417 },
418 PostToolExecute: func(ctx context.Context, toolCall fantasy.ToolCall, response fantasy.ToolResponse, executionTimeMs int64) (*fantasy.ToolResponse, error) {
419 modifiedResponse, hookResult, err := a.executePostToolUseHook(ctx, call.SessionID, toolCall, response, executionTimeMs)
420 if hookResult != nil {
421 // Store for OnToolResult callback
422 postToolHookResults.Set(toolCall.ID, *hookResult)
423 }
424 return modifiedResponse, err
425 },
426 OnStepFinish: func(stepResult fantasy.StepResult) error {
427 finishReason := message.FinishReasonUnknown
428 switch stepResult.FinishReason {
429 case fantasy.FinishReasonLength:
430 finishReason = message.FinishReasonMaxTokens
431 case fantasy.FinishReasonStop:
432 finishReason = message.FinishReasonEndTurn
433 case fantasy.FinishReasonToolCalls:
434 finishReason = message.FinishReasonToolUse
435 }
436 currentAssistant.AddFinish(finishReason, "", "")
437 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
438 sessionLock.Lock()
439 _, sessionErr := a.sessions.Save(genCtx, currentSession)
440 sessionLock.Unlock()
441 if sessionErr != nil {
442 return sessionErr
443 }
444 return a.messages.Update(genCtx, *currentAssistant)
445 },
446 StopWhen: []fantasy.StopCondition{
447 func(_ []fantasy.StepResult) bool {
448 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
449 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
450 remaining := cw - tokens
451 var threshold int64
452 if cw > 200_000 {
453 threshold = 20_000
454 } else {
455 threshold = int64(float64(cw) * 0.2)
456 }
457 if (remaining <= threshold) && !a.disableAutoSummarize {
458 shouldSummarize = true
459 return true
460 }
461 return false
462 },
463 },
464 })
465
466 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
467
468 if err != nil {
469 isCancelErr := errors.Is(err, context.Canceled)
470 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
471 isHookDenied := errors.Is(err, ErrHookDenied)
472
473 // Set stop reason for defer
474 if isCancelErr {
475 stopReason = "cancelled"
476 } else if isPermissionErr || isHookDenied {
477 stopReason = "permission_denied"
478 } else {
479 stopReason = "error"
480 }
481
482 if currentAssistant == nil {
483 return result, err
484 }
485 // Ensure we finish thinking on error to close the reasoning state.
486 currentAssistant.FinishThinking()
487 toolCalls := currentAssistant.ToolCalls()
488 // INFO: we use the parent context here because the genCtx has been cancelled.
489 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
490 if createErr != nil {
491 return nil, createErr
492 }
493 for _, tc := range toolCalls {
494 if !tc.Finished {
495 tc.Finished = true
496 tc.Input = "{}"
497 currentAssistant.AddToolCall(tc)
498 updateErr := a.messages.Update(ctx, *currentAssistant)
499 if updateErr != nil {
500 return nil, updateErr
501 }
502 }
503
504 found := false
505 for _, msg := range msgs {
506 if msg.Role == message.Tool {
507 for _, tr := range msg.ToolResults() {
508 if tr.ToolCallID == tc.ID {
509 found = true
510 break
511 }
512 }
513 }
514 if found {
515 break
516 }
517 }
518 if found {
519 continue
520 }
521 content := "There was an error while executing the tool"
522 if isCancelErr {
523 content = "Tool execution canceled by user"
524 } else if isPermissionErr {
525 content = "User denied permission"
526 } else if isHookDenied {
527 content = "Hook denied execution"
528 }
529 toolResult := message.ToolResult{
530 ToolCallID: tc.ID,
531 Name: tc.Name,
532 Content: content,
533 IsError: true,
534 }
535 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
536 Role: message.Tool,
537 Parts: []message.ContentPart{
538 toolResult,
539 },
540 })
541 if createErr != nil {
542 return nil, createErr
543 }
544 }
545 var fantasyErr *fantasy.Error
546 var providerErr *fantasy.ProviderError
547 const defaultTitle = "Provider Error"
548 if isCancelErr {
549 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
550 } else if isPermissionErr {
551 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
552 } else if isHookDenied {
553 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook denied execution", "")
554 } else if errors.As(err, &providerErr) {
555 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
556 } else if errors.As(err, &fantasyErr) {
557 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
558 } else {
559 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
560 }
561 // Note: we use the parent context here because the genCtx has been
562 // cancelled.
563 updateErr := a.messages.Update(ctx, *currentAssistant)
564 if updateErr != nil {
565 return nil, updateErr
566 }
567 return nil, err
568 }
569 wg.Wait()
570
571 // Set completion reason for stop hook
572 stopReason = "completed"
573
574 if shouldSummarize {
575 a.activeRequests.Del(call.SessionID)
576 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
577 return nil, summarizeErr
578 }
579 // If the agent wasn't done...
580 if len(currentAssistant.ToolCalls()) > 0 {
581 existing, ok := a.messageQueue.Get(call.SessionID)
582 if !ok {
583 existing = []SessionAgentCall{}
584 }
585 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
586 existing = append(existing, call)
587 a.messageQueue.Set(call.SessionID, existing)
588 }
589 }
590
591 // Release active request before processing queued messages.
592 a.activeRequests.Del(call.SessionID)
593 cancel()
594
595 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
596 if !ok || len(queuedMessages) == 0 {
597 return result, err
598 }
599 // There are queued messages restart the loop.
600 firstQueuedMessage := queuedMessages[0]
601 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
602 return a.Run(ctx, firstQueuedMessage)
603}
604
605func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
606 if a.IsSessionBusy(sessionID) {
607 return ErrSessionBusy
608 }
609
610 currentSession, err := a.sessions.Get(ctx, sessionID)
611 if err != nil {
612 return fmt.Errorf("failed to get session: %w", err)
613 }
614 msgs, err := a.getSessionMessages(ctx, currentSession)
615 if err != nil {
616 return err
617 }
618 if len(msgs) == 0 {
619 // Nothing to summarize.
620 return nil
621 }
622
623 aiMsgs, _ := a.preparePrompt(msgs)
624
625 genCtx, cancel := context.WithCancel(ctx)
626 a.activeRequests.Set(sessionID, cancel)
627 defer a.activeRequests.Del(sessionID)
628 defer cancel()
629
630 agent := fantasy.NewAgent(a.largeModel.Model,
631 fantasy.WithSystemPrompt(string(summaryPrompt)),
632 )
633 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
634 Role: message.Assistant,
635 Model: a.largeModel.Model.Model(),
636 Provider: a.largeModel.Model.Provider(),
637 IsSummaryMessage: true,
638 })
639 if err != nil {
640 return err
641 }
642
643 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
644 Prompt: "Provide a detailed summary of our conversation above.",
645 Messages: aiMsgs,
646 ProviderOptions: opts,
647 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
648 prepared.Messages = options.Messages
649 if a.systemPromptPrefix != "" {
650 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
651 }
652 return callContext, prepared, nil
653 },
654 OnReasoningDelta: func(id string, text string) error {
655 summaryMessage.AppendReasoningContent(text)
656 return a.messages.Update(genCtx, summaryMessage)
657 },
658 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
659 // Handle anthropic signature.
660 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
661 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
662 summaryMessage.AppendReasoningSignature(signature.Signature)
663 }
664 }
665 summaryMessage.FinishThinking()
666 return a.messages.Update(genCtx, summaryMessage)
667 },
668 OnTextDelta: func(id, text string) error {
669 summaryMessage.AppendContent(text)
670 return a.messages.Update(genCtx, summaryMessage)
671 },
672 })
673 if err != nil {
674 isCancelErr := errors.Is(err, context.Canceled)
675 if isCancelErr {
676 // User cancelled summarize we need to remove the summary message.
677 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
678 return deleteErr
679 }
680 return err
681 }
682
683 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
684 err = a.messages.Update(genCtx, summaryMessage)
685 if err != nil {
686 return err
687 }
688
689 var openrouterCost *float64
690 for _, step := range resp.Steps {
691 stepCost := a.openrouterCost(step.ProviderMetadata)
692 if stepCost != nil {
693 newCost := *stepCost
694 if openrouterCost != nil {
695 newCost += *openrouterCost
696 }
697 openrouterCost = &newCost
698 }
699 }
700
701 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
702
703 // Just in case, get just the last usage info.
704 usage := resp.Response.Usage
705 currentSession.SummaryMessageID = summaryMessage.ID
706 currentSession.CompletionTokens = usage.OutputTokens
707 currentSession.PromptTokens = 0
708 _, err = a.sessions.Save(genCtx, currentSession)
709 return err
710}
711
712func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
713 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
714 return fantasy.ProviderOptions{}
715 }
716 return fantasy.ProviderOptions{
717 anthropic.Name: &anthropic.ProviderCacheControlOptions{
718 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
719 },
720 bedrock.Name: &anthropic.ProviderCacheControlOptions{
721 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
722 },
723 }
724}
725
726func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
727 var attachmentParts []message.ContentPart
728 for _, attachment := range call.Attachments {
729 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
730 }
731 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
732 parts = append(parts, attachmentParts...)
733 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
734 Role: message.User,
735 Parts: parts,
736 })
737 if err != nil {
738 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
739 }
740 return msg, nil
741}
742
743func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
744 var history []fantasy.Message
745 for _, m := range msgs {
746 if len(m.Parts) == 0 {
747 continue
748 }
749 // Assistant message without content or tool calls (cancelled before it
750 // returned anything).
751 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
752 continue
753 }
754 history = append(history, m.ToAIMessage()...)
755 }
756
757 var files []fantasy.FilePart
758 for _, attachment := range attachments {
759 files = append(files, fantasy.FilePart{
760 Filename: attachment.FileName,
761 Data: attachment.Content,
762 MediaType: attachment.MimeType,
763 })
764 }
765
766 return history, files
767}
768
769func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
770 msgs, err := a.messages.List(ctx, session.ID)
771 if err != nil {
772 return nil, fmt.Errorf("failed to list messages: %w", err)
773 }
774
775 if session.SummaryMessageID != "" {
776 summaryMsgInex := -1
777 for i, msg := range msgs {
778 if msg.ID == session.SummaryMessageID {
779 summaryMsgInex = i
780 break
781 }
782 }
783 if summaryMsgInex != -1 {
784 msgs = msgs[summaryMsgInex:]
785 msgs[0].Role = message.User
786 }
787 }
788 return msgs, nil
789}
790
791func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
792 if prompt == "" {
793 return
794 }
795
796 var maxOutput int64 = 40
797 if a.smallModel.CatwalkCfg.CanReason {
798 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
799 }
800
801 agent := fantasy.NewAgent(a.smallModel.Model,
802 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
803 fantasy.WithMaxOutputTokens(maxOutput),
804 )
805
806 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
807 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
808 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
809 prepared.Messages = options.Messages
810 if a.systemPromptPrefix != "" {
811 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
812 }
813 return callContext, prepared, nil
814 },
815 })
816 if err != nil {
817 slog.Error("error generating title", "err", err)
818 return
819 }
820
821 title := resp.Response.Content.Text()
822
823 title = strings.ReplaceAll(title, "\n", " ")
824
825 // Remove thinking tags if present.
826 if idx := strings.Index(title, "</think>"); idx > 0 {
827 title = title[idx+len("</think>"):]
828 }
829
830 title = strings.TrimSpace(title)
831 if title == "" {
832 slog.Warn("failed to generate title", "warn", "empty title")
833 return
834 }
835
836 session.Title = title
837
838 var openrouterCost *float64
839 for _, step := range resp.Steps {
840 stepCost := a.openrouterCost(step.ProviderMetadata)
841 if stepCost != nil {
842 newCost := *stepCost
843 if openrouterCost != nil {
844 newCost += *openrouterCost
845 }
846 openrouterCost = &newCost
847 }
848 }
849
850 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
851 _, saveErr := a.sessions.Save(ctx, *session)
852 if saveErr != nil {
853 slog.Error("failed to save session title & usage", "error", saveErr)
854 return
855 }
856}
857
858func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
859 openrouterMetadata, ok := metadata[openrouter.Name]
860 if !ok {
861 return nil
862 }
863
864 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
865 if !ok {
866 return nil
867 }
868 return &opts.Usage.Cost
869}
870
871func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
872 modelConfig := model.CatwalkCfg
873 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
874 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
875 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
876 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
877
878 a.eventTokensUsed(session.ID, model, usage, cost)
879
880 if overrideCost != nil {
881 session.Cost += *overrideCost
882 } else {
883 session.Cost += cost
884 }
885
886 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
887 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
888}
889
890func (a *sessionAgent) Cancel(sessionID string) {
891 // Cancel regular requests.
892 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
893 slog.Info("Request cancellation initiated", "session_id", sessionID)
894 cancel()
895 }
896
897 // Also check for summarize requests.
898 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
899 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
900 cancel()
901 }
902
903 if a.QueuedPrompts(sessionID) > 0 {
904 slog.Info("Clearing queued prompts", "session_id", sessionID)
905 a.messageQueue.Del(sessionID)
906 }
907}
908
909func (a *sessionAgent) ClearQueue(sessionID string) {
910 if a.QueuedPrompts(sessionID) > 0 {
911 slog.Info("Clearing queued prompts", "session_id", sessionID)
912 a.messageQueue.Del(sessionID)
913 }
914}
915
916func (a *sessionAgent) CancelAll() {
917 if !a.IsBusy() {
918 return
919 }
920 for key := range a.activeRequests.Seq2() {
921 a.Cancel(key) // key is sessionID
922 }
923
924 timeout := time.After(5 * time.Second)
925 for a.IsBusy() {
926 select {
927 case <-timeout:
928 return
929 default:
930 time.Sleep(200 * time.Millisecond)
931 }
932 }
933}
934
935func (a *sessionAgent) IsBusy() bool {
936 var busy bool
937 for cancelFunc := range a.activeRequests.Seq() {
938 if cancelFunc != nil {
939 busy = true
940 break
941 }
942 }
943 return busy
944}
945
946func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
947 _, busy := a.activeRequests.Get(sessionID)
948 return busy
949}
950
951func (a *sessionAgent) QueuedPrompts(sessionID string) int {
952 l, ok := a.messageQueue.Get(sessionID)
953 if !ok {
954 return 0
955 }
956 return len(l)
957}
958
959func (a *sessionAgent) SetModels(large Model, small Model) {
960 a.largeModel = large
961 a.smallModel = small
962}
963
964func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
965 a.tools = tools
966}
967
968func (a *sessionAgent) Model() Model {
969 return a.largeModel
970}
971
972// executePromptSubmitHook executes the user-prompt-submit hook and applies modifications to the call.
973// Only runs for main agent (not sub-agents).
974func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message.Message, isFirstMessage bool) error {
975 // Skip if sub-agent or no hooks manager.
976 if a.isSubAgent || a.hooksManager == nil {
977 return nil
978 }
979
980 // Convert attachments to file paths.
981 attachmentPaths := make([]string, len(msg.BinaryContent()))
982 for i, att := range msg.BinaryContent() {
983 attachmentPaths[i] = att.Path
984 }
985
986 hookResult, err := a.hooksManager.ExecuteUserPromptSubmit(ctx, msg.SessionID, a.workingDir, hooks.UserPromptSubmitData{
987 Prompt: msg.Content().Text,
988 Attachments: attachmentPaths,
989 Model: a.largeModel.CatwalkCfg.ID,
990 Provider: a.largeModel.Model.Provider(),
991 IsFirstMessage: isFirstMessage,
992 })
993 if err != nil {
994 return fmt.Errorf("hook execution failed: %w", err)
995 }
996
997 // Apply hook modifications to the prompt.
998 if hookResult.ModifiedPrompt != nil {
999 for i, part := range msg.Parts {
1000 if _, ok := part.(message.TextContent); ok {
1001 msg.Parts[i] = message.TextContent{Text: *hookResult.ModifiedPrompt}
1002 }
1003 }
1004 }
1005 msg.AddHookResult(hookResult)
1006 err = a.messages.Update(ctx, *msg)
1007 if err != nil {
1008 return err
1009 }
1010 // If hook returned Continue: false, stop execution.
1011 if !hookResult.Continue {
1012 return ErrHookExecutionStop
1013 }
1014 return nil
1015}
1016
1017// executePreToolUseHook executes the pre-tool-use hook and applies modifications.
1018// Only runs for main agent (not sub-agents).
1019func (a *sessionAgent) executePreToolUseHook(ctx context.Context, sessionID string, toolCall fantasy.ToolCall, currentAssistant *message.Message) (context.Context, *fantasy.ToolCall, error) {
1020 // Skip if sub-agent or no hooks manager.
1021 if a.isSubAgent || a.hooksManager == nil {
1022 return ctx, nil, nil
1023 }
1024
1025 // Parse tool input to map
1026 var toolInput map[string]any
1027 if err := json.Unmarshal([]byte(toolCall.Input), &toolInput); err != nil {
1028 // If we can't parse the input, skip the hook
1029 return ctx, nil, nil
1030 }
1031
1032 hookResult, err := a.hooksManager.ExecutePreToolUse(ctx, sessionID, a.workingDir, hooks.PreToolUseData{
1033 ToolName: toolCall.Name,
1034 ToolCallID: toolCall.ID,
1035 ToolInput: toolInput,
1036 })
1037 if err != nil {
1038 return ctx, nil, fmt.Errorf("pre-tool-use hook execution failed: %w", err)
1039 }
1040
1041 // Store hook result in the current assistant's tool call
1042 for _, tc := range currentAssistant.ToolCalls() {
1043 if tc.ID == toolCall.ID {
1044 tc.HookResult = &hookResult
1045 currentAssistant.AddToolCall(tc)
1046 if updateErr := a.messages.Update(ctx, *currentAssistant); updateErr != nil {
1047 slog.Error("failed to update assistant message with pre-hook result", "error", updateErr)
1048 }
1049 break
1050 }
1051 }
1052
1053 // If hook returned Continue: false, deny execution.
1054 if !hookResult.Continue {
1055 return ctx, nil, ErrHookDenied
1056 }
1057
1058 // Set permission in context for tools to use
1059 if hookResult.Permission != "" {
1060 ctx = tools.SetHookPermissionInContext(ctx, hookResult.Permission)
1061 }
1062
1063 // Apply modified input if present.
1064 if len(hookResult.ModifiedInput) > 0 {
1065 // Merge modified input with original
1066 for k, v := range hookResult.ModifiedInput {
1067 toolInput[k] = v
1068 }
1069
1070 modifiedInputJSON, err := json.Marshal(toolInput)
1071 if err != nil {
1072 return ctx, nil, fmt.Errorf("failed to marshal modified input: %w", err)
1073 }
1074
1075 modifiedCall := toolCall
1076 modifiedCall.Input = string(modifiedInputJSON)
1077 return ctx, &modifiedCall, nil
1078 }
1079
1080 return ctx, nil, nil
1081}
1082
1083// executePostToolUseHook executes the post-tool-use hook and applies modifications.
1084// Only runs for main agent (not sub-agents).
1085func (a *sessionAgent) executePostToolUseHook(ctx context.Context, sessionID string, toolCall fantasy.ToolCall, response fantasy.ToolResponse, executionTimeMs int64) (*fantasy.ToolResponse, *hooks.HookResult, error) {
1086 // Skip if sub-agent or no hooks manager.
1087 if a.isSubAgent || a.hooksManager == nil {
1088 return nil, nil, nil
1089 }
1090
1091 // Parse tool input to map
1092 var toolInput map[string]any
1093 if err := json.Unmarshal([]byte(toolCall.Input), &toolInput); err != nil {
1094 return nil, nil, nil
1095 }
1096
1097 // Parse tool output to map
1098 toolOutput := map[string]any{
1099 "success": !response.IsError,
1100 "content": response.Content,
1101 }
1102 if response.Metadata != "" {
1103 toolOutput["metadata"] = response.Metadata
1104 }
1105
1106 hookResult, err := a.hooksManager.ExecutePostToolUse(ctx, sessionID, a.workingDir, hooks.PostToolUseData{
1107 ToolName: toolCall.Name,
1108 ToolCallID: toolCall.ID,
1109 ToolInput: toolInput,
1110 ToolOutput: toolOutput,
1111 ExecutionTimeMs: executionTimeMs,
1112 })
1113 if err != nil {
1114 return nil, nil, fmt.Errorf("post-tool-use hook execution failed: %w", err)
1115 }
1116
1117 // If hook returned Continue: false, return error to stop execution.
1118 if !hookResult.Continue {
1119 return nil, &hookResult, ErrHookDenied
1120 }
1121
1122 // Apply modified output if present.
1123 if len(hookResult.ModifiedOutput) > 0 {
1124 modifiedResponse := response
1125
1126 // Apply modifications
1127 if content, ok := hookResult.ModifiedOutput["content"].(string); ok {
1128 modifiedResponse.Content = content
1129 }
1130 if success, ok := hookResult.ModifiedOutput["success"].(bool); ok {
1131 modifiedResponse.IsError = !success
1132 }
1133 if metadata, ok := hookResult.ModifiedOutput["metadata"].(string); ok {
1134 modifiedResponse.Metadata = metadata
1135 }
1136
1137 return &modifiedResponse, &hookResult, nil
1138 }
1139
1140 return nil, &hookResult, nil
1141}
1142
1143// executeStopHook executes the stop hook when agent loop ends.
1144// Only runs for main agent (not sub-agents). Errors are logged but don't fail.
1145func (a *sessionAgent) executeStopHook(ctx context.Context, sessionID, reason string) {
1146 // Skip if sub-agent or no hooks manager.
1147 if a.isSubAgent || a.hooksManager == nil {
1148 return
1149 }
1150
1151 // Use a fresh context with timeout to ensure hook runs even if parent is cancelled
1152 hookCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
1153 defer cancel()
1154
1155 _, err := a.hooksManager.ExecuteStop(hookCtx, sessionID, a.workingDir, hooks.StopData{
1156 Reason: reason,
1157 })
1158 if err != nil {
1159 slog.Error("stop hook execution failed", "session_id", sessionID, "reason", reason, "error", err)
1160 }
1161}