1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "context"
12 _ "embed"
13 "errors"
14 "fmt"
15 "log/slog"
16 "os"
17 "strconv"
18 "strings"
19 "sync"
20 "time"
21
22 "charm.land/fantasy"
23 "charm.land/fantasy/providers/anthropic"
24 "charm.land/fantasy/providers/bedrock"
25 "charm.land/fantasy/providers/google"
26 "charm.land/fantasy/providers/openai"
27 "charm.land/fantasy/providers/openrouter"
28 "github.com/charmbracelet/catwalk/pkg/catwalk"
29 "github.com/charmbracelet/crush/internal/agent/tools"
30 "github.com/charmbracelet/crush/internal/config"
31 "github.com/charmbracelet/crush/internal/csync"
32 "github.com/charmbracelet/crush/internal/message"
33 "github.com/charmbracelet/crush/internal/permission"
34 "github.com/charmbracelet/crush/internal/session"
35)
36
37//go:embed templates/title.md
38var titlePrompt []byte
39
40//go:embed templates/summary.md
41var summaryPrompt []byte
42
43type SessionAgentCall struct {
44 SessionID string
45 Prompt string
46 ProviderOptions fantasy.ProviderOptions
47 Attachments []message.Attachment
48 MaxOutputTokens int64
49 Temperature *float64
50 TopP *float64
51 TopK *int64
52 FrequencyPenalty *float64
53 PresencePenalty *float64
54}
55
56type SessionAgent interface {
57 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
58 SetModels(large Model, small Model)
59 SetTools(tools []fantasy.AgentTool)
60 Cancel(sessionID string)
61 CancelAll()
62 IsSessionBusy(sessionID string) bool
63 IsBusy() bool
64 QueuedPrompts(sessionID string) int
65 ClearQueue(sessionID string)
66 Summarize(context.Context, string, fantasy.ProviderOptions) error
67 Model() Model
68}
69
70type Model struct {
71 Model fantasy.LanguageModel
72 CatwalkCfg catwalk.Model
73 ModelCfg config.SelectedModel
74}
75
76type sessionAgent struct {
77 largeModel Model
78 smallModel Model
79 systemPromptPrefix string
80 systemPrompt string
81 tools []fantasy.AgentTool
82 sessions session.Service
83 messages message.Service
84 disableAutoSummarize bool
85 isYolo bool
86
87 messageQueue *csync.Map[string, []SessionAgentCall]
88 activeRequests *csync.Map[string, context.CancelFunc]
89}
90
91type SessionAgentOptions struct {
92 LargeModel Model
93 SmallModel Model
94 SystemPromptPrefix string
95 SystemPrompt string
96 DisableAutoSummarize bool
97 IsYolo bool
98 Sessions session.Service
99 Messages message.Service
100 Tools []fantasy.AgentTool
101}
102
103func NewSessionAgent(
104 opts SessionAgentOptions,
105) SessionAgent {
106 return &sessionAgent{
107 largeModel: opts.LargeModel,
108 smallModel: opts.SmallModel,
109 systemPromptPrefix: opts.SystemPromptPrefix,
110 systemPrompt: opts.SystemPrompt,
111 sessions: opts.Sessions,
112 messages: opts.Messages,
113 disableAutoSummarize: opts.DisableAutoSummarize,
114 tools: opts.Tools,
115 isYolo: opts.IsYolo,
116 messageQueue: csync.NewMap[string, []SessionAgentCall](),
117 activeRequests: csync.NewMap[string, context.CancelFunc](),
118 }
119}
120
121func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
122 if call.Prompt == "" {
123 return nil, ErrEmptyPrompt
124 }
125 if call.SessionID == "" {
126 return nil, ErrSessionMissing
127 }
128
129 // Queue the message if busy
130 if a.IsSessionBusy(call.SessionID) {
131 existing, ok := a.messageQueue.Get(call.SessionID)
132 if !ok {
133 existing = []SessionAgentCall{}
134 }
135 existing = append(existing, call)
136 a.messageQueue.Set(call.SessionID, existing)
137 return nil, nil
138 }
139
140 if len(a.tools) > 0 {
141 // Add anthropic caching to the last tool.
142 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
143 }
144
145 agent := fantasy.NewAgent(
146 a.largeModel.Model,
147 fantasy.WithSystemPrompt(a.systemPrompt),
148 fantasy.WithTools(a.tools...),
149 )
150
151 sessionLock := sync.Mutex{}
152 currentSession, err := a.sessions.Get(ctx, call.SessionID)
153 if err != nil {
154 return nil, fmt.Errorf("failed to get session: %w", err)
155 }
156
157 msgs, err := a.getSessionMessages(ctx, currentSession)
158 if err != nil {
159 return nil, fmt.Errorf("failed to get session messages: %w", err)
160 }
161
162 var wg sync.WaitGroup
163 // Generate title if first message.
164 if len(msgs) == 0 {
165 wg.Go(func() {
166 sessionLock.Lock()
167 a.generateTitle(ctx, ¤tSession, call.Prompt)
168 sessionLock.Unlock()
169 })
170 }
171
172 // Add the user message to the session.
173 _, err = a.createUserMessage(ctx, call)
174 if err != nil {
175 return nil, err
176 }
177
178 // Add the session to the context.
179 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
180
181 genCtx, cancel := context.WithCancel(ctx)
182 a.activeRequests.Set(call.SessionID, cancel)
183
184 defer cancel()
185 defer a.activeRequests.Del(call.SessionID)
186
187 history, files := a.preparePrompt(msgs, call.Attachments...)
188
189 startTime := time.Now()
190 a.eventPromptSent(call.SessionID)
191
192 var currentAssistant *message.Message
193 var shouldSummarize bool
194 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
195 Prompt: call.Prompt,
196 Files: files,
197 Messages: history,
198 ProviderOptions: call.ProviderOptions,
199 MaxOutputTokens: &call.MaxOutputTokens,
200 TopP: call.TopP,
201 Temperature: call.Temperature,
202 PresencePenalty: call.PresencePenalty,
203 TopK: call.TopK,
204 FrequencyPenalty: call.FrequencyPenalty,
205 // Before each step create a new assistant message.
206 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
207 prepared.Messages = options.Messages
208 // Reset all cached items.
209 for i := range prepared.Messages {
210 prepared.Messages[i].ProviderOptions = nil
211 }
212
213 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
214 a.messageQueue.Del(call.SessionID)
215 for _, queued := range queuedCalls {
216 userMessage, createErr := a.createUserMessage(callContext, queued)
217 if createErr != nil {
218 return callContext, prepared, createErr
219 }
220 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
221 }
222
223 lastSystemRoleInx := 0
224 systemMessageUpdated := false
225 for i, msg := range prepared.Messages {
226 // Only add cache control to the last message.
227 if msg.Role == fantasy.MessageRoleSystem {
228 lastSystemRoleInx = i
229 } else if !systemMessageUpdated {
230 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
231 systemMessageUpdated = true
232 }
233 // Than add cache control to the last 2 messages.
234 if i > len(prepared.Messages)-3 {
235 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
236 }
237 }
238
239 if a.systemPromptPrefix != "" {
240 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
241 }
242
243 var assistantMsg message.Message
244 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
245 Role: message.Assistant,
246 Parts: []message.ContentPart{},
247 Model: a.largeModel.ModelCfg.Model,
248 Provider: a.largeModel.ModelCfg.Provider,
249 })
250 if err != nil {
251 return callContext, prepared, err
252 }
253 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
254 currentAssistant = &assistantMsg
255 return callContext, prepared, err
256 },
257 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
258 currentAssistant.AppendReasoningContent(reasoning.Text)
259 return a.messages.Update(genCtx, *currentAssistant)
260 },
261 OnReasoningDelta: func(id string, text string) error {
262 currentAssistant.AppendReasoningContent(text)
263 return a.messages.Update(genCtx, *currentAssistant)
264 },
265 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
266 // handle anthropic signature
267 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
268 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
269 currentAssistant.AppendReasoningSignature(reasoning.Signature)
270 }
271 }
272 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
273 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
274 currentAssistant.AppendReasoningSignature(reasoning.Signature)
275 }
276 }
277 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
278 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
279 currentAssistant.SetReasoningResponsesData(reasoning)
280 }
281 }
282 currentAssistant.FinishThinking()
283 return a.messages.Update(genCtx, *currentAssistant)
284 },
285 OnTextDelta: func(id string, text string) error {
286 // Strip leading newline from initial text content. This is is
287 // particularly important in non-interactive mode where leading
288 // newlines are very visible.
289 if len(currentAssistant.Parts) == 0 && strings.HasPrefix(text, "\n") {
290 text = strings.TrimPrefix(text, "\n")
291 }
292
293 currentAssistant.AppendContent(text)
294 return a.messages.Update(genCtx, *currentAssistant)
295 },
296 OnToolInputStart: func(id string, toolName string) error {
297 toolCall := message.ToolCall{
298 ID: id,
299 Name: toolName,
300 ProviderExecuted: false,
301 Finished: false,
302 }
303 currentAssistant.AddToolCall(toolCall)
304 return a.messages.Update(genCtx, *currentAssistant)
305 },
306 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
307 // TODO: implement
308 },
309 OnToolCall: func(tc fantasy.ToolCallContent) error {
310 toolCall := message.ToolCall{
311 ID: tc.ToolCallID,
312 Name: tc.ToolName,
313 Input: tc.Input,
314 ProviderExecuted: false,
315 Finished: true,
316 }
317 currentAssistant.AddToolCall(toolCall)
318 return a.messages.Update(genCtx, *currentAssistant)
319 },
320 OnToolResult: func(result fantasy.ToolResultContent) error {
321 var resultContent string
322 isError := false
323 switch result.Result.GetType() {
324 case fantasy.ToolResultContentTypeText:
325 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
326 if ok {
327 resultContent = r.Text
328 }
329 case fantasy.ToolResultContentTypeError:
330 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
331 if ok {
332 isError = true
333 resultContent = r.Error.Error()
334 }
335 case fantasy.ToolResultContentTypeMedia:
336 // TODO: handle this message type
337 }
338 toolResult := message.ToolResult{
339 ToolCallID: result.ToolCallID,
340 Name: result.ToolName,
341 Content: resultContent,
342 IsError: isError,
343 Metadata: result.ClientMetadata,
344 }
345 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
346 Role: message.Tool,
347 Parts: []message.ContentPart{
348 toolResult,
349 },
350 })
351 if createMsgErr != nil {
352 return createMsgErr
353 }
354 return nil
355 },
356 OnStepFinish: func(stepResult fantasy.StepResult) error {
357 finishReason := message.FinishReasonUnknown
358 switch stepResult.FinishReason {
359 case fantasy.FinishReasonLength:
360 finishReason = message.FinishReasonMaxTokens
361 case fantasy.FinishReasonStop:
362 finishReason = message.FinishReasonEndTurn
363 case fantasy.FinishReasonToolCalls:
364 finishReason = message.FinishReasonToolUse
365 }
366 currentAssistant.AddFinish(finishReason, "", "")
367 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
368 sessionLock.Lock()
369 _, sessionErr := a.sessions.Save(genCtx, currentSession)
370 sessionLock.Unlock()
371 if sessionErr != nil {
372 return sessionErr
373 }
374 return a.messages.Update(genCtx, *currentAssistant)
375 },
376 StopWhen: []fantasy.StopCondition{
377 func(_ []fantasy.StepResult) bool {
378 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
379 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
380 remaining := cw - tokens
381 var threshold int64
382 if cw > 200_000 {
383 threshold = 20_000
384 } else {
385 threshold = int64(float64(cw) * 0.2)
386 }
387 if (remaining <= threshold) && !a.disableAutoSummarize {
388 shouldSummarize = true
389 return true
390 }
391 return false
392 },
393 },
394 })
395
396 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
397
398 if err != nil {
399 isCancelErr := errors.Is(err, context.Canceled)
400 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
401 if currentAssistant == nil {
402 return result, err
403 }
404 // Ensure we finish thinking on error to close the reasoning state.
405 currentAssistant.FinishThinking()
406 toolCalls := currentAssistant.ToolCalls()
407 // INFO: we use the parent context here because the genCtx has been cancelled.
408 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
409 if createErr != nil {
410 return nil, createErr
411 }
412 for _, tc := range toolCalls {
413 if !tc.Finished {
414 tc.Finished = true
415 tc.Input = "{}"
416 currentAssistant.AddToolCall(tc)
417 updateErr := a.messages.Update(ctx, *currentAssistant)
418 if updateErr != nil {
419 return nil, updateErr
420 }
421 }
422
423 found := false
424 for _, msg := range msgs {
425 if msg.Role == message.Tool {
426 for _, tr := range msg.ToolResults() {
427 if tr.ToolCallID == tc.ID {
428 found = true
429 break
430 }
431 }
432 }
433 if found {
434 break
435 }
436 }
437 if found {
438 continue
439 }
440 content := "There was an error while executing the tool"
441 if isCancelErr {
442 content = "Tool execution canceled by user"
443 } else if isPermissionErr {
444 content = "Permission denied"
445 }
446 toolResult := message.ToolResult{
447 ToolCallID: tc.ID,
448 Name: tc.Name,
449 Content: content,
450 IsError: true,
451 }
452 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
453 Role: message.Tool,
454 Parts: []message.ContentPart{
455 toolResult,
456 },
457 })
458 if createErr != nil {
459 return nil, createErr
460 }
461 }
462 if isCancelErr {
463 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
464 } else if isPermissionErr {
465 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
466 } else {
467 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
468 }
469 // Note: we use the parent context here because the genCtx has been
470 // cancelled.
471 updateErr := a.messages.Update(ctx, *currentAssistant)
472 if updateErr != nil {
473 return nil, updateErr
474 }
475 return nil, err
476 }
477 wg.Wait()
478
479 if shouldSummarize {
480 a.activeRequests.Del(call.SessionID)
481 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
482 return nil, summarizeErr
483 }
484 // If the agent wasn't done...
485 if len(currentAssistant.ToolCalls()) > 0 {
486 existing, ok := a.messageQueue.Get(call.SessionID)
487 if !ok {
488 existing = []SessionAgentCall{}
489 }
490 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
491 existing = append(existing, call)
492 a.messageQueue.Set(call.SessionID, existing)
493 }
494 }
495
496 // Release active request before processing queued messages.
497 a.activeRequests.Del(call.SessionID)
498 cancel()
499
500 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
501 if !ok || len(queuedMessages) == 0 {
502 return result, err
503 }
504 // There are queued messages restart the loop.
505 firstQueuedMessage := queuedMessages[0]
506 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
507 return a.Run(ctx, firstQueuedMessage)
508}
509
510func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
511 if a.IsSessionBusy(sessionID) {
512 return ErrSessionBusy
513 }
514
515 currentSession, err := a.sessions.Get(ctx, sessionID)
516 if err != nil {
517 return fmt.Errorf("failed to get session: %w", err)
518 }
519 msgs, err := a.getSessionMessages(ctx, currentSession)
520 if err != nil {
521 return err
522 }
523 if len(msgs) == 0 {
524 // Nothing to summarize.
525 return nil
526 }
527
528 aiMsgs, _ := a.preparePrompt(msgs)
529
530 genCtx, cancel := context.WithCancel(ctx)
531 a.activeRequests.Set(sessionID, cancel)
532 defer a.activeRequests.Del(sessionID)
533 defer cancel()
534
535 agent := fantasy.NewAgent(a.largeModel.Model,
536 fantasy.WithSystemPrompt(string(summaryPrompt)),
537 )
538 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
539 Role: message.Assistant,
540 Model: a.largeModel.Model.Model(),
541 Provider: a.largeModel.Model.Provider(),
542 IsSummaryMessage: true,
543 })
544 if err != nil {
545 return err
546 }
547
548 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
549 Prompt: "Provide a detailed summary of our conversation above.",
550 Messages: aiMsgs,
551 ProviderOptions: opts,
552 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
553 prepared.Messages = options.Messages
554 if a.systemPromptPrefix != "" {
555 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
556 }
557 return callContext, prepared, nil
558 },
559 OnReasoningDelta: func(id string, text string) error {
560 summaryMessage.AppendReasoningContent(text)
561 return a.messages.Update(genCtx, summaryMessage)
562 },
563 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
564 // Handle anthropic signature.
565 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
566 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
567 summaryMessage.AppendReasoningSignature(signature.Signature)
568 }
569 }
570 summaryMessage.FinishThinking()
571 return a.messages.Update(genCtx, summaryMessage)
572 },
573 OnTextDelta: func(id, text string) error {
574 summaryMessage.AppendContent(text)
575 return a.messages.Update(genCtx, summaryMessage)
576 },
577 })
578 if err != nil {
579 isCancelErr := errors.Is(err, context.Canceled)
580 if isCancelErr {
581 // User cancelled summarize we need to remove the summary message.
582 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
583 return deleteErr
584 }
585 return err
586 }
587
588 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
589 err = a.messages.Update(genCtx, summaryMessage)
590 if err != nil {
591 return err
592 }
593
594 var openrouterCost *float64
595 for _, step := range resp.Steps {
596 stepCost := a.openrouterCost(step.ProviderMetadata)
597 if stepCost != nil {
598 newCost := *stepCost
599 if openrouterCost != nil {
600 newCost += *openrouterCost
601 }
602 openrouterCost = &newCost
603 }
604 }
605
606 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
607
608 // Just in case, get just the last usage info.
609 usage := resp.Response.Usage
610 currentSession.SummaryMessageID = summaryMessage.ID
611 currentSession.CompletionTokens = usage.OutputTokens
612 currentSession.PromptTokens = 0
613 _, err = a.sessions.Save(genCtx, currentSession)
614 return err
615}
616
617func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
618 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
619 return fantasy.ProviderOptions{}
620 }
621 return fantasy.ProviderOptions{
622 anthropic.Name: &anthropic.ProviderCacheControlOptions{
623 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
624 },
625 bedrock.Name: &anthropic.ProviderCacheControlOptions{
626 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
627 },
628 }
629}
630
631func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
632 var attachmentParts []message.ContentPart
633 for _, attachment := range call.Attachments {
634 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
635 }
636 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
637 parts = append(parts, attachmentParts...)
638 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
639 Role: message.User,
640 Parts: parts,
641 })
642 if err != nil {
643 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
644 }
645 return msg, nil
646}
647
648func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
649 var history []fantasy.Message
650 for _, m := range msgs {
651 if len(m.Parts) == 0 {
652 continue
653 }
654 // Assistant message without content or tool calls (cancelled before it
655 // returned anything).
656 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
657 continue
658 }
659 history = append(history, m.ToAIMessage()...)
660 }
661
662 var files []fantasy.FilePart
663 for _, attachment := range attachments {
664 files = append(files, fantasy.FilePart{
665 Filename: attachment.FileName,
666 Data: attachment.Content,
667 MediaType: attachment.MimeType,
668 })
669 }
670
671 return history, files
672}
673
674func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
675 msgs, err := a.messages.List(ctx, session.ID)
676 if err != nil {
677 return nil, fmt.Errorf("failed to list messages: %w", err)
678 }
679
680 if session.SummaryMessageID != "" {
681 summaryMsgInex := -1
682 for i, msg := range msgs {
683 if msg.ID == session.SummaryMessageID {
684 summaryMsgInex = i
685 break
686 }
687 }
688 if summaryMsgInex != -1 {
689 msgs = msgs[summaryMsgInex:]
690 msgs[0].Role = message.User
691 }
692 }
693 return msgs, nil
694}
695
696func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
697 if prompt == "" {
698 return
699 }
700
701 var maxOutput int64 = 40
702 if a.smallModel.CatwalkCfg.CanReason {
703 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
704 }
705
706 agent := fantasy.NewAgent(a.smallModel.Model,
707 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
708 fantasy.WithMaxOutputTokens(maxOutput),
709 )
710
711 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
712 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
713 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
714 prepared.Messages = options.Messages
715 if a.systemPromptPrefix != "" {
716 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
717 }
718 return callContext, prepared, nil
719 },
720 })
721 if err != nil {
722 slog.Error("error generating title", "err", err)
723 return
724 }
725
726 title := resp.Response.Content.Text()
727
728 title = strings.ReplaceAll(title, "\n", " ")
729
730 // Remove thinking tags if present.
731 if idx := strings.Index(title, "</think>"); idx > 0 {
732 title = title[idx+len("</think>"):]
733 }
734
735 title = strings.TrimSpace(title)
736 if title == "" {
737 slog.Warn("failed to generate title", "warn", "empty title")
738 return
739 }
740
741 session.Title = title
742
743 var openrouterCost *float64
744 for _, step := range resp.Steps {
745 stepCost := a.openrouterCost(step.ProviderMetadata)
746 if stepCost != nil {
747 newCost := *stepCost
748 if openrouterCost != nil {
749 newCost += *openrouterCost
750 }
751 openrouterCost = &newCost
752 }
753 }
754
755 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
756 _, saveErr := a.sessions.Save(ctx, *session)
757 if saveErr != nil {
758 slog.Error("failed to save session title & usage", "error", saveErr)
759 return
760 }
761}
762
763func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
764 openrouterMetadata, ok := metadata[openrouter.Name]
765 if !ok {
766 return nil
767 }
768
769 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
770 if !ok {
771 return nil
772 }
773 return &opts.Usage.Cost
774}
775
776func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
777 modelConfig := model.CatwalkCfg
778 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
779 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
780 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
781 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
782
783 a.eventTokensUsed(session.ID, model, usage, cost)
784
785 if overrideCost != nil {
786 session.Cost += *overrideCost
787 } else {
788 session.Cost += cost
789 }
790
791 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
792 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
793}
794
795func (a *sessionAgent) Cancel(sessionID string) {
796 // Cancel regular requests.
797 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
798 slog.Info("Request cancellation initiated", "session_id", sessionID)
799 cancel()
800 }
801
802 // Also check for summarize requests.
803 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
804 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
805 cancel()
806 }
807
808 if a.QueuedPrompts(sessionID) > 0 {
809 slog.Info("Clearing queued prompts", "session_id", sessionID)
810 a.messageQueue.Del(sessionID)
811 }
812}
813
814func (a *sessionAgent) ClearQueue(sessionID string) {
815 if a.QueuedPrompts(sessionID) > 0 {
816 slog.Info("Clearing queued prompts", "session_id", sessionID)
817 a.messageQueue.Del(sessionID)
818 }
819}
820
821func (a *sessionAgent) CancelAll() {
822 if !a.IsBusy() {
823 return
824 }
825 for key := range a.activeRequests.Seq2() {
826 a.Cancel(key) // key is sessionID
827 }
828
829 timeout := time.After(5 * time.Second)
830 for a.IsBusy() {
831 select {
832 case <-timeout:
833 return
834 default:
835 time.Sleep(200 * time.Millisecond)
836 }
837 }
838}
839
840func (a *sessionAgent) IsBusy() bool {
841 var busy bool
842 for cancelFunc := range a.activeRequests.Seq() {
843 if cancelFunc != nil {
844 busy = true
845 break
846 }
847 }
848 return busy
849}
850
851func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
852 _, busy := a.activeRequests.Get(sessionID)
853 return busy
854}
855
856func (a *sessionAgent) QueuedPrompts(sessionID string) int {
857 l, ok := a.messageQueue.Get(sessionID)
858 if !ok {
859 return 0
860 }
861 return len(l)
862}
863
864func (a *sessionAgent) SetModels(large Model, small Model) {
865 a.largeModel = large
866 a.smallModel = small
867}
868
869func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
870 a.tools = tools
871}
872
873func (a *sessionAgent) Model() Model {
874 return a.largeModel
875}