1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "github.com/charmbracelet/catwalk/pkg/catwalk"
30 "github.com/charmbracelet/crush/internal/agent/tools"
31 "github.com/charmbracelet/crush/internal/config"
32 "github.com/charmbracelet/crush/internal/csync"
33 "github.com/charmbracelet/crush/internal/hooks"
34 "github.com/charmbracelet/crush/internal/message"
35 "github.com/charmbracelet/crush/internal/permission"
36 "github.com/charmbracelet/crush/internal/session"
37 "github.com/charmbracelet/crush/internal/stringext"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57}
58
59type SessionAgent interface {
60 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
61 SetModels(large Model, small Model)
62 SetTools(tools []fantasy.AgentTool)
63 Cancel(sessionID string)
64 CancelAll()
65 IsSessionBusy(sessionID string) bool
66 IsBusy() bool
67 QueuedPrompts(sessionID string) int
68 ClearQueue(sessionID string)
69 Summarize(context.Context, string, fantasy.ProviderOptions) error
70 Model() Model
71}
72
73type Model struct {
74 Model fantasy.LanguageModel
75 CatwalkCfg catwalk.Model
76 ModelCfg config.SelectedModel
77}
78
79type sessionAgent struct {
80 largeModel Model
81 smallModel Model
82 systemPromptPrefix string
83 systemPrompt string
84 tools []fantasy.AgentTool
85 sessions session.Service
86 messages message.Service
87 hooks hooks.Service
88 disableAutoSummarize bool
89 isYolo bool
90
91 messageQueue *csync.Map[string, []SessionAgentCall]
92 activeRequests *csync.Map[string, context.CancelFunc]
93}
94
95type SessionAgentOptions struct {
96 LargeModel Model
97 SmallModel Model
98 SystemPromptPrefix string
99 SystemPrompt string
100 DisableAutoSummarize bool
101 IsYolo bool
102 Sessions session.Service
103 Messages message.Service
104 Hooks hooks.Service
105 Tools []fantasy.AgentTool
106}
107
108func NewSessionAgent(
109 opts SessionAgentOptions,
110) SessionAgent {
111 return &sessionAgent{
112 largeModel: opts.LargeModel,
113 smallModel: opts.SmallModel,
114 systemPromptPrefix: opts.SystemPromptPrefix,
115 systemPrompt: opts.SystemPrompt,
116 sessions: opts.Sessions,
117 messages: opts.Messages,
118 disableAutoSummarize: opts.DisableAutoSummarize,
119 hooks: opts.Hooks,
120 tools: opts.Tools,
121 isYolo: opts.IsYolo,
122 messageQueue: csync.NewMap[string, []SessionAgentCall](),
123 activeRequests: csync.NewMap[string, context.CancelFunc](),
124 }
125}
126
127func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
128 if call.Prompt == "" {
129 return nil, ErrEmptyPrompt
130 }
131 if call.SessionID == "" {
132 return nil, ErrSessionMissing
133 }
134
135 // Queue the message if busy
136 if a.IsSessionBusy(call.SessionID) {
137 existing, ok := a.messageQueue.Get(call.SessionID)
138 if !ok {
139 existing = []SessionAgentCall{}
140 }
141 existing = append(existing, call)
142 a.messageQueue.Set(call.SessionID, existing)
143 return nil, nil
144 }
145
146 if len(a.tools) > 0 {
147 // Add Anthropic caching to the last tool.
148 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
149 }
150
151 agent := fantasy.NewAgent(
152 a.largeModel.Model,
153 fantasy.WithSystemPrompt(a.systemPrompt),
154 fantasy.WithTools(a.tools...),
155 )
156
157 sessionLock := sync.Mutex{}
158 currentSession, err := a.sessions.Get(ctx, call.SessionID)
159 if err != nil {
160 return nil, fmt.Errorf("failed to get session: %w", err)
161 }
162
163 msgs, err := a.getSessionMessages(ctx, currentSession)
164 if err != nil {
165 return nil, fmt.Errorf("failed to get session messages: %w", err)
166 }
167
168 var wg sync.WaitGroup
169 // Generate title if first message.
170 if len(msgs) == 0 {
171 wg.Go(func() {
172 sessionLock.Lock()
173 a.generateTitle(ctx, ¤tSession, call.Prompt)
174 sessionLock.Unlock()
175 })
176 }
177
178 // Add the user message to the session.
179 userMsg, err := a.createUserMessage(ctx, call)
180 if err != nil {
181 return nil, err
182 }
183
184 // Add the session to the context.
185 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
186
187 genCtx, cancel := context.WithCancel(ctx)
188 a.activeRequests.Set(call.SessionID, cancel)
189 defer cancel()
190 defer a.activeRequests.Del(call.SessionID)
191
192 // create the agent message asap to show loading
193 var currentAssistant *message.Message
194 assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
195 Role: message.Assistant,
196 Parts: []message.ContentPart{},
197 Model: a.largeModel.ModelCfg.Model,
198 Provider: a.largeModel.ModelCfg.Provider,
199 })
200 if err != nil {
201 return nil, err
202 }
203 currentAssistant = &assistantMessage
204
205 // run hooks after assistant message is created
206 // this way we show loading for the user
207 hooks, err := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
208 if err != nil {
209 return nil, err
210 }
211 userMsg.AddHookOutputs(hooks...)
212 if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
213 return nil, updateErr
214 }
215
216 for _, hook := range hooks {
217 // execution stopped
218 if hook.Stop {
219 deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
220 if deleteErr != nil {
221 return nil, deleteErr
222 }
223 return nil, nil
224 }
225 }
226
227 history, files := a.preparePrompt(msgs, call.Attachments...)
228
229 startTime := time.Now()
230 a.eventPromptSent(call.SessionID)
231
232 var shouldSummarize bool
233 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
234 Prompt: userMsg.ContentWithHooksContext(),
235 Files: files,
236 Messages: history,
237 ProviderOptions: call.ProviderOptions,
238 MaxOutputTokens: &call.MaxOutputTokens,
239 TopP: call.TopP,
240 Temperature: call.Temperature,
241 PresencePenalty: call.PresencePenalty,
242 TopK: call.TopK,
243 FrequencyPenalty: call.FrequencyPenalty,
244 // Before each step create a new assistant message.
245 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
246 // only add new assistant message when its not the first step
247 if options.StepNumber != 0 {
248 var assistantMsg message.Message
249 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
250 Role: message.Assistant,
251 Parts: []message.ContentPart{},
252 Model: a.largeModel.ModelCfg.Model,
253 Provider: a.largeModel.ModelCfg.Provider,
254 })
255 currentAssistant = &assistantMsg
256 // create the message first so we show loading asap
257 if err != nil {
258 return callContext, prepared, err
259 }
260 }
261
262 prepared.Messages = options.Messages
263 // Reset all cached items.
264 for i := range prepared.Messages {
265 prepared.Messages[i].ProviderOptions = nil
266 }
267
268 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
269 a.messageQueue.Del(call.SessionID)
270 for _, queued := range queuedCalls {
271 userMessage, createErr := a.createUserMessage(callContext, queued)
272 if createErr != nil {
273 return callContext, prepared, createErr
274 }
275
276 // run hooks after assistant message is created
277 // this way we show loading for the user
278 hooks, hookErr := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
279 if hookErr != nil {
280 return callContext, prepared, hookErr
281 }
282 userMsg.AddHookOutputs(hooks...)
283 for _, hook := range hooks {
284 // execution stopped
285 if hook.Stop {
286 deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
287 if deleteErr != nil {
288 return callContext, prepared, deleteErr
289 }
290 return callContext, prepared, ErrHookCancellation
291 }
292 }
293 if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
294 return callContext, prepared, updateErr
295 }
296 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
297 }
298
299 lastSystemRoleInx := 0
300 systemMessageUpdated := false
301 for i, msg := range prepared.Messages {
302 // Only add cache control to the last message.
303 if msg.Role == fantasy.MessageRoleSystem {
304 lastSystemRoleInx = i
305 } else if !systemMessageUpdated {
306 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
307 systemMessageUpdated = true
308 }
309 // Than add cache control to the last 2 messages.
310 if i > len(prepared.Messages)-3 {
311 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
312 }
313 }
314
315 if a.systemPromptPrefix != "" {
316 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
317 }
318
319 callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
320 return callContext, prepared, err
321 },
322 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
323 currentAssistant.AppendReasoningContent(reasoning.Text)
324 return a.messages.Update(genCtx, *currentAssistant)
325 },
326 OnReasoningDelta: func(id string, text string) error {
327 currentAssistant.AppendReasoningContent(text)
328 return a.messages.Update(genCtx, *currentAssistant)
329 },
330 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
331 // handle anthropic signature
332 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
333 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
334 currentAssistant.AppendReasoningSignature(reasoning.Signature)
335 }
336 }
337 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
338 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
339 currentAssistant.AppendReasoningSignature(reasoning.Signature)
340 }
341 }
342 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
343 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
344 currentAssistant.SetReasoningResponsesData(reasoning)
345 }
346 }
347 currentAssistant.FinishThinking()
348 return a.messages.Update(genCtx, *currentAssistant)
349 },
350 OnTextDelta: func(id string, text string) error {
351 // Strip leading newline from initial text content. This is is
352 // particularly important in non-interactive mode where leading
353 // newlines are very visible.
354 if len(currentAssistant.Parts) == 0 {
355 text = strings.TrimPrefix(text, "\n")
356 }
357
358 currentAssistant.AppendContent(text)
359 return a.messages.Update(genCtx, *currentAssistant)
360 },
361 OnToolInputStart: func(id string, toolName string) error {
362 toolCall := message.ToolCall{
363 ID: id,
364 Name: toolName,
365 ProviderExecuted: false,
366 Finished: false,
367 }
368 currentAssistant.AddToolCall(toolCall)
369 return a.messages.Update(genCtx, *currentAssistant)
370 },
371 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
372 // TODO: implement
373 },
374 OnToolCall: func(tc fantasy.ToolCallContent) error {
375 toolCall := message.ToolCall{
376 ID: tc.ToolCallID,
377 Name: tc.ToolName,
378 Input: tc.Input,
379 ProviderExecuted: false,
380 Finished: true,
381 }
382 currentAssistant.AddToolCall(toolCall)
383 return a.messages.Update(genCtx, *currentAssistant)
384 },
385 OnToolResult: func(result fantasy.ToolResultContent) error {
386 var resultContent string
387 isError := false
388 switch result.Result.GetType() {
389 case fantasy.ToolResultContentTypeText:
390 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
391 if ok {
392 resultContent = r.Text
393 }
394 case fantasy.ToolResultContentTypeError:
395 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
396 if ok {
397 isError = true
398 resultContent = r.Error.Error()
399 }
400 case fantasy.ToolResultContentTypeMedia:
401 // TODO: handle this message type
402 }
403 toolResult := message.ToolResult{
404 ToolCallID: result.ToolCallID,
405 Name: result.ToolName,
406 Content: resultContent,
407 IsError: isError,
408 Metadata: result.ClientMetadata,
409 }
410 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
411 Role: message.Tool,
412 Parts: []message.ContentPart{
413 toolResult,
414 },
415 })
416 if createMsgErr != nil {
417 return createMsgErr
418 }
419 return nil
420 },
421 OnStepFinish: func(stepResult fantasy.StepResult) error {
422 finishReason := message.FinishReasonUnknown
423 switch stepResult.FinishReason {
424 case fantasy.FinishReasonLength:
425 finishReason = message.FinishReasonMaxTokens
426 case fantasy.FinishReasonStop:
427 finishReason = message.FinishReasonEndTurn
428 case fantasy.FinishReasonToolCalls:
429 finishReason = message.FinishReasonToolUse
430 }
431 currentAssistant.AddFinish(finishReason, "", "")
432 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
433 sessionLock.Lock()
434 _, sessionErr := a.sessions.Save(genCtx, currentSession)
435 sessionLock.Unlock()
436 if sessionErr != nil {
437 return sessionErr
438 }
439 return a.messages.Update(genCtx, *currentAssistant)
440 },
441 StopWhen: []fantasy.StopCondition{
442 func(_ []fantasy.StepResult) bool {
443 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
444 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
445 remaining := cw - tokens
446 var threshold int64
447 if cw > 200_000 {
448 threshold = 20_000
449 } else {
450 threshold = int64(float64(cw) * 0.2)
451 }
452 if (remaining <= threshold) && !a.disableAutoSummarize {
453 shouldSummarize = true
454 return true
455 }
456 return false
457 },
458 },
459 })
460
461 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
462
463 if err != nil {
464 isCancelErr := errors.Is(err, context.Canceled)
465 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
466 isHookCancelled := errors.Is(err, ErrHookCancellation)
467 if currentAssistant == nil {
468 return result, err
469 }
470 // Ensure we finish thinking on error to close the reasoning state.
471 currentAssistant.FinishThinking()
472 toolCalls := currentAssistant.ToolCalls()
473 // INFO: we use the parent context here because the genCtx has been cancelled.
474 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
475 if createErr != nil {
476 return nil, createErr
477 }
478 for _, tc := range toolCalls {
479 if !tc.Finished {
480 tc.Finished = true
481 tc.Input = "{}"
482 currentAssistant.AddToolCall(tc)
483 updateErr := a.messages.Update(ctx, *currentAssistant)
484 if updateErr != nil {
485 return nil, updateErr
486 }
487 }
488
489 found := false
490 for _, msg := range msgs {
491 if msg.Role == message.Tool {
492 for _, tr := range msg.ToolResults() {
493 if tr.ToolCallID == tc.ID {
494 found = true
495 break
496 }
497 }
498 }
499 if found {
500 break
501 }
502 }
503 if found {
504 continue
505 }
506 content := "There was an error while executing the tool"
507 if isCancelErr {
508 content = "Tool execution canceled by user"
509 } else if isPermissionErr {
510 content = "User denied permission"
511 }
512 toolResult := message.ToolResult{
513 ToolCallID: tc.ID,
514 Name: tc.Name,
515 Content: content,
516 IsError: true,
517 }
518 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
519 Role: message.Tool,
520 Parts: []message.ContentPart{
521 toolResult,
522 },
523 })
524 if createErr != nil {
525 return nil, createErr
526 }
527 }
528 var fantasyErr *fantasy.Error
529 var providerErr *fantasy.ProviderError
530 const defaultTitle = "Provider Error"
531 if isCancelErr {
532 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
533 } else if isPermissionErr {
534 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
535 } else if isHookCancelled {
536 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook canceled request", "")
537 } else if errors.As(err, &providerErr) {
538 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
539 } else if errors.As(err, &fantasyErr) {
540 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
541 } else {
542 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
543 }
544 // Note: we use the parent context here because the genCtx has been
545 // cancelled.
546 updateErr := a.messages.Update(ctx, *currentAssistant)
547 if updateErr != nil {
548 return nil, updateErr
549 }
550 return nil, err
551 }
552 wg.Wait()
553
554 if shouldSummarize {
555 a.activeRequests.Del(call.SessionID)
556 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
557 return nil, summarizeErr
558 }
559 // If the agent wasn't done...
560 if len(currentAssistant.ToolCalls()) > 0 {
561 existing, ok := a.messageQueue.Get(call.SessionID)
562 if !ok {
563 existing = []SessionAgentCall{}
564 }
565 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
566 existing = append(existing, call)
567 a.messageQueue.Set(call.SessionID, existing)
568 }
569 }
570
571 // Release active request before processing queued messages.
572 a.activeRequests.Del(call.SessionID)
573 cancel()
574
575 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
576 if !ok || len(queuedMessages) == 0 {
577 return result, err
578 }
579 // There are queued messages restart the loop.
580 firstQueuedMessage := queuedMessages[0]
581 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
582 return a.Run(ctx, firstQueuedMessage)
583}
584
585func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
586 if a.IsSessionBusy(sessionID) {
587 return ErrSessionBusy
588 }
589
590 currentSession, err := a.sessions.Get(ctx, sessionID)
591 if err != nil {
592 return fmt.Errorf("failed to get session: %w", err)
593 }
594 msgs, err := a.getSessionMessages(ctx, currentSession)
595 if err != nil {
596 return err
597 }
598 if len(msgs) == 0 {
599 // Nothing to summarize.
600 return nil
601 }
602
603 aiMsgs, _ := a.preparePrompt(msgs)
604
605 genCtx, cancel := context.WithCancel(ctx)
606 a.activeRequests.Set(sessionID, cancel)
607 defer a.activeRequests.Del(sessionID)
608 defer cancel()
609
610 agent := fantasy.NewAgent(a.largeModel.Model,
611 fantasy.WithSystemPrompt(string(summaryPrompt)),
612 )
613 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
614 Role: message.Assistant,
615 Model: a.largeModel.Model.Model(),
616 Provider: a.largeModel.Model.Provider(),
617 IsSummaryMessage: true,
618 })
619 if err != nil {
620 return err
621 }
622
623 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
624 Prompt: "Provide a detailed summary of our conversation above.",
625 Messages: aiMsgs,
626 ProviderOptions: opts,
627 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
628 prepared.Messages = options.Messages
629 if a.systemPromptPrefix != "" {
630 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
631 }
632 return callContext, prepared, nil
633 },
634 OnReasoningDelta: func(id string, text string) error {
635 summaryMessage.AppendReasoningContent(text)
636 return a.messages.Update(genCtx, summaryMessage)
637 },
638 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
639 // Handle anthropic signature.
640 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
641 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
642 summaryMessage.AppendReasoningSignature(signature.Signature)
643 }
644 }
645 summaryMessage.FinishThinking()
646 return a.messages.Update(genCtx, summaryMessage)
647 },
648 OnTextDelta: func(id, text string) error {
649 summaryMessage.AppendContent(text)
650 return a.messages.Update(genCtx, summaryMessage)
651 },
652 })
653 if err != nil {
654 isCancelErr := errors.Is(err, context.Canceled)
655 if isCancelErr {
656 // User cancelled summarize we need to remove the summary message.
657 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
658 return deleteErr
659 }
660 return err
661 }
662
663 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
664 err = a.messages.Update(genCtx, summaryMessage)
665 if err != nil {
666 return err
667 }
668
669 var openrouterCost *float64
670 for _, step := range resp.Steps {
671 stepCost := a.openrouterCost(step.ProviderMetadata)
672 if stepCost != nil {
673 newCost := *stepCost
674 if openrouterCost != nil {
675 newCost += *openrouterCost
676 }
677 openrouterCost = &newCost
678 }
679 }
680
681 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
682
683 // Just in case, get just the last usage info.
684 usage := resp.Response.Usage
685 currentSession.SummaryMessageID = summaryMessage.ID
686 currentSession.CompletionTokens = usage.OutputTokens
687 currentSession.PromptTokens = 0
688 _, err = a.sessions.Save(genCtx, currentSession)
689 return err
690}
691
692func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
693 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
694 return fantasy.ProviderOptions{}
695 }
696 return fantasy.ProviderOptions{
697 anthropic.Name: &anthropic.ProviderCacheControlOptions{
698 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
699 },
700 bedrock.Name: &anthropic.ProviderCacheControlOptions{
701 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
702 },
703 }
704}
705
706func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
707 var attachmentParts []message.ContentPart
708 for _, attachment := range call.Attachments {
709 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
710 }
711 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
712 parts = append(parts, attachmentParts...)
713 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
714 Role: message.User,
715 Parts: parts,
716 })
717 if err != nil {
718 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
719 }
720 return msg, nil
721}
722
723func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
724 var history []fantasy.Message
725 for _, m := range msgs {
726 if len(m.Parts) == 0 {
727 continue
728 }
729 // Assistant message without content or tool calls (cancelled before it
730 // returned anything).
731 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
732 continue
733 }
734 history = append(history, m.ToAIMessage()...)
735 }
736
737 var files []fantasy.FilePart
738 for _, attachment := range attachments {
739 files = append(files, fantasy.FilePart{
740 Filename: attachment.FileName,
741 Data: attachment.Content,
742 MediaType: attachment.MimeType,
743 })
744 }
745
746 return history, files
747}
748
749func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
750 msgs, err := a.messages.List(ctx, session.ID)
751 if err != nil {
752 return nil, fmt.Errorf("failed to list messages: %w", err)
753 }
754
755 if session.SummaryMessageID != "" {
756 summaryMsgInex := -1
757 for i, msg := range msgs {
758 if msg.ID == session.SummaryMessageID {
759 summaryMsgInex = i
760 break
761 }
762 }
763 if summaryMsgInex != -1 {
764 msgs = msgs[summaryMsgInex:]
765 msgs[0].Role = message.User
766 }
767 }
768 return msgs, nil
769}
770
771func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
772 if prompt == "" {
773 return
774 }
775
776 var maxOutput int64 = 40
777 if a.smallModel.CatwalkCfg.CanReason {
778 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
779 }
780
781 agent := fantasy.NewAgent(a.smallModel.Model,
782 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
783 fantasy.WithMaxOutputTokens(maxOutput),
784 )
785
786 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
787 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
788 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
789 prepared.Messages = options.Messages
790 if a.systemPromptPrefix != "" {
791 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
792 }
793 return callContext, prepared, nil
794 },
795 })
796 if err != nil {
797 slog.Error("error generating title", "err", err)
798 return
799 }
800
801 title := resp.Response.Content.Text()
802
803 title = strings.ReplaceAll(title, "\n", " ")
804
805 // Remove thinking tags if present.
806 if idx := strings.Index(title, "</think>"); idx > 0 {
807 title = title[idx+len("</think>"):]
808 }
809
810 title = strings.TrimSpace(title)
811 if title == "" {
812 slog.Warn("failed to generate title", "warn", "empty title")
813 return
814 }
815
816 session.Title = title
817
818 var openrouterCost *float64
819 for _, step := range resp.Steps {
820 stepCost := a.openrouterCost(step.ProviderMetadata)
821 if stepCost != nil {
822 newCost := *stepCost
823 if openrouterCost != nil {
824 newCost += *openrouterCost
825 }
826 openrouterCost = &newCost
827 }
828 }
829
830 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
831 _, saveErr := a.sessions.Save(ctx, *session)
832 if saveErr != nil {
833 slog.Error("failed to save session title & usage", "error", saveErr)
834 return
835 }
836}
837
838func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
839 openrouterMetadata, ok := metadata[openrouter.Name]
840 if !ok {
841 return nil
842 }
843
844 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
845 if !ok {
846 return nil
847 }
848 return &opts.Usage.Cost
849}
850
851func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
852 modelConfig := model.CatwalkCfg
853 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
854 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
855 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
856 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
857
858 a.eventTokensUsed(session.ID, model, usage, cost)
859
860 if overrideCost != nil {
861 session.Cost += *overrideCost
862 } else {
863 session.Cost += cost
864 }
865
866 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
867 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
868}
869
870func (a *sessionAgent) Cancel(sessionID string) {
871 // Cancel regular requests.
872 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
873 slog.Info("Request cancellation initiated", "session_id", sessionID)
874 cancel()
875 }
876
877 // Also check for summarize requests.
878 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
879 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
880 cancel()
881 }
882
883 if a.QueuedPrompts(sessionID) > 0 {
884 slog.Info("Clearing queued prompts", "session_id", sessionID)
885 a.messageQueue.Del(sessionID)
886 }
887}
888
889func (a *sessionAgent) ClearQueue(sessionID string) {
890 if a.QueuedPrompts(sessionID) > 0 {
891 slog.Info("Clearing queued prompts", "session_id", sessionID)
892 a.messageQueue.Del(sessionID)
893 }
894}
895
896func (a *sessionAgent) CancelAll() {
897 if !a.IsBusy() {
898 return
899 }
900 for key := range a.activeRequests.Seq2() {
901 a.Cancel(key) // key is sessionID
902 }
903
904 timeout := time.After(5 * time.Second)
905 for a.IsBusy() {
906 select {
907 case <-timeout:
908 return
909 default:
910 time.Sleep(200 * time.Millisecond)
911 }
912 }
913}
914
915func (a *sessionAgent) IsBusy() bool {
916 var busy bool
917 for cancelFunc := range a.activeRequests.Seq() {
918 if cancelFunc != nil {
919 busy = true
920 break
921 }
922 }
923 return busy
924}
925
926func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
927 _, busy := a.activeRequests.Get(sessionID)
928 return busy
929}
930
931func (a *sessionAgent) QueuedPrompts(sessionID string) int {
932 l, ok := a.messageQueue.Get(sessionID)
933 if !ok {
934 return 0
935 }
936 return len(l)
937}
938
939func (a *sessionAgent) SetModels(large Model, small Model) {
940 a.largeModel = large
941 a.smallModel = small
942}
943
944func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
945 a.tools = tools
946}
947
948func (a *sessionAgent) Model() Model {
949 return a.largeModel
950}
951
952func (a *sessionAgent) executeUserPromptSubmitHook(
953 ctx context.Context,
954 sessionID string,
955 prompt string,
956) ([]message.HookOutput, error) {
957 if a.hooks == nil {
958 return nil, nil
959 }
960 return a.hooks.Execute(ctx, hooks.HookContext{
961 EventType: config.UserPromptSubmit,
962 SessionID: sessionID,
963 UserPrompt: prompt,
964 Provider: a.largeModel.ModelCfg.Provider,
965 Model: a.largeModel.ModelCfg.Model,
966 })
967}