1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "git.secluded.site/crush/internal/agent/tools"
30 "git.secluded.site/crush/internal/config"
31 "git.secluded.site/crush/internal/csync"
32 "git.secluded.site/crush/internal/message"
33 "git.secluded.site/crush/internal/notification"
34 "git.secluded.site/crush/internal/permission"
35 "git.secluded.site/crush/internal/session"
36 "git.secluded.site/crush/internal/stringext"
37 "github.com/charmbracelet/catwalk/pkg/catwalk"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57}
58
59type SessionAgent interface {
60 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
61 SetModels(large Model, small Model)
62 SetTools(tools []fantasy.AgentTool)
63 Cancel(sessionID string)
64 CancelAll()
65 IsSessionBusy(sessionID string) bool
66 IsBusy() bool
67 QueuedPrompts(sessionID string) int
68 ClearQueue(sessionID string)
69 Summarize(context.Context, string, fantasy.ProviderOptions) error
70 Model() Model
71 // CancelCompletionNotification cancels any scheduled "turn ended"
72 // notification for the provided sessionID.
73 CancelCompletionNotification(sessionID string)
74 // HasPendingCompletionNotification returns true if a turn-end
75 // notification has been scheduled and not yet cancelled/shown.
76 HasPendingCompletionNotification(sessionID string) bool
77}
78
79const completionNotificationDelay = 5 * time.Second
80
81type Model struct {
82 Model fantasy.LanguageModel
83 CatwalkCfg catwalk.Model
84 ModelCfg config.SelectedModel
85}
86
87type sessionAgent struct {
88 largeModel Model
89 smallModel Model
90 systemPromptPrefix string
91 systemPrompt string
92 tools []fantasy.AgentTool
93 sessions session.Service
94 messages message.Service
95 disableAutoSummarize bool
96 isYolo bool
97
98 messageQueue *csync.Map[string, []SessionAgentCall]
99 activeRequests *csync.Map[string, context.CancelFunc]
100 notifier *notification.Notifier
101 notifyCtx context.Context
102 completionCancels *csync.Map[string, context.CancelFunc]
103}
104
105type SessionAgentOptions struct {
106 LargeModel Model
107 SmallModel Model
108 SystemPromptPrefix string
109 SystemPrompt string
110 DisableAutoSummarize bool
111 IsYolo bool
112 Sessions session.Service
113 Messages message.Service
114 Tools []fantasy.AgentTool
115 Notifier *notification.Notifier
116 NotificationCtx context.Context
117}
118
119func NewSessionAgent(
120 opts SessionAgentOptions,
121) SessionAgent {
122 notifyCtx := opts.NotificationCtx
123 if notifyCtx == nil {
124 notifyCtx = context.Background()
125 }
126
127 return &sessionAgent{
128 largeModel: opts.LargeModel,
129 smallModel: opts.SmallModel,
130 systemPromptPrefix: opts.SystemPromptPrefix,
131 systemPrompt: opts.SystemPrompt,
132 sessions: opts.Sessions,
133 messages: opts.Messages,
134 disableAutoSummarize: opts.DisableAutoSummarize,
135 tools: opts.Tools,
136 isYolo: opts.IsYolo,
137 messageQueue: csync.NewMap[string, []SessionAgentCall](),
138 activeRequests: csync.NewMap[string, context.CancelFunc](),
139 notifier: opts.Notifier,
140 notifyCtx: notifyCtx,
141 completionCancels: csync.NewMap[string, context.CancelFunc](),
142 }
143}
144
145func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
146 if call.Prompt == "" {
147 return nil, ErrEmptyPrompt
148 }
149 if call.SessionID == "" {
150 return nil, ErrSessionMissing
151 }
152
153 a.cancelCompletionNotification(call.SessionID)
154
155 // Queue the message if busy
156 if a.IsSessionBusy(call.SessionID) {
157 existing, ok := a.messageQueue.Get(call.SessionID)
158 if !ok {
159 existing = []SessionAgentCall{}
160 }
161 existing = append(existing, call)
162 a.messageQueue.Set(call.SessionID, existing)
163 return nil, nil
164 }
165
166 if len(a.tools) > 0 {
167 // Add Anthropic caching to the last tool.
168 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
169 }
170
171 agent := fantasy.NewAgent(
172 a.largeModel.Model,
173 fantasy.WithSystemPrompt(a.systemPrompt),
174 fantasy.WithTools(a.tools...),
175 )
176
177 sessionLock := sync.Mutex{}
178 currentSession, err := a.sessions.Get(ctx, call.SessionID)
179 if err != nil {
180 return nil, fmt.Errorf("failed to get session: %w", err)
181 }
182
183 msgs, err := a.getSessionMessages(ctx, currentSession)
184 if err != nil {
185 return nil, fmt.Errorf("failed to get session messages: %w", err)
186 }
187
188 var wg sync.WaitGroup
189 // Generate title if first message.
190 if len(msgs) == 0 {
191 wg.Go(func() {
192 sessionLock.Lock()
193 a.generateTitle(ctx, ¤tSession, call.Prompt)
194 sessionLock.Unlock()
195 })
196 }
197
198 // Add the user message to the session.
199 _, err = a.createUserMessage(ctx, call)
200 if err != nil {
201 return nil, err
202 }
203
204 // Add the session to the context.
205 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
206
207 genCtx, cancel := context.WithCancel(ctx)
208 a.activeRequests.Set(call.SessionID, cancel)
209
210 defer cancel()
211 defer a.activeRequests.Del(call.SessionID)
212
213 history, files := a.preparePrompt(msgs, call.Attachments...)
214
215 startTime := time.Now()
216 a.eventPromptSent(call.SessionID)
217
218 var currentAssistant *message.Message
219 var shouldSummarize bool
220 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
221 Prompt: call.Prompt,
222 Files: files,
223 Messages: history,
224 ProviderOptions: call.ProviderOptions,
225 MaxOutputTokens: &call.MaxOutputTokens,
226 TopP: call.TopP,
227 Temperature: call.Temperature,
228 PresencePenalty: call.PresencePenalty,
229 TopK: call.TopK,
230 FrequencyPenalty: call.FrequencyPenalty,
231 // Before each step create a new assistant message.
232 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
233 prepared.Messages = options.Messages
234 // Reset all cached items.
235 for i := range prepared.Messages {
236 prepared.Messages[i].ProviderOptions = nil
237 }
238
239 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
240 a.messageQueue.Del(call.SessionID)
241 for _, queued := range queuedCalls {
242 userMessage, createErr := a.createUserMessage(callContext, queued)
243 if createErr != nil {
244 return callContext, prepared, createErr
245 }
246 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
247 }
248
249 lastSystemRoleInx := 0
250 systemMessageUpdated := false
251 for i, msg := range prepared.Messages {
252 // Only add cache control to the last message.
253 if msg.Role == fantasy.MessageRoleSystem {
254 lastSystemRoleInx = i
255 } else if !systemMessageUpdated {
256 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
257 systemMessageUpdated = true
258 }
259 // Than add cache control to the last 2 messages.
260 if i > len(prepared.Messages)-3 {
261 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
262 }
263 }
264
265 if a.systemPromptPrefix != "" {
266 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
267 }
268
269 var assistantMsg message.Message
270 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
271 Role: message.Assistant,
272 Parts: []message.ContentPart{},
273 Model: a.largeModel.ModelCfg.Model,
274 Provider: a.largeModel.ModelCfg.Provider,
275 })
276 if err != nil {
277 return callContext, prepared, err
278 }
279 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
280 currentAssistant = &assistantMsg
281 return callContext, prepared, err
282 },
283 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
284 currentAssistant.AppendReasoningContent(reasoning.Text)
285 return a.messages.Update(genCtx, *currentAssistant)
286 },
287 OnReasoningDelta: func(id string, text string) error {
288 currentAssistant.AppendReasoningContent(text)
289 return a.messages.Update(genCtx, *currentAssistant)
290 },
291 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
292 // handle anthropic signature
293 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
294 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
295 currentAssistant.AppendReasoningSignature(reasoning.Signature)
296 }
297 }
298 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
299 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
300 currentAssistant.AppendReasoningSignature(reasoning.Signature)
301 }
302 }
303 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
304 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
305 currentAssistant.SetReasoningResponsesData(reasoning)
306 }
307 }
308 currentAssistant.FinishThinking()
309 return a.messages.Update(genCtx, *currentAssistant)
310 },
311 OnTextDelta: func(id string, text string) error {
312 // Strip leading newline from initial text content. This is is
313 // particularly important in non-interactive mode where leading
314 // newlines are very visible.
315 if len(currentAssistant.Parts) == 0 {
316 text = strings.TrimPrefix(text, "\n")
317 }
318
319 currentAssistant.AppendContent(text)
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 OnToolInputStart: func(id string, toolName string) error {
323 toolCall := message.ToolCall{
324 ID: id,
325 Name: toolName,
326 ProviderExecuted: false,
327 Finished: false,
328 }
329 currentAssistant.AddToolCall(toolCall)
330 return a.messages.Update(genCtx, *currentAssistant)
331 },
332 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
333 // TODO: implement
334 },
335 OnToolCall: func(tc fantasy.ToolCallContent) error {
336 toolCall := message.ToolCall{
337 ID: tc.ToolCallID,
338 Name: tc.ToolName,
339 Input: tc.Input,
340 ProviderExecuted: false,
341 Finished: true,
342 }
343 currentAssistant.AddToolCall(toolCall)
344 return a.messages.Update(genCtx, *currentAssistant)
345 },
346 OnToolResult: func(result fantasy.ToolResultContent) error {
347 var resultContent string
348 isError := false
349 switch result.Result.GetType() {
350 case fantasy.ToolResultContentTypeText:
351 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
352 if ok {
353 resultContent = r.Text
354 }
355 case fantasy.ToolResultContentTypeError:
356 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
357 if ok {
358 isError = true
359 resultContent = r.Error.Error()
360 }
361 case fantasy.ToolResultContentTypeMedia:
362 // TODO: handle this message type
363 }
364 toolResult := message.ToolResult{
365 ToolCallID: result.ToolCallID,
366 Name: result.ToolName,
367 Content: resultContent,
368 IsError: isError,
369 Metadata: result.ClientMetadata,
370 }
371 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
372 Role: message.Tool,
373 Parts: []message.ContentPart{
374 toolResult,
375 },
376 })
377 if createMsgErr != nil {
378 return createMsgErr
379 }
380 return nil
381 },
382 OnStepFinish: func(stepResult fantasy.StepResult) error {
383 finishReason := message.FinishReasonUnknown
384 switch stepResult.FinishReason {
385 case fantasy.FinishReasonLength:
386 finishReason = message.FinishReasonMaxTokens
387 case fantasy.FinishReasonStop:
388 finishReason = message.FinishReasonEndTurn
389 case fantasy.FinishReasonToolCalls:
390 finishReason = message.FinishReasonToolUse
391 }
392 currentAssistant.AddFinish(finishReason, "", "")
393 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
394 sessionLock.Lock()
395 _, sessionErr := a.sessions.Save(genCtx, currentSession)
396 sessionLock.Unlock()
397 if sessionErr != nil {
398 return sessionErr
399 }
400 if finishReason == message.FinishReasonEndTurn {
401 a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
402 }
403 return a.messages.Update(genCtx, *currentAssistant)
404 },
405 StopWhen: []fantasy.StopCondition{
406 func(_ []fantasy.StepResult) bool {
407 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
408 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
409 remaining := cw - tokens
410 var threshold int64
411 if cw > 200_000 {
412 threshold = 20_000
413 } else {
414 threshold = int64(float64(cw) * 0.2)
415 }
416 if (remaining <= threshold) && !a.disableAutoSummarize {
417 shouldSummarize = true
418 return true
419 }
420 return false
421 },
422 },
423 })
424
425 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
426
427 if err != nil {
428 isCancelErr := errors.Is(err, context.Canceled)
429 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
430 if currentAssistant == nil {
431 return result, err
432 }
433 // Ensure we finish thinking on error to close the reasoning state.
434 currentAssistant.FinishThinking()
435 toolCalls := currentAssistant.ToolCalls()
436 // INFO: we use the parent context here because the genCtx has been cancelled.
437 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
438 if createErr != nil {
439 return nil, createErr
440 }
441 for _, tc := range toolCalls {
442 if !tc.Finished {
443 tc.Finished = true
444 tc.Input = "{}"
445 currentAssistant.AddToolCall(tc)
446 updateErr := a.messages.Update(ctx, *currentAssistant)
447 if updateErr != nil {
448 return nil, updateErr
449 }
450 }
451
452 found := false
453 for _, msg := range msgs {
454 if msg.Role == message.Tool {
455 for _, tr := range msg.ToolResults() {
456 if tr.ToolCallID == tc.ID {
457 found = true
458 break
459 }
460 }
461 }
462 if found {
463 break
464 }
465 }
466 if found {
467 continue
468 }
469 content := "There was an error while executing the tool"
470 if isCancelErr {
471 content = "Tool execution canceled by user"
472 } else if isPermissionErr {
473 content = "User denied permission"
474 }
475 toolResult := message.ToolResult{
476 ToolCallID: tc.ID,
477 Name: tc.Name,
478 Content: content,
479 IsError: true,
480 }
481 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
482 Role: message.Tool,
483 Parts: []message.ContentPart{
484 toolResult,
485 },
486 })
487 if createErr != nil {
488 return nil, createErr
489 }
490 }
491 var fantasyErr *fantasy.Error
492 var providerErr *fantasy.ProviderError
493 const defaultTitle = "Provider Error"
494 if isCancelErr {
495 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
496 } else if isPermissionErr {
497 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
498 } else if errors.As(err, &providerErr) {
499 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
500 } else if errors.As(err, &fantasyErr) {
501 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
502 } else {
503 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
504 }
505 // Note: we use the parent context here because the genCtx has been
506 // cancelled.
507 updateErr := a.messages.Update(ctx, *currentAssistant)
508 if updateErr != nil {
509 return nil, updateErr
510 }
511 return nil, err
512 }
513 wg.Wait()
514
515 if shouldSummarize {
516 a.cancelCompletionNotification(call.SessionID)
517 a.activeRequests.Del(call.SessionID)
518 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
519 return nil, summarizeErr
520 }
521 // If the agent wasn't done...
522 if len(currentAssistant.ToolCalls()) > 0 {
523 existing, ok := a.messageQueue.Get(call.SessionID)
524 if !ok {
525 existing = []SessionAgentCall{}
526 }
527 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
528 existing = append(existing, call)
529 a.messageQueue.Set(call.SessionID, existing)
530 }
531 }
532
533 // Release active request before processing queued messages.
534 a.activeRequests.Del(call.SessionID)
535 cancel()
536
537 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
538 if !ok || len(queuedMessages) == 0 {
539 return result, err
540 }
541 // There are queued messages restart the loop.
542 firstQueuedMessage := queuedMessages[0]
543 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
544 return a.Run(ctx, firstQueuedMessage)
545}
546
547func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
548 // Do not emit notifications for Agent-tool sub-sessions.
549 if a.sessions != nil && a.sessions.IsAgentToolSession(sessionID) {
550 return
551 }
552 if a.notifier == nil {
553 return
554 }
555
556 if sessionTitle == "" {
557 sessionTitle = sessionID
558 }
559
560 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
561 cancel()
562 }
563
564 title := "💘 Crush is waiting"
565 message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
566 cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
567 if cancel == nil {
568 cancel = func() {}
569 }
570 a.completionCancels.Set(sessionID, cancel)
571}
572
573func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
574 if a.notifier == nil {
575 return
576 }
577
578 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
579 cancel()
580 }
581}
582
583// CancelCompletionNotification implements SessionAgent.
584func (a *sessionAgent) CancelCompletionNotification(sessionID string) {
585 a.cancelCompletionNotification(sessionID)
586}
587
588// HasPendingCompletionNotification implements SessionAgent.
589func (a *sessionAgent) HasPendingCompletionNotification(sessionID string) bool {
590 if a.IsSessionBusy(sessionID) {
591 return false
592 }
593 _, ok := a.completionCancels.Get(sessionID)
594 return ok
595}
596
597func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
598 if a.IsSessionBusy(sessionID) {
599 return ErrSessionBusy
600 }
601
602 currentSession, err := a.sessions.Get(ctx, sessionID)
603 if err != nil {
604 return fmt.Errorf("failed to get session: %w", err)
605 }
606 msgs, err := a.getSessionMessages(ctx, currentSession)
607 if err != nil {
608 return err
609 }
610 if len(msgs) == 0 {
611 // Nothing to summarize.
612 return nil
613 }
614
615 aiMsgs, _ := a.preparePrompt(msgs)
616
617 genCtx, cancel := context.WithCancel(ctx)
618 a.activeRequests.Set(sessionID, cancel)
619 defer a.activeRequests.Del(sessionID)
620 defer cancel()
621
622 agent := fantasy.NewAgent(a.largeModel.Model,
623 fantasy.WithSystemPrompt(string(summaryPrompt)),
624 )
625 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
626 Role: message.Assistant,
627 Model: a.largeModel.Model.Model(),
628 Provider: a.largeModel.Model.Provider(),
629 IsSummaryMessage: true,
630 })
631 if err != nil {
632 return err
633 }
634
635 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
636 Prompt: "Provide a detailed summary of our conversation above.",
637 Messages: aiMsgs,
638 ProviderOptions: opts,
639 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
640 prepared.Messages = options.Messages
641 if a.systemPromptPrefix != "" {
642 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
643 }
644 return callContext, prepared, nil
645 },
646 OnReasoningDelta: func(id string, text string) error {
647 summaryMessage.AppendReasoningContent(text)
648 return a.messages.Update(genCtx, summaryMessage)
649 },
650 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
651 // Handle anthropic signature.
652 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
653 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
654 summaryMessage.AppendReasoningSignature(signature.Signature)
655 }
656 }
657 summaryMessage.FinishThinking()
658 return a.messages.Update(genCtx, summaryMessage)
659 },
660 OnTextDelta: func(id, text string) error {
661 summaryMessage.AppendContent(text)
662 return a.messages.Update(genCtx, summaryMessage)
663 },
664 })
665 if err != nil {
666 isCancelErr := errors.Is(err, context.Canceled)
667 if isCancelErr {
668 // User cancelled summarize we need to remove the summary message.
669 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
670 return deleteErr
671 }
672 return err
673 }
674
675 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
676 err = a.messages.Update(genCtx, summaryMessage)
677 if err != nil {
678 return err
679 }
680
681 var openrouterCost *float64
682 for _, step := range resp.Steps {
683 stepCost := a.openrouterCost(step.ProviderMetadata)
684 if stepCost != nil {
685 newCost := *stepCost
686 if openrouterCost != nil {
687 newCost += *openrouterCost
688 }
689 openrouterCost = &newCost
690 }
691 }
692
693 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
694
695 // Just in case, get just the last usage info.
696 usage := resp.Response.Usage
697 currentSession.SummaryMessageID = summaryMessage.ID
698 currentSession.CompletionTokens = usage.OutputTokens
699 currentSession.PromptTokens = 0
700 _, err = a.sessions.Save(genCtx, currentSession)
701 return err
702}
703
704func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
705 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
706 return fantasy.ProviderOptions{}
707 }
708 return fantasy.ProviderOptions{
709 anthropic.Name: &anthropic.ProviderCacheControlOptions{
710 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
711 },
712 bedrock.Name: &anthropic.ProviderCacheControlOptions{
713 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
714 },
715 }
716}
717
718func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
719 var attachmentParts []message.ContentPart
720 for _, attachment := range call.Attachments {
721 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
722 }
723 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
724 parts = append(parts, attachmentParts...)
725 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
726 Role: message.User,
727 Parts: parts,
728 })
729 if err != nil {
730 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
731 }
732 return msg, nil
733}
734
735func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
736 var history []fantasy.Message
737 for _, m := range msgs {
738 if len(m.Parts) == 0 {
739 continue
740 }
741 // Assistant message without content or tool calls (cancelled before it
742 // returned anything).
743 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
744 continue
745 }
746 history = append(history, m.ToAIMessage()...)
747 }
748
749 var files []fantasy.FilePart
750 for _, attachment := range attachments {
751 files = append(files, fantasy.FilePart{
752 Filename: attachment.FileName,
753 Data: attachment.Content,
754 MediaType: attachment.MimeType,
755 })
756 }
757
758 return history, files
759}
760
761func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
762 msgs, err := a.messages.List(ctx, session.ID)
763 if err != nil {
764 return nil, fmt.Errorf("failed to list messages: %w", err)
765 }
766
767 if session.SummaryMessageID != "" {
768 summaryMsgInex := -1
769 for i, msg := range msgs {
770 if msg.ID == session.SummaryMessageID {
771 summaryMsgInex = i
772 break
773 }
774 }
775 if summaryMsgInex != -1 {
776 msgs = msgs[summaryMsgInex:]
777 msgs[0].Role = message.User
778 }
779 }
780 return msgs, nil
781}
782
783func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
784 if prompt == "" {
785 return
786 }
787
788 var maxOutput int64 = 40
789 if a.smallModel.CatwalkCfg.CanReason {
790 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
791 }
792
793 agent := fantasy.NewAgent(a.smallModel.Model,
794 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
795 fantasy.WithMaxOutputTokens(maxOutput),
796 )
797
798 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
799 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
800 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
801 prepared.Messages = options.Messages
802 if a.systemPromptPrefix != "" {
803 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
804 }
805 return callContext, prepared, nil
806 },
807 })
808 if err != nil {
809 slog.Error("error generating title", "err", err)
810 return
811 }
812
813 title := resp.Response.Content.Text()
814
815 title = strings.ReplaceAll(title, "\n", " ")
816
817 // Remove thinking tags if present.
818 if idx := strings.Index(title, "</think>"); idx > 0 {
819 title = title[idx+len("</think>"):]
820 }
821
822 title = strings.TrimSpace(title)
823 if title == "" {
824 slog.Warn("failed to generate title", "warn", "empty title")
825 return
826 }
827
828 session.Title = title
829
830 var openrouterCost *float64
831 for _, step := range resp.Steps {
832 stepCost := a.openrouterCost(step.ProviderMetadata)
833 if stepCost != nil {
834 newCost := *stepCost
835 if openrouterCost != nil {
836 newCost += *openrouterCost
837 }
838 openrouterCost = &newCost
839 }
840 }
841
842 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
843 _, saveErr := a.sessions.Save(ctx, *session)
844 if saveErr != nil {
845 slog.Error("failed to save session title & usage", "error", saveErr)
846 return
847 }
848}
849
850func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
851 openrouterMetadata, ok := metadata[openrouter.Name]
852 if !ok {
853 return nil
854 }
855
856 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
857 if !ok {
858 return nil
859 }
860 return &opts.Usage.Cost
861}
862
863func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
864 modelConfig := model.CatwalkCfg
865 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
866 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
867 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
868 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
869
870 a.eventTokensUsed(session.ID, model, usage, cost)
871
872 if overrideCost != nil {
873 session.Cost += *overrideCost
874 } else {
875 session.Cost += cost
876 }
877
878 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
879 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
880}
881
882func (a *sessionAgent) Cancel(sessionID string) {
883 // Cancel regular requests.
884 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
885 slog.Info("Request cancellation initiated", "session_id", sessionID)
886 cancel()
887 }
888
889 // Also check for summarize requests.
890 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
891 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
892 cancel()
893 }
894
895 if a.QueuedPrompts(sessionID) > 0 {
896 slog.Info("Clearing queued prompts", "session_id", sessionID)
897 a.messageQueue.Del(sessionID)
898 }
899}
900
901func (a *sessionAgent) ClearQueue(sessionID string) {
902 if a.QueuedPrompts(sessionID) > 0 {
903 slog.Info("Clearing queued prompts", "session_id", sessionID)
904 a.messageQueue.Del(sessionID)
905 }
906}
907
908func (a *sessionAgent) CancelAll() {
909 if !a.IsBusy() {
910 // still ensure notifications are cancelled even when not busy
911 for cancel := range a.completionCancels.Seq() {
912 if cancel != nil {
913 cancel()
914 }
915 }
916 a.completionCancels.Reset(make(map[string]context.CancelFunc))
917 return
918 }
919 for key := range a.activeRequests.Seq2() {
920 a.Cancel(key) // key is sessionID
921 }
922
923 timeout := time.After(5 * time.Second)
924 for a.IsBusy() {
925 select {
926 case <-timeout:
927 return
928 default:
929 time.Sleep(200 * time.Millisecond)
930 }
931 }
932
933 for cancel := range a.completionCancels.Seq() {
934 if cancel != nil {
935 cancel()
936 }
937 }
938 a.completionCancels.Reset(make(map[string]context.CancelFunc))
939}
940
941func (a *sessionAgent) IsBusy() bool {
942 var busy bool
943 for cancelFunc := range a.activeRequests.Seq() {
944 if cancelFunc != nil {
945 busy = true
946 break
947 }
948 }
949 return busy
950}
951
952func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
953 _, busy := a.activeRequests.Get(sessionID)
954 return busy
955}
956
957func (a *sessionAgent) QueuedPrompts(sessionID string) int {
958 l, ok := a.messageQueue.Get(sessionID)
959 if !ok {
960 return 0
961 }
962 return len(l)
963}
964
965func (a *sessionAgent) SetModels(large Model, small Model) {
966 a.largeModel = large
967 a.smallModel = small
968}
969
970func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
971 a.tools = tools
972}
973
974func (a *sessionAgent) Model() Model {
975 return a.largeModel
976}