1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "context"
12 _ "embed"
13 "errors"
14 "fmt"
15 "log/slog"
16 "os"
17 "strconv"
18 "strings"
19 "sync"
20 "time"
21
22 "charm.land/fantasy"
23 "charm.land/fantasy/providers/anthropic"
24 "charm.land/fantasy/providers/bedrock"
25 "charm.land/fantasy/providers/google"
26 "charm.land/fantasy/providers/openai"
27 "charm.land/fantasy/providers/openrouter"
28 "git.secluded.site/crush/internal/agent/tools"
29 "git.secluded.site/crush/internal/config"
30 "git.secluded.site/crush/internal/csync"
31 "git.secluded.site/crush/internal/message"
32 "git.secluded.site/crush/internal/notification"
33 "git.secluded.site/crush/internal/permission"
34 "git.secluded.site/crush/internal/session"
35 "github.com/charmbracelet/catwalk/pkg/catwalk"
36)
37
38//go:embed templates/title.md
39var titlePrompt []byte
40
41//go:embed templates/summary.md
42var summaryPrompt []byte
43
44type SessionAgentCall struct {
45 SessionID string
46 Prompt string
47 ProviderOptions fantasy.ProviderOptions
48 Attachments []message.Attachment
49 MaxOutputTokens int64
50 Temperature *float64
51 TopP *float64
52 TopK *int64
53 FrequencyPenalty *float64
54 PresencePenalty *float64
55}
56
57type SessionAgent interface {
58 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
59 SetModels(large Model, small Model)
60 SetTools(tools []fantasy.AgentTool)
61 Cancel(sessionID string)
62 CancelAll()
63 IsSessionBusy(sessionID string) bool
64 IsBusy() bool
65 QueuedPrompts(sessionID string) int
66 ClearQueue(sessionID string)
67 Summarize(context.Context, string, fantasy.ProviderOptions) error
68 Model() Model
69 // CancelCompletionNotification cancels any scheduled "turn ended"
70 // notification for the provided sessionID.
71 CancelCompletionNotification(sessionID string)
72 // HasPendingCompletionNotification returns true if a turn-end
73 // notification has been scheduled and not yet cancelled/shown.
74 HasPendingCompletionNotification(sessionID string) bool
75}
76
77const completionNotificationDelay = 5 * time.Second
78
79type Model struct {
80 Model fantasy.LanguageModel
81 CatwalkCfg catwalk.Model
82 ModelCfg config.SelectedModel
83}
84
85type sessionAgent struct {
86 largeModel Model
87 smallModel Model
88 systemPromptPrefix string
89 systemPrompt string
90 tools []fantasy.AgentTool
91 sessions session.Service
92 messages message.Service
93 disableAutoSummarize bool
94 isYolo bool
95
96 messageQueue *csync.Map[string, []SessionAgentCall]
97 activeRequests *csync.Map[string, context.CancelFunc]
98 notifier *notification.Notifier
99 notifyCtx context.Context
100 completionCancels *csync.Map[string, context.CancelFunc]
101}
102
103type SessionAgentOptions struct {
104 LargeModel Model
105 SmallModel Model
106 SystemPromptPrefix string
107 SystemPrompt string
108 DisableAutoSummarize bool
109 IsYolo bool
110 Sessions session.Service
111 Messages message.Service
112 Tools []fantasy.AgentTool
113 Notifier *notification.Notifier
114 NotificationCtx context.Context
115}
116
117func NewSessionAgent(
118 opts SessionAgentOptions,
119) SessionAgent {
120 notifyCtx := opts.NotificationCtx
121 if notifyCtx == nil {
122 notifyCtx = context.Background()
123 }
124
125 return &sessionAgent{
126 largeModel: opts.LargeModel,
127 smallModel: opts.SmallModel,
128 systemPromptPrefix: opts.SystemPromptPrefix,
129 systemPrompt: opts.SystemPrompt,
130 sessions: opts.Sessions,
131 messages: opts.Messages,
132 disableAutoSummarize: opts.DisableAutoSummarize,
133 tools: opts.Tools,
134 isYolo: opts.IsYolo,
135 messageQueue: csync.NewMap[string, []SessionAgentCall](),
136 activeRequests: csync.NewMap[string, context.CancelFunc](),
137 notifier: opts.Notifier,
138 notifyCtx: notifyCtx,
139 completionCancels: csync.NewMap[string, context.CancelFunc](),
140 }
141}
142
143func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
144 if call.Prompt == "" {
145 return nil, ErrEmptyPrompt
146 }
147 if call.SessionID == "" {
148 return nil, ErrSessionMissing
149 }
150
151 a.cancelCompletionNotification(call.SessionID)
152
153 // Queue the message if busy
154 if a.IsSessionBusy(call.SessionID) {
155 existing, ok := a.messageQueue.Get(call.SessionID)
156 if !ok {
157 existing = []SessionAgentCall{}
158 }
159 existing = append(existing, call)
160 a.messageQueue.Set(call.SessionID, existing)
161 return nil, nil
162 }
163
164 if len(a.tools) > 0 {
165 // Add Anthropic caching to the last tool.
166 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
167 }
168
169 agent := fantasy.NewAgent(
170 a.largeModel.Model,
171 fantasy.WithSystemPrompt(a.systemPrompt),
172 fantasy.WithTools(a.tools...),
173 )
174
175 sessionLock := sync.Mutex{}
176 currentSession, err := a.sessions.Get(ctx, call.SessionID)
177 if err != nil {
178 return nil, fmt.Errorf("failed to get session: %w", err)
179 }
180
181 msgs, err := a.getSessionMessages(ctx, currentSession)
182 if err != nil {
183 return nil, fmt.Errorf("failed to get session messages: %w", err)
184 }
185
186 var wg sync.WaitGroup
187 // Generate title if first message.
188 if len(msgs) == 0 {
189 wg.Go(func() {
190 sessionLock.Lock()
191 a.generateTitle(ctx, ¤tSession, call.Prompt)
192 sessionLock.Unlock()
193 })
194 }
195
196 // Add the user message to the session.
197 _, err = a.createUserMessage(ctx, call)
198 if err != nil {
199 return nil, err
200 }
201
202 // Add the session to the context.
203 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
204
205 genCtx, cancel := context.WithCancel(ctx)
206 a.activeRequests.Set(call.SessionID, cancel)
207
208 defer cancel()
209 defer a.activeRequests.Del(call.SessionID)
210
211 history, files := a.preparePrompt(msgs, call.Attachments...)
212
213 startTime := time.Now()
214 a.eventPromptSent(call.SessionID)
215
216 var currentAssistant *message.Message
217 var shouldSummarize bool
218 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
219 Prompt: call.Prompt,
220 Files: files,
221 Messages: history,
222 ProviderOptions: call.ProviderOptions,
223 MaxOutputTokens: &call.MaxOutputTokens,
224 TopP: call.TopP,
225 Temperature: call.Temperature,
226 PresencePenalty: call.PresencePenalty,
227 TopK: call.TopK,
228 FrequencyPenalty: call.FrequencyPenalty,
229 // Before each step create a new assistant message.
230 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
231 prepared.Messages = options.Messages
232 // Reset all cached items.
233 for i := range prepared.Messages {
234 prepared.Messages[i].ProviderOptions = nil
235 }
236
237 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
238 a.messageQueue.Del(call.SessionID)
239 for _, queued := range queuedCalls {
240 userMessage, createErr := a.createUserMessage(callContext, queued)
241 if createErr != nil {
242 return callContext, prepared, createErr
243 }
244 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
245 }
246
247 lastSystemRoleInx := 0
248 systemMessageUpdated := false
249 for i, msg := range prepared.Messages {
250 // Only add cache control to the last message.
251 if msg.Role == fantasy.MessageRoleSystem {
252 lastSystemRoleInx = i
253 } else if !systemMessageUpdated {
254 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
255 systemMessageUpdated = true
256 }
257 // Than add cache control to the last 2 messages.
258 if i > len(prepared.Messages)-3 {
259 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
260 }
261 }
262
263 if a.systemPromptPrefix != "" {
264 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
265 }
266
267 var assistantMsg message.Message
268 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
269 Role: message.Assistant,
270 Parts: []message.ContentPart{},
271 Model: a.largeModel.ModelCfg.Model,
272 Provider: a.largeModel.ModelCfg.Provider,
273 })
274 if err != nil {
275 return callContext, prepared, err
276 }
277 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
278 currentAssistant = &assistantMsg
279 return callContext, prepared, err
280 },
281 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
282 currentAssistant.AppendReasoningContent(reasoning.Text)
283 return a.messages.Update(genCtx, *currentAssistant)
284 },
285 OnReasoningDelta: func(id string, text string) error {
286 currentAssistant.AppendReasoningContent(text)
287 return a.messages.Update(genCtx, *currentAssistant)
288 },
289 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
290 // handle anthropic signature
291 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
292 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
293 currentAssistant.AppendReasoningSignature(reasoning.Signature)
294 }
295 }
296 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
297 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
298 currentAssistant.AppendReasoningSignature(reasoning.Signature)
299 }
300 }
301 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
302 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
303 currentAssistant.SetReasoningResponsesData(reasoning)
304 }
305 }
306 currentAssistant.FinishThinking()
307 return a.messages.Update(genCtx, *currentAssistant)
308 },
309 OnTextDelta: func(id string, text string) error {
310 // Strip leading newline from initial text content. This is is
311 // particularly important in non-interactive mode where leading
312 // newlines are very visible.
313 if len(currentAssistant.Parts) == 0 {
314 text = strings.TrimPrefix(text, "\n")
315 }
316
317 currentAssistant.AppendContent(text)
318 return a.messages.Update(genCtx, *currentAssistant)
319 },
320 OnToolInputStart: func(id string, toolName string) error {
321 toolCall := message.ToolCall{
322 ID: id,
323 Name: toolName,
324 ProviderExecuted: false,
325 Finished: false,
326 }
327 currentAssistant.AddToolCall(toolCall)
328 return a.messages.Update(genCtx, *currentAssistant)
329 },
330 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
331 // TODO: implement
332 },
333 OnToolCall: func(tc fantasy.ToolCallContent) error {
334 toolCall := message.ToolCall{
335 ID: tc.ToolCallID,
336 Name: tc.ToolName,
337 Input: tc.Input,
338 ProviderExecuted: false,
339 Finished: true,
340 }
341 currentAssistant.AddToolCall(toolCall)
342 return a.messages.Update(genCtx, *currentAssistant)
343 },
344 OnToolResult: func(result fantasy.ToolResultContent) error {
345 var resultContent string
346 isError := false
347 switch result.Result.GetType() {
348 case fantasy.ToolResultContentTypeText:
349 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
350 if ok {
351 resultContent = r.Text
352 }
353 case fantasy.ToolResultContentTypeError:
354 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
355 if ok {
356 isError = true
357 resultContent = r.Error.Error()
358 }
359 case fantasy.ToolResultContentTypeMedia:
360 // TODO: handle this message type
361 }
362 toolResult := message.ToolResult{
363 ToolCallID: result.ToolCallID,
364 Name: result.ToolName,
365 Content: resultContent,
366 IsError: isError,
367 Metadata: result.ClientMetadata,
368 }
369 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
370 Role: message.Tool,
371 Parts: []message.ContentPart{
372 toolResult,
373 },
374 })
375 if createMsgErr != nil {
376 return createMsgErr
377 }
378 return nil
379 },
380 OnStepFinish: func(stepResult fantasy.StepResult) error {
381 finishReason := message.FinishReasonUnknown
382 switch stepResult.FinishReason {
383 case fantasy.FinishReasonLength:
384 finishReason = message.FinishReasonMaxTokens
385 case fantasy.FinishReasonStop:
386 finishReason = message.FinishReasonEndTurn
387 case fantasy.FinishReasonToolCalls:
388 finishReason = message.FinishReasonToolUse
389 }
390 currentAssistant.AddFinish(finishReason, "", "")
391 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
392 sessionLock.Lock()
393 _, sessionErr := a.sessions.Save(genCtx, currentSession)
394 sessionLock.Unlock()
395 if sessionErr != nil {
396 return sessionErr
397 }
398 if finishReason == message.FinishReasonEndTurn {
399 a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
400 }
401 return a.messages.Update(genCtx, *currentAssistant)
402 },
403 StopWhen: []fantasy.StopCondition{
404 func(_ []fantasy.StepResult) bool {
405 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
406 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
407 remaining := cw - tokens
408 var threshold int64
409 if cw > 200_000 {
410 threshold = 20_000
411 } else {
412 threshold = int64(float64(cw) * 0.2)
413 }
414 if (remaining <= threshold) && !a.disableAutoSummarize {
415 shouldSummarize = true
416 return true
417 }
418 return false
419 },
420 },
421 })
422
423 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
424
425 if err != nil {
426 isCancelErr := errors.Is(err, context.Canceled)
427 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
428 if currentAssistant == nil {
429 return result, err
430 }
431 // Ensure we finish thinking on error to close the reasoning state.
432 currentAssistant.FinishThinking()
433 toolCalls := currentAssistant.ToolCalls()
434 // INFO: we use the parent context here because the genCtx has been cancelled.
435 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
436 if createErr != nil {
437 return nil, createErr
438 }
439 for _, tc := range toolCalls {
440 if !tc.Finished {
441 tc.Finished = true
442 tc.Input = "{}"
443 currentAssistant.AddToolCall(tc)
444 updateErr := a.messages.Update(ctx, *currentAssistant)
445 if updateErr != nil {
446 return nil, updateErr
447 }
448 }
449
450 found := false
451 for _, msg := range msgs {
452 if msg.Role == message.Tool {
453 for _, tr := range msg.ToolResults() {
454 if tr.ToolCallID == tc.ID {
455 found = true
456 break
457 }
458 }
459 }
460 if found {
461 break
462 }
463 }
464 if found {
465 continue
466 }
467 content := "There was an error while executing the tool"
468 if isCancelErr {
469 content = "Tool execution canceled by user"
470 } else if isPermissionErr {
471 content = "User denied permission"
472 }
473 toolResult := message.ToolResult{
474 ToolCallID: tc.ID,
475 Name: tc.Name,
476 Content: content,
477 IsError: true,
478 }
479 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
480 Role: message.Tool,
481 Parts: []message.ContentPart{
482 toolResult,
483 },
484 })
485 if createErr != nil {
486 return nil, createErr
487 }
488 }
489 if isCancelErr {
490 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
491 } else if isPermissionErr {
492 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
493 } else {
494 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
495 }
496 // Note: we use the parent context here because the genCtx has been
497 // cancelled.
498 updateErr := a.messages.Update(ctx, *currentAssistant)
499 if updateErr != nil {
500 return nil, updateErr
501 }
502 return nil, err
503 }
504 wg.Wait()
505
506 if shouldSummarize {
507 a.cancelCompletionNotification(call.SessionID)
508 a.activeRequests.Del(call.SessionID)
509 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
510 return nil, summarizeErr
511 }
512 // If the agent wasn't done...
513 if len(currentAssistant.ToolCalls()) > 0 {
514 existing, ok := a.messageQueue.Get(call.SessionID)
515 if !ok {
516 existing = []SessionAgentCall{}
517 }
518 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
519 existing = append(existing, call)
520 a.messageQueue.Set(call.SessionID, existing)
521 }
522 }
523
524 // Release active request before processing queued messages.
525 a.activeRequests.Del(call.SessionID)
526 cancel()
527
528 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
529 if !ok || len(queuedMessages) == 0 {
530 return result, err
531 }
532 // There are queued messages restart the loop.
533 firstQueuedMessage := queuedMessages[0]
534 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
535 return a.Run(ctx, firstQueuedMessage)
536}
537
538func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
539 // Do not emit notifications for Agent-tool sub-sessions.
540 if a.sessions != nil && a.sessions.IsAgentToolSession(sessionID) {
541 return
542 }
543 if a.notifier == nil {
544 return
545 }
546
547 if sessionTitle == "" {
548 sessionTitle = sessionID
549 }
550
551 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
552 cancel()
553 }
554
555 title := "💘 Crush is waiting"
556 message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
557 cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
558 if cancel == nil {
559 cancel = func() {}
560 }
561 a.completionCancels.Set(sessionID, cancel)
562}
563
564func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
565 if a.notifier == nil {
566 return
567 }
568
569 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
570 cancel()
571 }
572}
573
574// CancelCompletionNotification implements SessionAgent.
575func (a *sessionAgent) CancelCompletionNotification(sessionID string) {
576 a.cancelCompletionNotification(sessionID)
577}
578
579// HasPendingCompletionNotification implements SessionAgent.
580func (a *sessionAgent) HasPendingCompletionNotification(sessionID string) bool {
581 if a.IsSessionBusy(sessionID) {
582 return false
583 }
584 _, ok := a.completionCancels.Get(sessionID)
585 return ok
586}
587
588func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
589 if a.IsSessionBusy(sessionID) {
590 return ErrSessionBusy
591 }
592
593 currentSession, err := a.sessions.Get(ctx, sessionID)
594 if err != nil {
595 return fmt.Errorf("failed to get session: %w", err)
596 }
597 msgs, err := a.getSessionMessages(ctx, currentSession)
598 if err != nil {
599 return err
600 }
601 if len(msgs) == 0 {
602 // Nothing to summarize.
603 return nil
604 }
605
606 aiMsgs, _ := a.preparePrompt(msgs)
607
608 genCtx, cancel := context.WithCancel(ctx)
609 a.activeRequests.Set(sessionID, cancel)
610 defer a.activeRequests.Del(sessionID)
611 defer cancel()
612
613 agent := fantasy.NewAgent(a.largeModel.Model,
614 fantasy.WithSystemPrompt(string(summaryPrompt)),
615 )
616 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
617 Role: message.Assistant,
618 Model: a.largeModel.Model.Model(),
619 Provider: a.largeModel.Model.Provider(),
620 IsSummaryMessage: true,
621 })
622 if err != nil {
623 return err
624 }
625
626 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
627 Prompt: "Provide a detailed summary of our conversation above.",
628 Messages: aiMsgs,
629 ProviderOptions: opts,
630 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
631 prepared.Messages = options.Messages
632 if a.systemPromptPrefix != "" {
633 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
634 }
635 return callContext, prepared, nil
636 },
637 OnReasoningDelta: func(id string, text string) error {
638 summaryMessage.AppendReasoningContent(text)
639 return a.messages.Update(genCtx, summaryMessage)
640 },
641 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
642 // Handle anthropic signature.
643 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
644 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
645 summaryMessage.AppendReasoningSignature(signature.Signature)
646 }
647 }
648 summaryMessage.FinishThinking()
649 return a.messages.Update(genCtx, summaryMessage)
650 },
651 OnTextDelta: func(id, text string) error {
652 summaryMessage.AppendContent(text)
653 return a.messages.Update(genCtx, summaryMessage)
654 },
655 })
656 if err != nil {
657 isCancelErr := errors.Is(err, context.Canceled)
658 if isCancelErr {
659 // User cancelled summarize we need to remove the summary message.
660 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
661 return deleteErr
662 }
663 return err
664 }
665
666 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
667 err = a.messages.Update(genCtx, summaryMessage)
668 if err != nil {
669 return err
670 }
671
672 var openrouterCost *float64
673 for _, step := range resp.Steps {
674 stepCost := a.openrouterCost(step.ProviderMetadata)
675 if stepCost != nil {
676 newCost := *stepCost
677 if openrouterCost != nil {
678 newCost += *openrouterCost
679 }
680 openrouterCost = &newCost
681 }
682 }
683
684 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
685
686 // Just in case, get just the last usage info.
687 usage := resp.Response.Usage
688 currentSession.SummaryMessageID = summaryMessage.ID
689 currentSession.CompletionTokens = usage.OutputTokens
690 currentSession.PromptTokens = 0
691 _, err = a.sessions.Save(genCtx, currentSession)
692 return err
693}
694
695func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
696 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
697 return fantasy.ProviderOptions{}
698 }
699 return fantasy.ProviderOptions{
700 anthropic.Name: &anthropic.ProviderCacheControlOptions{
701 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
702 },
703 bedrock.Name: &anthropic.ProviderCacheControlOptions{
704 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
705 },
706 }
707}
708
709func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
710 var attachmentParts []message.ContentPart
711 for _, attachment := range call.Attachments {
712 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
713 }
714 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
715 parts = append(parts, attachmentParts...)
716 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
717 Role: message.User,
718 Parts: parts,
719 })
720 if err != nil {
721 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
722 }
723 return msg, nil
724}
725
726func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
727 var history []fantasy.Message
728 for _, m := range msgs {
729 if len(m.Parts) == 0 {
730 continue
731 }
732 // Assistant message without content or tool calls (cancelled before it
733 // returned anything).
734 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
735 continue
736 }
737 history = append(history, m.ToAIMessage()...)
738 }
739
740 var files []fantasy.FilePart
741 for _, attachment := range attachments {
742 files = append(files, fantasy.FilePart{
743 Filename: attachment.FileName,
744 Data: attachment.Content,
745 MediaType: attachment.MimeType,
746 })
747 }
748
749 return history, files
750}
751
752func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
753 msgs, err := a.messages.List(ctx, session.ID)
754 if err != nil {
755 return nil, fmt.Errorf("failed to list messages: %w", err)
756 }
757
758 if session.SummaryMessageID != "" {
759 summaryMsgInex := -1
760 for i, msg := range msgs {
761 if msg.ID == session.SummaryMessageID {
762 summaryMsgInex = i
763 break
764 }
765 }
766 if summaryMsgInex != -1 {
767 msgs = msgs[summaryMsgInex:]
768 msgs[0].Role = message.User
769 }
770 }
771 return msgs, nil
772}
773
774func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
775 if prompt == "" {
776 return
777 }
778
779 var maxOutput int64 = 40
780 if a.smallModel.CatwalkCfg.CanReason {
781 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
782 }
783
784 agent := fantasy.NewAgent(a.smallModel.Model,
785 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
786 fantasy.WithMaxOutputTokens(maxOutput),
787 )
788
789 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
790 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
791 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
792 prepared.Messages = options.Messages
793 if a.systemPromptPrefix != "" {
794 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
795 }
796 return callContext, prepared, nil
797 },
798 })
799 if err != nil {
800 slog.Error("error generating title", "err", err)
801 return
802 }
803
804 title := resp.Response.Content.Text()
805
806 title = strings.ReplaceAll(title, "\n", " ")
807
808 // Remove thinking tags if present.
809 if idx := strings.Index(title, "</think>"); idx > 0 {
810 title = title[idx+len("</think>"):]
811 }
812
813 title = strings.TrimSpace(title)
814 if title == "" {
815 slog.Warn("failed to generate title", "warn", "empty title")
816 return
817 }
818
819 session.Title = title
820
821 var openrouterCost *float64
822 for _, step := range resp.Steps {
823 stepCost := a.openrouterCost(step.ProviderMetadata)
824 if stepCost != nil {
825 newCost := *stepCost
826 if openrouterCost != nil {
827 newCost += *openrouterCost
828 }
829 openrouterCost = &newCost
830 }
831 }
832
833 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
834 _, saveErr := a.sessions.Save(ctx, *session)
835 if saveErr != nil {
836 slog.Error("failed to save session title & usage", "error", saveErr)
837 return
838 }
839}
840
841func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
842 openrouterMetadata, ok := metadata[openrouter.Name]
843 if !ok {
844 return nil
845 }
846
847 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
848 if !ok {
849 return nil
850 }
851 return &opts.Usage.Cost
852}
853
854func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
855 modelConfig := model.CatwalkCfg
856 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
857 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
858 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
859 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
860
861 a.eventTokensUsed(session.ID, model, usage, cost)
862
863 if overrideCost != nil {
864 session.Cost += *overrideCost
865 } else {
866 session.Cost += cost
867 }
868
869 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
870 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
871}
872
873func (a *sessionAgent) Cancel(sessionID string) {
874 // Cancel regular requests.
875 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
876 slog.Info("Request cancellation initiated", "session_id", sessionID)
877 cancel()
878 }
879
880 // Also check for summarize requests.
881 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
882 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
883 cancel()
884 }
885
886 if a.QueuedPrompts(sessionID) > 0 {
887 slog.Info("Clearing queued prompts", "session_id", sessionID)
888 a.messageQueue.Del(sessionID)
889 }
890}
891
892func (a *sessionAgent) ClearQueue(sessionID string) {
893 if a.QueuedPrompts(sessionID) > 0 {
894 slog.Info("Clearing queued prompts", "session_id", sessionID)
895 a.messageQueue.Del(sessionID)
896 }
897}
898
899func (a *sessionAgent) CancelAll() {
900 if !a.IsBusy() {
901 // still ensure notifications are cancelled even when not busy
902 for cancel := range a.completionCancels.Seq() {
903 if cancel != nil {
904 cancel()
905 }
906 }
907 a.completionCancels.Reset(make(map[string]context.CancelFunc))
908 return
909 }
910 for key := range a.activeRequests.Seq2() {
911 a.Cancel(key) // key is sessionID
912 }
913
914 timeout := time.After(5 * time.Second)
915 for a.IsBusy() {
916 select {
917 case <-timeout:
918 return
919 default:
920 time.Sleep(200 * time.Millisecond)
921 }
922 }
923
924 for cancel := range a.completionCancels.Seq() {
925 if cancel != nil {
926 cancel()
927 }
928 }
929 a.completionCancels.Reset(make(map[string]context.CancelFunc))
930}
931
932func (a *sessionAgent) IsBusy() bool {
933 var busy bool
934 for cancelFunc := range a.activeRequests.Seq() {
935 if cancelFunc != nil {
936 busy = true
937 break
938 }
939 }
940 return busy
941}
942
943func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
944 _, busy := a.activeRequests.Get(sessionID)
945 return busy
946}
947
948func (a *sessionAgent) QueuedPrompts(sessionID string) int {
949 l, ok := a.messageQueue.Get(sessionID)
950 if !ok {
951 return 0
952 }
953 return len(l)
954}
955
956func (a *sessionAgent) SetModels(large Model, small Model) {
957 a.largeModel = large
958 a.smallModel = small
959}
960
961func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
962 a.tools = tools
963}
964
965func (a *sessionAgent) Model() Model {
966 return a.largeModel
967}