1package agent
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "os"
11 "strconv"
12 "strings"
13 "sync"
14 "time"
15
16 "charm.land/fantasy"
17 "charm.land/fantasy/providers/anthropic"
18 "charm.land/fantasy/providers/bedrock"
19 "charm.land/fantasy/providers/google"
20 "charm.land/fantasy/providers/openai"
21 "charm.land/fantasy/providers/openrouter"
22 "github.com/charmbracelet/catwalk/pkg/catwalk"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/csync"
26 "github.com/charmbracelet/crush/internal/hooks"
27 "github.com/charmbracelet/crush/internal/message"
28 "github.com/charmbracelet/crush/internal/permission"
29 "github.com/charmbracelet/crush/internal/session"
30)
31
32//go:embed templates/title.md
33var titlePrompt []byte
34
35//go:embed templates/summary.md
36var summaryPrompt []byte
37
38type SessionAgentCall struct {
39 SessionID string
40 Prompt string
41 ProviderOptions fantasy.ProviderOptions
42 Attachments []message.Attachment
43 MaxOutputTokens int64
44 Temperature *float64
45 TopP *float64
46 TopK *int64
47 FrequencyPenalty *float64
48 PresencePenalty *float64
49}
50
51type SessionAgent interface {
52 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
53 SetModels(large Model, small Model)
54 SetTools(tools []fantasy.AgentTool)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 QueuedPrompts(sessionID string) int
60 ClearQueue(sessionID string)
61 Summarize(context.Context, string, fantasy.ProviderOptions) error
62 Model() Model
63}
64
65type Model struct {
66 Model fantasy.LanguageModel
67 CatwalkCfg catwalk.Model
68 ModelCfg config.SelectedModel
69}
70
71type sessionAgent struct {
72 largeModel Model
73 smallModel Model
74 systemPromptPrefix string
75 systemPrompt string
76 tools []fantasy.AgentTool
77 sessions session.Service
78 messages message.Service
79 disableAutoSummarize bool
80 isYolo bool
81 hooks *hooks.Executor
82
83 messageQueue *csync.Map[string, []SessionAgentCall]
84 activeRequests *csync.Map[string, context.CancelFunc]
85}
86
87type SessionAgentOptions struct {
88 LargeModel Model
89 SmallModel Model
90 SystemPromptPrefix string
91 SystemPrompt string
92 DisableAutoSummarize bool
93 IsYolo bool
94 Sessions session.Service
95 Messages message.Service
96 Tools []fantasy.AgentTool
97 Hooks *hooks.Executor
98}
99
100func NewSessionAgent(
101 opts SessionAgentOptions,
102) SessionAgent {
103 return &sessionAgent{
104 largeModel: opts.LargeModel,
105 smallModel: opts.SmallModel,
106 systemPromptPrefix: opts.SystemPromptPrefix,
107 systemPrompt: opts.SystemPrompt,
108 sessions: opts.Sessions,
109 messages: opts.Messages,
110 disableAutoSummarize: opts.DisableAutoSummarize,
111 tools: opts.Tools,
112 isYolo: opts.IsYolo,
113 hooks: opts.Hooks,
114 messageQueue: csync.NewMap[string, []SessionAgentCall](),
115 activeRequests: csync.NewMap[string, context.CancelFunc](),
116 }
117}
118
119func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
120 if call.Prompt == "" {
121 return nil, ErrEmptyPrompt
122 }
123 if call.SessionID == "" {
124 return nil, ErrSessionMissing
125 }
126
127 // Queue the message if busy
128 if a.IsSessionBusy(call.SessionID) {
129 existing, ok := a.messageQueue.Get(call.SessionID)
130 if !ok {
131 existing = []SessionAgentCall{}
132 }
133 existing = append(existing, call)
134 a.messageQueue.Set(call.SessionID, existing)
135 return nil, nil
136 }
137
138 if len(a.tools) > 0 {
139 // add anthropic caching to the last tool
140 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
141 }
142
143 agent := fantasy.NewAgent(
144 a.largeModel.Model,
145 fantasy.WithSystemPrompt(a.systemPrompt),
146 fantasy.WithTools(a.tools...),
147 )
148
149 sessionLock := sync.Mutex{}
150 currentSession, err := a.sessions.Get(ctx, call.SessionID)
151 if err != nil {
152 return nil, fmt.Errorf("failed to get session: %w", err)
153 }
154
155 msgs, err := a.getSessionMessages(ctx, currentSession)
156 if err != nil {
157 return nil, fmt.Errorf("failed to get session messages: %w", err)
158 }
159
160 var wg sync.WaitGroup
161 // Generate title if first message
162 if len(msgs) == 0 {
163 wg.Go(func() {
164 sessionLock.Lock()
165 a.generateTitle(ctx, ¤tSession, call.Prompt)
166 sessionLock.Unlock()
167 })
168 }
169
170 // Add the user message to the session
171 _, err = a.createUserMessage(ctx, call)
172 if err != nil {
173 return nil, err
174 }
175
176 // Execute UserPromptSubmit hook
177 if a.hooks != nil {
178 if err := a.hooks.Execute(ctx, hooks.HookContext{
179 EventType: config.UserPromptSubmit,
180 SessionID: call.SessionID,
181 UserPrompt: call.Prompt,
182 Provider: a.largeModel.ModelCfg.Provider,
183 Model: a.largeModel.ModelCfg.Model,
184 }); err != nil {
185 slog.Debug("user_prompt_submit hook execution failed", "error", err)
186 }
187 }
188
189 // add the session to the context
190 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
191
192 genCtx, cancel := context.WithCancel(ctx)
193 a.activeRequests.Set(call.SessionID, cancel)
194
195 defer cancel()
196 defer a.activeRequests.Del(call.SessionID)
197
198 history, files := a.preparePrompt(msgs, call.Attachments...)
199
200 startTime := time.Now()
201 a.eventPromptSent(call.SessionID)
202
203 var currentAssistant *message.Message
204 var shouldSummarize bool
205 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
206 Prompt: call.Prompt,
207 Files: files,
208 Messages: history,
209 ProviderOptions: call.ProviderOptions,
210 MaxOutputTokens: &call.MaxOutputTokens,
211 TopP: call.TopP,
212 Temperature: call.Temperature,
213 PresencePenalty: call.PresencePenalty,
214 TopK: call.TopK,
215 FrequencyPenalty: call.FrequencyPenalty,
216 // Before each step create the new assistant message
217 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
218 prepared.Messages = options.Messages
219 // reset all cached items
220 for i := range prepared.Messages {
221 prepared.Messages[i].ProviderOptions = nil
222 }
223
224 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
225 a.messageQueue.Del(call.SessionID)
226 for _, queued := range queuedCalls {
227 userMessage, createErr := a.createUserMessage(callContext, queued)
228 if createErr != nil {
229 return callContext, prepared, createErr
230 }
231 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
232 }
233
234 lastSystemRoleInx := 0
235 systemMessageUpdated := false
236 for i, msg := range prepared.Messages {
237 // only add cache control to the last message
238 if msg.Role == fantasy.MessageRoleSystem {
239 lastSystemRoleInx = i
240 } else if !systemMessageUpdated {
241 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
242 systemMessageUpdated = true
243 }
244 // than add cache control to the last 2 messages
245 if i > len(prepared.Messages)-3 {
246 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
247 }
248 }
249
250 if a.systemPromptPrefix != "" {
251 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
252 }
253
254 var assistantMsg message.Message
255 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
256 Role: message.Assistant,
257 Parts: []message.ContentPart{},
258 Model: a.largeModel.ModelCfg.Model,
259 Provider: a.largeModel.ModelCfg.Provider,
260 })
261 if err != nil {
262 return callContext, prepared, err
263 }
264 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
265 currentAssistant = &assistantMsg
266 return callContext, prepared, err
267 },
268 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
269 currentAssistant.AppendReasoningContent(reasoning.Text)
270 return a.messages.Update(genCtx, *currentAssistant)
271 },
272 OnReasoningDelta: func(id string, text string) error {
273 currentAssistant.AppendReasoningContent(text)
274 return a.messages.Update(genCtx, *currentAssistant)
275 },
276 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
277 // handle anthropic signature
278 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
279 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
280 currentAssistant.AppendReasoningSignature(reasoning.Signature)
281 }
282 }
283 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
284 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
285 currentAssistant.AppendReasoningSignature(reasoning.Signature)
286 }
287 }
288 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
289 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
290 currentAssistant.SetReasoningResponsesData(reasoning)
291 }
292 }
293 currentAssistant.FinishThinking()
294 return a.messages.Update(genCtx, *currentAssistant)
295 },
296 OnTextDelta: func(id string, text string) error {
297 currentAssistant.AppendContent(text)
298 return a.messages.Update(genCtx, *currentAssistant)
299 },
300 OnToolInputStart: func(id string, toolName string) error {
301 toolCall := message.ToolCall{
302 ID: id,
303 Name: toolName,
304 ProviderExecuted: false,
305 Finished: false,
306 }
307 currentAssistant.AddToolCall(toolCall)
308 return a.messages.Update(genCtx, *currentAssistant)
309 },
310 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
311 // TODO: implement
312 },
313 OnToolCall: func(tc fantasy.ToolCallContent) error {
314 // Execute PreToolUse hook - blocks tool execution on error
315 if a.hooks != nil {
316 toolInput := make(map[string]any)
317 if err := json.Unmarshal([]byte(tc.Input), &toolInput); err != nil {
318 slog.Warn("Failed to unmarshal tool input for PreToolUse hook", "error", err, "tool", tc.ToolName)
319 }
320 if err := a.hooks.Execute(genCtx, hooks.HookContext{
321 EventType: config.PreToolUse,
322 SessionID: call.SessionID,
323 ToolName: tc.ToolName,
324 ToolInput: toolInput,
325 MessageID: currentAssistant.ID,
326 Provider: a.largeModel.ModelCfg.Provider,
327 Model: a.largeModel.ModelCfg.Model,
328 }); err != nil {
329 return fmt.Errorf("PreToolUse hook blocked tool execution: %w", err)
330 }
331 }
332
333 toolCall := message.ToolCall{
334 ID: tc.ToolCallID,
335 Name: tc.ToolName,
336 Input: tc.Input,
337 ProviderExecuted: false,
338 Finished: true,
339 }
340 currentAssistant.AddToolCall(toolCall)
341 return a.messages.Update(genCtx, *currentAssistant)
342 },
343 OnToolResult: func(result fantasy.ToolResultContent) error {
344 var resultContent string
345 isError := false
346 switch result.Result.GetType() {
347 case fantasy.ToolResultContentTypeText:
348 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
349 if ok {
350 resultContent = r.Text
351 }
352 case fantasy.ToolResultContentTypeError:
353 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
354 if ok {
355 isError = true
356 resultContent = r.Error.Error()
357 }
358 case fantasy.ToolResultContentTypeMedia:
359 // TODO: handle this message type
360 }
361
362 // Execute PostToolUse hook
363 if a.hooks != nil {
364 toolInput := make(map[string]any)
365 // Try to get tool input from the assistant message
366 toolCalls := currentAssistant.ToolCalls()
367 for _, tc := range toolCalls {
368 if tc.ID == result.ToolCallID {
369 if err := json.Unmarshal([]byte(tc.Input), &toolInput); err != nil {
370 slog.Debug("Failed to unmarshal tool input for PostToolUse hook", "error", err, "tool", result.ToolName)
371 }
372 break
373 }
374 }
375
376 if err := a.hooks.Execute(genCtx, hooks.HookContext{
377 EventType: config.PostToolUse,
378 SessionID: call.SessionID,
379 ToolName: result.ToolName,
380 ToolInput: toolInput,
381 ToolResult: resultContent,
382 ToolError: isError,
383 MessageID: currentAssistant.ID,
384 Provider: a.largeModel.ModelCfg.Provider,
385 Model: a.largeModel.ModelCfg.Model,
386 }); err != nil {
387 slog.Debug("post_tool_use hook execution failed", "error", err)
388 }
389 }
390
391 toolResult := message.ToolResult{
392 ToolCallID: result.ToolCallID,
393 Name: result.ToolName,
394 Content: resultContent,
395 IsError: isError,
396 Metadata: result.ClientMetadata,
397 }
398 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
399 Role: message.Tool,
400 Parts: []message.ContentPart{
401 toolResult,
402 },
403 })
404 if createMsgErr != nil {
405 return createMsgErr
406 }
407 return nil
408 },
409 OnStepFinish: func(stepResult fantasy.StepResult) error {
410 finishReason := message.FinishReasonUnknown
411 switch stepResult.FinishReason {
412 case fantasy.FinishReasonLength:
413 finishReason = message.FinishReasonMaxTokens
414 case fantasy.FinishReasonStop:
415 finishReason = message.FinishReasonEndTurn
416 case fantasy.FinishReasonToolCalls:
417 finishReason = message.FinishReasonToolUse
418 }
419 currentAssistant.AddFinish(finishReason, "", "")
420 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
421 sessionLock.Lock()
422 _, sessionErr := a.sessions.Save(genCtx, currentSession)
423 sessionLock.Unlock()
424 if sessionErr != nil {
425 return sessionErr
426 }
427 return a.messages.Update(genCtx, *currentAssistant)
428 },
429 StopWhen: []fantasy.StopCondition{
430 func(_ []fantasy.StepResult) bool {
431 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
432 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
433 remaining := cw - tokens
434 var threshold int64
435 if cw > 200_000 {
436 threshold = 20_000
437 } else {
438 threshold = int64(float64(cw) * 0.2)
439 }
440 if (remaining <= threshold) && !a.disableAutoSummarize {
441 shouldSummarize = true
442 return true
443 }
444 return false
445 },
446 },
447 })
448
449 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
450
451 if err != nil {
452 isCancelErr := errors.Is(err, context.Canceled)
453 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
454 if currentAssistant == nil {
455 return result, err
456 }
457 // Ensure we finish thinking on error to close the reasoning state
458 currentAssistant.FinishThinking()
459 toolCalls := currentAssistant.ToolCalls()
460 // INFO: we use the parent context here because the genCtx has been cancelled
461 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
462 if createErr != nil {
463 return nil, createErr
464 }
465 for _, tc := range toolCalls {
466 if !tc.Finished {
467 tc.Finished = true
468 tc.Input = "{}"
469 currentAssistant.AddToolCall(tc)
470 updateErr := a.messages.Update(ctx, *currentAssistant)
471 if updateErr != nil {
472 return nil, updateErr
473 }
474 }
475
476 found := false
477 for _, msg := range msgs {
478 if msg.Role == message.Tool {
479 for _, tr := range msg.ToolResults() {
480 if tr.ToolCallID == tc.ID {
481 found = true
482 break
483 }
484 }
485 }
486 if found {
487 break
488 }
489 }
490 if found {
491 continue
492 }
493 content := "There was an error while executing the tool"
494 if isCancelErr {
495 content = "Tool execution canceled by user"
496 } else if isPermissionErr {
497 content = "Permission denied"
498 }
499 toolResult := message.ToolResult{
500 ToolCallID: tc.ID,
501 Name: tc.Name,
502 Content: content,
503 IsError: true,
504 }
505 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
506 Role: message.Tool,
507 Parts: []message.ContentPart{
508 toolResult,
509 },
510 })
511 if createErr != nil {
512 return nil, createErr
513 }
514 }
515 if isCancelErr {
516 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
517 } else if isPermissionErr {
518 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
519 } else {
520 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
521 }
522 // INFO: we use the parent context here because the genCtx has been cancelled
523 updateErr := a.messages.Update(ctx, *currentAssistant)
524 if updateErr != nil {
525 return nil, updateErr
526 }
527 return nil, err
528 }
529 wg.Wait()
530
531 // Execute Stop hook
532 if a.hooks != nil && result != nil {
533 var totalTokens, inputTokens int64
534 for _, step := range result.Steps {
535 totalTokens += step.Usage.TotalTokens
536 inputTokens += step.Usage.InputTokens
537 }
538
539 if err := a.hooks.Execute(ctx, hooks.HookContext{
540 EventType: config.Stop,
541 SessionID: call.SessionID,
542 MessageID: currentAssistant.ID,
543 Provider: a.largeModel.ModelCfg.Provider,
544 Model: a.largeModel.ModelCfg.Model,
545 TokensUsed: totalTokens,
546 TokensInput: inputTokens,
547 }); err != nil {
548 slog.Debug("stop hook execution failed", "error", err)
549 }
550 }
551
552 if shouldSummarize {
553 a.activeRequests.Del(call.SessionID)
554 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
555 return nil, summarizeErr
556 }
557 // if the agent was not done...
558 if len(currentAssistant.ToolCalls()) > 0 {
559 existing, ok := a.messageQueue.Get(call.SessionID)
560 if !ok {
561 existing = []SessionAgentCall{}
562 }
563 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
564 existing = append(existing, call)
565 a.messageQueue.Set(call.SessionID, existing)
566 }
567 }
568
569 // release active request before processing queued messages
570 a.activeRequests.Del(call.SessionID)
571 cancel()
572
573 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
574 if !ok || len(queuedMessages) == 0 {
575 return result, err
576 }
577 // there are queued messages restart the loop
578 firstQueuedMessage := queuedMessages[0]
579 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
580 return a.Run(ctx, firstQueuedMessage)
581}
582
583func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
584 if a.IsSessionBusy(sessionID) {
585 return ErrSessionBusy
586 }
587
588 currentSession, err := a.sessions.Get(ctx, sessionID)
589 if err != nil {
590 return fmt.Errorf("failed to get session: %w", err)
591 }
592 msgs, err := a.getSessionMessages(ctx, currentSession)
593 if err != nil {
594 return err
595 }
596 if len(msgs) == 0 {
597 // nothing to summarize
598 return nil
599 }
600
601 // Execute PreCompact hook
602 if a.hooks != nil {
603 if err := a.hooks.Execute(ctx, hooks.HookContext{
604 EventType: config.PreCompact,
605 SessionID: sessionID,
606 Provider: a.largeModel.ModelCfg.Provider,
607 Model: a.largeModel.ModelCfg.Model,
608 }); err != nil {
609 slog.Debug("pre_compact hook execution failed", "error", err)
610 }
611 }
612
613 aiMsgs, _ := a.preparePrompt(msgs)
614
615 genCtx, cancel := context.WithCancel(ctx)
616 a.activeRequests.Set(sessionID, cancel)
617 defer a.activeRequests.Del(sessionID)
618 defer cancel()
619
620 agent := fantasy.NewAgent(a.largeModel.Model,
621 fantasy.WithSystemPrompt(string(summaryPrompt)),
622 )
623 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
624 Role: message.Assistant,
625 Model: a.largeModel.Model.Model(),
626 Provider: a.largeModel.Model.Provider(),
627 IsSummaryMessage: true,
628 })
629 if err != nil {
630 return err
631 }
632
633 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
634 Prompt: "Provide a detailed summary of our conversation above.",
635 Messages: aiMsgs,
636 ProviderOptions: opts,
637 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
638 prepared.Messages = options.Messages
639 if a.systemPromptPrefix != "" {
640 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
641 }
642 return callContext, prepared, nil
643 },
644 OnReasoningDelta: func(id string, text string) error {
645 summaryMessage.AppendReasoningContent(text)
646 return a.messages.Update(genCtx, summaryMessage)
647 },
648 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
649 // handle anthropic signature
650 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
651 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
652 summaryMessage.AppendReasoningSignature(signature.Signature)
653 }
654 }
655 summaryMessage.FinishThinking()
656 return a.messages.Update(genCtx, summaryMessage)
657 },
658 OnTextDelta: func(id, text string) error {
659 summaryMessage.AppendContent(text)
660 return a.messages.Update(genCtx, summaryMessage)
661 },
662 })
663 if err != nil {
664 isCancelErr := errors.Is(err, context.Canceled)
665 if isCancelErr {
666 // User cancelled summarize we need to remove the summary message
667 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
668 return deleteErr
669 }
670 return err
671 }
672
673 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
674 err = a.messages.Update(genCtx, summaryMessage)
675 if err != nil {
676 return err
677 }
678
679 var openrouterCost *float64
680 for _, step := range resp.Steps {
681 stepCost := a.openrouterCost(step.ProviderMetadata)
682 if stepCost != nil {
683 newCost := *stepCost
684 if openrouterCost != nil {
685 newCost += *openrouterCost
686 }
687 openrouterCost = &newCost
688 }
689 }
690
691 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
692
693 // just in case get just the last usage
694 usage := resp.Response.Usage
695 currentSession.SummaryMessageID = summaryMessage.ID
696 currentSession.CompletionTokens = usage.OutputTokens
697 currentSession.PromptTokens = 0
698 _, err = a.sessions.Save(genCtx, currentSession)
699 return err
700}
701
702func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
703 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
704 return fantasy.ProviderOptions{}
705 }
706 return fantasy.ProviderOptions{
707 anthropic.Name: &anthropic.ProviderCacheControlOptions{
708 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
709 },
710 bedrock.Name: &anthropic.ProviderCacheControlOptions{
711 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
712 },
713 }
714}
715
716func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
717 var attachmentParts []message.ContentPart
718 for _, attachment := range call.Attachments {
719 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
720 }
721 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
722 parts = append(parts, attachmentParts...)
723 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
724 Role: message.User,
725 Parts: parts,
726 })
727 if err != nil {
728 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
729 }
730 return msg, nil
731}
732
733func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
734 var history []fantasy.Message
735 for _, m := range msgs {
736 if len(m.Parts) == 0 {
737 continue
738 }
739 // Assistant message without content or tool calls (cancelled before it returned anything)
740 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
741 continue
742 }
743 history = append(history, m.ToAIMessage()...)
744 }
745
746 var files []fantasy.FilePart
747 for _, attachment := range attachments {
748 files = append(files, fantasy.FilePart{
749 Filename: attachment.FileName,
750 Data: attachment.Content,
751 MediaType: attachment.MimeType,
752 })
753 }
754
755 return history, files
756}
757
758func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
759 msgs, err := a.messages.List(ctx, session.ID)
760 if err != nil {
761 return nil, fmt.Errorf("failed to list messages: %w", err)
762 }
763
764 if session.SummaryMessageID != "" {
765 summaryMsgInex := -1
766 for i, msg := range msgs {
767 if msg.ID == session.SummaryMessageID {
768 summaryMsgInex = i
769 break
770 }
771 }
772 if summaryMsgInex != -1 {
773 msgs = msgs[summaryMsgInex:]
774 msgs[0].Role = message.User
775 }
776 }
777 return msgs, nil
778}
779
780func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
781 if prompt == "" {
782 return
783 }
784
785 var maxOutput int64 = 40
786 if a.smallModel.CatwalkCfg.CanReason {
787 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
788 }
789
790 agent := fantasy.NewAgent(a.smallModel.Model,
791 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
792 fantasy.WithMaxOutputTokens(maxOutput),
793 )
794
795 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
796 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
797 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
798 prepared.Messages = options.Messages
799 if a.systemPromptPrefix != "" {
800 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
801 }
802 return callContext, prepared, nil
803 },
804 })
805 if err != nil {
806 slog.Error("error generating title", "err", err)
807 return
808 }
809
810 title := resp.Response.Content.Text()
811
812 title = strings.ReplaceAll(title, "\n", " ")
813
814 // remove thinking tags if present
815 if idx := strings.Index(title, "</think>"); idx > 0 {
816 title = title[idx+len("</think>"):]
817 }
818
819 title = strings.TrimSpace(title)
820 if title == "" {
821 slog.Warn("failed to generate title", "warn", "empty title")
822 return
823 }
824
825 session.Title = title
826
827 var openrouterCost *float64
828 for _, step := range resp.Steps {
829 stepCost := a.openrouterCost(step.ProviderMetadata)
830 if stepCost != nil {
831 newCost := *stepCost
832 if openrouterCost != nil {
833 newCost += *openrouterCost
834 }
835 openrouterCost = &newCost
836 }
837 }
838
839 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
840 _, saveErr := a.sessions.Save(ctx, *session)
841 if saveErr != nil {
842 slog.Error("failed to save session title & usage", "error", saveErr)
843 return
844 }
845}
846
847func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
848 openrouterMetadata, ok := metadata[openrouter.Name]
849 if !ok {
850 return nil
851 }
852
853 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
854 if !ok {
855 return nil
856 }
857 return &opts.Usage.Cost
858}
859
860func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
861 modelConfig := model.CatwalkCfg
862 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
863 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
864 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
865 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
866
867 a.eventTokensUsed(session.ID, model, usage, cost)
868
869 if overrideCost != nil {
870 session.Cost += *overrideCost
871 } else {
872 session.Cost += cost
873 }
874
875 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
876 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
877}
878
879func (a *sessionAgent) Cancel(sessionID string) {
880 // Cancel regular requests
881 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
882 slog.Info("Request cancellation initiated", "session_id", sessionID)
883 cancel()
884 }
885
886 // Also check for summarize requests
887 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
888 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
889 cancel()
890 }
891
892 if a.QueuedPrompts(sessionID) > 0 {
893 slog.Info("Clearing queued prompts", "session_id", sessionID)
894 a.messageQueue.Del(sessionID)
895 }
896}
897
898func (a *sessionAgent) ClearQueue(sessionID string) {
899 if a.QueuedPrompts(sessionID) > 0 {
900 slog.Info("Clearing queued prompts", "session_id", sessionID)
901 a.messageQueue.Del(sessionID)
902 }
903}
904
905func (a *sessionAgent) CancelAll() {
906 if !a.IsBusy() {
907 return
908 }
909 for key := range a.activeRequests.Seq2() {
910 a.Cancel(key) // key is sessionID
911 }
912
913 timeout := time.After(5 * time.Second)
914 for a.IsBusy() {
915 select {
916 case <-timeout:
917 return
918 default:
919 time.Sleep(200 * time.Millisecond)
920 }
921 }
922}
923
924func (a *sessionAgent) IsBusy() bool {
925 var busy bool
926 for cancelFunc := range a.activeRequests.Seq() {
927 if cancelFunc != nil {
928 busy = true
929 break
930 }
931 }
932 return busy
933}
934
935func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
936 _, busy := a.activeRequests.Get(sessionID)
937 return busy
938}
939
940func (a *sessionAgent) QueuedPrompts(sessionID string) int {
941 l, ok := a.messageQueue.Get(sessionID)
942 if !ok {
943 return 0
944 }
945 return len(l)
946}
947
948func (a *sessionAgent) SetModels(large Model, small Model) {
949 a.largeModel = large
950 a.smallModel = small
951}
952
953func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
954 a.tools = tools
955}
956
957func (a *sessionAgent) Model() Model {
958 return a.largeModel
959}