1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "github.com/charmbracelet/catwalk/pkg/catwalk"
30 "github.com/charmbracelet/crush/internal/agent/tools"
31 "github.com/charmbracelet/crush/internal/config"
32 "github.com/charmbracelet/crush/internal/csync"
33 "github.com/charmbracelet/crush/internal/message"
34 "github.com/charmbracelet/crush/internal/permission"
35 "github.com/charmbracelet/crush/internal/session"
36 "github.com/charmbracelet/crush/internal/stringext"
37)
38
39//go:embed templates/title.md
40var titlePrompt []byte
41
42//go:embed templates/summary.md
43var summaryPrompt []byte
44
45type SessionAgentCall struct {
46 SessionID string
47 Prompt string
48 ProviderOptions fantasy.ProviderOptions
49 Attachments []message.Attachment
50 MaxOutputTokens int64
51 Temperature *float64
52 TopP *float64
53 TopK *int64
54 FrequencyPenalty *float64
55 PresencePenalty *float64
56}
57
58type SessionAgent interface {
59 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
60 SetModels(large Model, small Model)
61 SetTools(tools []fantasy.AgentTool)
62 Cancel(sessionID string)
63 CancelAll()
64 IsSessionBusy(sessionID string) bool
65 IsBusy() bool
66 QueuedPrompts(sessionID string) int
67 ClearQueue(sessionID string)
68 Summarize(context.Context, string, fantasy.ProviderOptions) error
69 Model() Model
70}
71
72type Model struct {
73 Model fantasy.LanguageModel
74 CatwalkCfg catwalk.Model
75 ModelCfg config.SelectedModel
76}
77
78type sessionAgent struct {
79 largeModel Model
80 smallModel Model
81 systemPromptPrefix string
82 systemPrompt string
83 tools []fantasy.AgentTool
84 sessions session.Service
85 messages message.Service
86 disableAutoSummarize bool
87 isYolo bool
88
89 messageQueue *csync.Map[string, []SessionAgentCall]
90 activeRequests *csync.Map[string, context.CancelFunc]
91}
92
93type SessionAgentOptions struct {
94 LargeModel Model
95 SmallModel Model
96 SystemPromptPrefix string
97 SystemPrompt string
98 DisableAutoSummarize bool
99 IsYolo bool
100 Sessions session.Service
101 Messages message.Service
102 Tools []fantasy.AgentTool
103}
104
105func NewSessionAgent(
106 opts SessionAgentOptions,
107) SessionAgent {
108 return &sessionAgent{
109 largeModel: opts.LargeModel,
110 smallModel: opts.SmallModel,
111 systemPromptPrefix: opts.SystemPromptPrefix,
112 systemPrompt: opts.SystemPrompt,
113 sessions: opts.Sessions,
114 messages: opts.Messages,
115 disableAutoSummarize: opts.DisableAutoSummarize,
116 tools: opts.Tools,
117 isYolo: opts.IsYolo,
118 messageQueue: csync.NewMap[string, []SessionAgentCall](),
119 activeRequests: csync.NewMap[string, context.CancelFunc](),
120 }
121}
122
123func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
124 if call.Prompt == "" {
125 return nil, ErrEmptyPrompt
126 }
127 if call.SessionID == "" {
128 return nil, ErrSessionMissing
129 }
130
131 // Queue the message if busy
132 if a.IsSessionBusy(call.SessionID) {
133 existing, ok := a.messageQueue.Get(call.SessionID)
134 if !ok {
135 existing = []SessionAgentCall{}
136 }
137 existing = append(existing, call)
138 a.messageQueue.Set(call.SessionID, existing)
139 return nil, nil
140 }
141
142 if len(a.tools) > 0 {
143 // Add Anthropic caching to the last tool.
144 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
145 }
146
147 agent := fantasy.NewAgent(
148 a.largeModel.Model,
149 fantasy.WithSystemPrompt(a.systemPrompt),
150 fantasy.WithTools(a.tools...),
151 )
152
153 sessionLock := sync.Mutex{}
154 currentSession, err := a.sessions.Get(ctx, call.SessionID)
155 if err != nil {
156 return nil, fmt.Errorf("failed to get session: %w", err)
157 }
158
159 msgs, err := a.getSessionMessages(ctx, currentSession)
160 if err != nil {
161 return nil, fmt.Errorf("failed to get session messages: %w", err)
162 }
163
164 var wg sync.WaitGroup
165 // Generate title if first message.
166 if len(msgs) == 0 {
167 wg.Go(func() {
168 sessionLock.Lock()
169 a.generateTitle(ctx, ¤tSession, call.Prompt)
170 sessionLock.Unlock()
171 })
172 }
173
174 // Add the user message to the session.
175 _, err = a.createUserMessage(ctx, call)
176 if err != nil {
177 return nil, err
178 }
179
180 // Add the session to the context.
181 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
182
183 genCtx, cancel := context.WithCancel(ctx)
184 a.activeRequests.Set(call.SessionID, cancel)
185
186 defer cancel()
187 defer a.activeRequests.Del(call.SessionID)
188
189 history, files := a.preparePrompt(msgs, call.Attachments...)
190
191 startTime := time.Now()
192 a.eventPromptSent(call.SessionID)
193
194 var currentAssistant *message.Message
195 var shouldSummarize bool
196 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
197 Prompt: call.Prompt,
198 Files: files,
199 Messages: history,
200 ProviderOptions: call.ProviderOptions,
201 MaxOutputTokens: &call.MaxOutputTokens,
202 TopP: call.TopP,
203 Temperature: call.Temperature,
204 PresencePenalty: call.PresencePenalty,
205 TopK: call.TopK,
206 FrequencyPenalty: call.FrequencyPenalty,
207 // Before each step create a new assistant message.
208 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
209 prepared.Messages = options.Messages
210 // Reset all cached items.
211 for i := range prepared.Messages {
212 prepared.Messages[i].ProviderOptions = nil
213 }
214
215 // Use latest tools (updated by SetTools when MCP tools change).
216 prepared.Tools = a.tools
217
218 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
219 a.messageQueue.Del(call.SessionID)
220 for _, queued := range queuedCalls {
221 userMessage, createErr := a.createUserMessage(callContext, queued)
222 if createErr != nil {
223 return callContext, prepared, createErr
224 }
225 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
226 }
227
228 lastSystemRoleInx := 0
229 systemMessageUpdated := false
230 for i, msg := range prepared.Messages {
231 // Only add cache control to the last message.
232 if msg.Role == fantasy.MessageRoleSystem {
233 lastSystemRoleInx = i
234 } else if !systemMessageUpdated {
235 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
236 systemMessageUpdated = true
237 }
238 // Than add cache control to the last 2 messages.
239 if i > len(prepared.Messages)-3 {
240 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
241 }
242 }
243
244 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
245 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
246 }
247
248 var assistantMsg message.Message
249 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
250 Role: message.Assistant,
251 Parts: []message.ContentPart{},
252 Model: a.largeModel.ModelCfg.Model,
253 Provider: a.largeModel.ModelCfg.Provider,
254 })
255 if err != nil {
256 return callContext, prepared, err
257 }
258 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
259 currentAssistant = &assistantMsg
260 return callContext, prepared, err
261 },
262 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
263 currentAssistant.AppendReasoningContent(reasoning.Text)
264 return a.messages.Update(genCtx, *currentAssistant)
265 },
266 OnReasoningDelta: func(id string, text string) error {
267 currentAssistant.AppendReasoningContent(text)
268 return a.messages.Update(genCtx, *currentAssistant)
269 },
270 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
271 // handle anthropic signature
272 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
273 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
274 currentAssistant.AppendReasoningSignature(reasoning.Signature)
275 }
276 }
277 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
278 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
279 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
280 }
281 }
282 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
283 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
284 currentAssistant.SetReasoningResponsesData(reasoning)
285 }
286 }
287 currentAssistant.FinishThinking()
288 return a.messages.Update(genCtx, *currentAssistant)
289 },
290 OnTextDelta: func(id string, text string) error {
291 // Strip leading newline from initial text content. This is is
292 // particularly important in non-interactive mode where leading
293 // newlines are very visible.
294 if len(currentAssistant.Parts) == 0 {
295 text = strings.TrimPrefix(text, "\n")
296 }
297
298 currentAssistant.AppendContent(text)
299 return a.messages.Update(genCtx, *currentAssistant)
300 },
301 OnToolInputStart: func(id string, toolName string) error {
302 toolCall := message.ToolCall{
303 ID: id,
304 Name: toolName,
305 ProviderExecuted: false,
306 Finished: false,
307 }
308 currentAssistant.AddToolCall(toolCall)
309 return a.messages.Update(genCtx, *currentAssistant)
310 },
311 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
312 // TODO: implement
313 },
314 OnToolCall: func(tc fantasy.ToolCallContent) error {
315 toolCall := message.ToolCall{
316 ID: tc.ToolCallID,
317 Name: tc.ToolName,
318 Input: tc.Input,
319 ProviderExecuted: false,
320 Finished: true,
321 }
322 currentAssistant.AddToolCall(toolCall)
323 return a.messages.Update(genCtx, *currentAssistant)
324 },
325 OnToolResult: func(result fantasy.ToolResultContent) error {
326 var resultContent string
327 isError := false
328 switch result.Result.GetType() {
329 case fantasy.ToolResultContentTypeText:
330 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
331 if ok {
332 resultContent = r.Text
333 }
334 case fantasy.ToolResultContentTypeError:
335 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
336 if ok {
337 isError = true
338 resultContent = r.Error.Error()
339 }
340 case fantasy.ToolResultContentTypeMedia:
341 // TODO: handle this message type
342 }
343 toolResult := message.ToolResult{
344 ToolCallID: result.ToolCallID,
345 Name: result.ToolName,
346 Content: resultContent,
347 IsError: isError,
348 Metadata: result.ClientMetadata,
349 }
350 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
351 Role: message.Tool,
352 Parts: []message.ContentPart{
353 toolResult,
354 },
355 })
356 if createMsgErr != nil {
357 return createMsgErr
358 }
359 return nil
360 },
361 OnStepFinish: func(stepResult fantasy.StepResult) error {
362 finishReason := message.FinishReasonUnknown
363 switch stepResult.FinishReason {
364 case fantasy.FinishReasonLength:
365 finishReason = message.FinishReasonMaxTokens
366 case fantasy.FinishReasonStop:
367 finishReason = message.FinishReasonEndTurn
368 case fantasy.FinishReasonToolCalls:
369 finishReason = message.FinishReasonToolUse
370 }
371 currentAssistant.AddFinish(finishReason, "", "")
372 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
373 sessionLock.Lock()
374 _, sessionErr := a.sessions.Save(genCtx, currentSession)
375 sessionLock.Unlock()
376 if sessionErr != nil {
377 return sessionErr
378 }
379 return a.messages.Update(genCtx, *currentAssistant)
380 },
381 StopWhen: []fantasy.StopCondition{
382 func(_ []fantasy.StepResult) bool {
383 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
384 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
385 remaining := cw - tokens
386 var threshold int64
387 if cw > 200_000 {
388 threshold = 20_000
389 } else {
390 threshold = int64(float64(cw) * 0.2)
391 }
392 if (remaining <= threshold) && !a.disableAutoSummarize {
393 shouldSummarize = true
394 return true
395 }
396 return false
397 },
398 },
399 })
400
401 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
402
403 if err != nil {
404 isCancelErr := errors.Is(err, context.Canceled)
405 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
406 if currentAssistant == nil {
407 return result, err
408 }
409 // Ensure we finish thinking on error to close the reasoning state.
410 currentAssistant.FinishThinking()
411 toolCalls := currentAssistant.ToolCalls()
412 // INFO: we use the parent context here because the genCtx has been cancelled.
413 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
414 if createErr != nil {
415 return nil, createErr
416 }
417 for _, tc := range toolCalls {
418 if !tc.Finished {
419 tc.Finished = true
420 tc.Input = "{}"
421 currentAssistant.AddToolCall(tc)
422 updateErr := a.messages.Update(ctx, *currentAssistant)
423 if updateErr != nil {
424 return nil, updateErr
425 }
426 }
427
428 found := false
429 for _, msg := range msgs {
430 if msg.Role == message.Tool {
431 for _, tr := range msg.ToolResults() {
432 if tr.ToolCallID == tc.ID {
433 found = true
434 break
435 }
436 }
437 }
438 if found {
439 break
440 }
441 }
442 if found {
443 continue
444 }
445 content := "There was an error while executing the tool"
446 if isCancelErr {
447 content = "Tool execution canceled by user"
448 } else if isPermissionErr {
449 content = "User denied permission"
450 }
451 toolResult := message.ToolResult{
452 ToolCallID: tc.ID,
453 Name: tc.Name,
454 Content: content,
455 IsError: true,
456 }
457 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
458 Role: message.Tool,
459 Parts: []message.ContentPart{
460 toolResult,
461 },
462 })
463 if createErr != nil {
464 return nil, createErr
465 }
466 }
467 var fantasyErr *fantasy.Error
468 var providerErr *fantasy.ProviderError
469 const defaultTitle = "Provider Error"
470 if isCancelErr {
471 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
472 } else if isPermissionErr {
473 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
474 } else if errors.As(err, &providerErr) {
475 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
476 } else if errors.As(err, &fantasyErr) {
477 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
478 } else {
479 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
480 }
481 // Note: we use the parent context here because the genCtx has been
482 // cancelled.
483 updateErr := a.messages.Update(ctx, *currentAssistant)
484 if updateErr != nil {
485 return nil, updateErr
486 }
487 return nil, err
488 }
489 wg.Wait()
490
491 if shouldSummarize {
492 a.activeRequests.Del(call.SessionID)
493 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
494 return nil, summarizeErr
495 }
496 // If the agent wasn't done...
497 if len(currentAssistant.ToolCalls()) > 0 {
498 existing, ok := a.messageQueue.Get(call.SessionID)
499 if !ok {
500 existing = []SessionAgentCall{}
501 }
502 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
503 existing = append(existing, call)
504 a.messageQueue.Set(call.SessionID, existing)
505 }
506 }
507
508 // Release active request before processing queued messages.
509 a.activeRequests.Del(call.SessionID)
510 cancel()
511
512 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
513 if !ok || len(queuedMessages) == 0 {
514 return result, err
515 }
516 // There are queued messages restart the loop.
517 firstQueuedMessage := queuedMessages[0]
518 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
519 return a.Run(ctx, firstQueuedMessage)
520}
521
522func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
523 if a.IsSessionBusy(sessionID) {
524 return ErrSessionBusy
525 }
526
527 currentSession, err := a.sessions.Get(ctx, sessionID)
528 if err != nil {
529 return fmt.Errorf("failed to get session: %w", err)
530 }
531 msgs, err := a.getSessionMessages(ctx, currentSession)
532 if err != nil {
533 return err
534 }
535 if len(msgs) == 0 {
536 // Nothing to summarize.
537 return nil
538 }
539
540 aiMsgs, _ := a.preparePrompt(msgs)
541
542 genCtx, cancel := context.WithCancel(ctx)
543 a.activeRequests.Set(sessionID, cancel)
544 defer a.activeRequests.Del(sessionID)
545 defer cancel()
546
547 agent := fantasy.NewAgent(a.largeModel.Model,
548 fantasy.WithSystemPrompt(string(summaryPrompt)),
549 )
550 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
551 Role: message.Assistant,
552 Model: a.largeModel.Model.Model(),
553 Provider: a.largeModel.Model.Provider(),
554 IsSummaryMessage: true,
555 })
556 if err != nil {
557 return err
558 }
559
560 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
561 Prompt: "Provide a detailed summary of our conversation above.",
562 Messages: aiMsgs,
563 ProviderOptions: opts,
564 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
565 prepared.Messages = options.Messages
566 if a.systemPromptPrefix != "" {
567 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
568 }
569 return callContext, prepared, nil
570 },
571 OnReasoningDelta: func(id string, text string) error {
572 summaryMessage.AppendReasoningContent(text)
573 return a.messages.Update(genCtx, summaryMessage)
574 },
575 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
576 // Handle anthropic signature.
577 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
578 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
579 summaryMessage.AppendReasoningSignature(signature.Signature)
580 }
581 }
582 summaryMessage.FinishThinking()
583 return a.messages.Update(genCtx, summaryMessage)
584 },
585 OnTextDelta: func(id, text string) error {
586 summaryMessage.AppendContent(text)
587 return a.messages.Update(genCtx, summaryMessage)
588 },
589 })
590 if err != nil {
591 isCancelErr := errors.Is(err, context.Canceled)
592 if isCancelErr {
593 // User cancelled summarize we need to remove the summary message.
594 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
595 return deleteErr
596 }
597 return err
598 }
599
600 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
601 err = a.messages.Update(genCtx, summaryMessage)
602 if err != nil {
603 return err
604 }
605
606 var openrouterCost *float64
607 for _, step := range resp.Steps {
608 stepCost := a.openrouterCost(step.ProviderMetadata)
609 if stepCost != nil {
610 newCost := *stepCost
611 if openrouterCost != nil {
612 newCost += *openrouterCost
613 }
614 openrouterCost = &newCost
615 }
616 }
617
618 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
619
620 // Just in case, get just the last usage info.
621 usage := resp.Response.Usage
622 currentSession.SummaryMessageID = summaryMessage.ID
623 currentSession.CompletionTokens = usage.OutputTokens
624 currentSession.PromptTokens = 0
625 _, err = a.sessions.Save(genCtx, currentSession)
626 return err
627}
628
629func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
630 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
631 return fantasy.ProviderOptions{}
632 }
633 return fantasy.ProviderOptions{
634 anthropic.Name: &anthropic.ProviderCacheControlOptions{
635 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
636 },
637 bedrock.Name: &anthropic.ProviderCacheControlOptions{
638 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
639 },
640 }
641}
642
643func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
644 var attachmentParts []message.ContentPart
645 for _, attachment := range call.Attachments {
646 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
647 }
648 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
649 parts = append(parts, attachmentParts...)
650 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
651 Role: message.User,
652 Parts: parts,
653 })
654 if err != nil {
655 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
656 }
657 return msg, nil
658}
659
660func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
661 var history []fantasy.Message
662 for _, m := range msgs {
663 if len(m.Parts) == 0 {
664 continue
665 }
666 // Assistant message without content or tool calls (cancelled before it
667 // returned anything).
668 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
669 continue
670 }
671 history = append(history, m.ToAIMessage()...)
672 }
673
674 var files []fantasy.FilePart
675 for _, attachment := range attachments {
676 files = append(files, fantasy.FilePart{
677 Filename: attachment.FileName,
678 Data: attachment.Content,
679 MediaType: attachment.MimeType,
680 })
681 }
682
683 return history, files
684}
685
686func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
687 msgs, err := a.messages.List(ctx, session.ID)
688 if err != nil {
689 return nil, fmt.Errorf("failed to list messages: %w", err)
690 }
691
692 if session.SummaryMessageID != "" {
693 summaryMsgInex := -1
694 for i, msg := range msgs {
695 if msg.ID == session.SummaryMessageID {
696 summaryMsgInex = i
697 break
698 }
699 }
700 if summaryMsgInex != -1 {
701 msgs = msgs[summaryMsgInex:]
702 msgs[0].Role = message.User
703 }
704 }
705 return msgs, nil
706}
707
708func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
709 if prompt == "" {
710 return
711 }
712
713 var maxOutput int64 = 40
714 if a.smallModel.CatwalkCfg.CanReason {
715 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
716 }
717
718 agent := fantasy.NewAgent(a.smallModel.Model,
719 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
720 fantasy.WithMaxOutputTokens(maxOutput),
721 )
722
723 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
724 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
725 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
726 prepared.Messages = options.Messages
727 if a.systemPromptPrefix != "" {
728 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
729 }
730 return callContext, prepared, nil
731 },
732 })
733 if err != nil {
734 slog.Error("error generating title", "err", err)
735 return
736 }
737
738 title := resp.Response.Content.Text()
739
740 title = strings.ReplaceAll(title, "\n", " ")
741
742 // Remove thinking tags if present.
743 if idx := strings.Index(title, "</think>"); idx > 0 {
744 title = title[idx+len("</think>"):]
745 }
746
747 title = strings.TrimSpace(title)
748 if title == "" {
749 slog.Warn("failed to generate title", "warn", "empty title")
750 return
751 }
752
753 session.Title = title
754
755 var openrouterCost *float64
756 for _, step := range resp.Steps {
757 stepCost := a.openrouterCost(step.ProviderMetadata)
758 if stepCost != nil {
759 newCost := *stepCost
760 if openrouterCost != nil {
761 newCost += *openrouterCost
762 }
763 openrouterCost = &newCost
764 }
765 }
766
767 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
768 _, saveErr := a.sessions.Save(ctx, *session)
769 if saveErr != nil {
770 slog.Error("failed to save session title & usage", "error", saveErr)
771 return
772 }
773}
774
775func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
776 openrouterMetadata, ok := metadata[openrouter.Name]
777 if !ok {
778 return nil
779 }
780
781 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
782 if !ok {
783 return nil
784 }
785 return &opts.Usage.Cost
786}
787
788func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
789 modelConfig := model.CatwalkCfg
790 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
791 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
792 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
793 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
794
795 if a.isClaudeCode() {
796 cost = 0
797 }
798
799 a.eventTokensUsed(session.ID, model, usage, cost)
800
801 if overrideCost != nil {
802 session.Cost += *overrideCost
803 } else {
804 session.Cost += cost
805 }
806
807 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
808 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
809}
810
811func (a *sessionAgent) Cancel(sessionID string) {
812 // Cancel regular requests.
813 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
814 slog.Info("Request cancellation initiated", "session_id", sessionID)
815 cancel()
816 }
817
818 // Also check for summarize requests.
819 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
820 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
821 cancel()
822 }
823
824 if a.QueuedPrompts(sessionID) > 0 {
825 slog.Info("Clearing queued prompts", "session_id", sessionID)
826 a.messageQueue.Del(sessionID)
827 }
828}
829
830func (a *sessionAgent) ClearQueue(sessionID string) {
831 if a.QueuedPrompts(sessionID) > 0 {
832 slog.Info("Clearing queued prompts", "session_id", sessionID)
833 a.messageQueue.Del(sessionID)
834 }
835}
836
837func (a *sessionAgent) CancelAll() {
838 if !a.IsBusy() {
839 return
840 }
841 for key := range a.activeRequests.Seq2() {
842 a.Cancel(key) // key is sessionID
843 }
844
845 timeout := time.After(5 * time.Second)
846 for a.IsBusy() {
847 select {
848 case <-timeout:
849 return
850 default:
851 time.Sleep(200 * time.Millisecond)
852 }
853 }
854}
855
856func (a *sessionAgent) IsBusy() bool {
857 var busy bool
858 for cancelFunc := range a.activeRequests.Seq() {
859 if cancelFunc != nil {
860 busy = true
861 break
862 }
863 }
864 return busy
865}
866
867func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
868 _, busy := a.activeRequests.Get(sessionID)
869 return busy
870}
871
872func (a *sessionAgent) QueuedPrompts(sessionID string) int {
873 l, ok := a.messageQueue.Get(sessionID)
874 if !ok {
875 return 0
876 }
877 return len(l)
878}
879
880func (a *sessionAgent) SetModels(large Model, small Model) {
881 a.largeModel = large
882 a.smallModel = small
883}
884
885func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
886 a.tools = tools
887}
888
889func (a *sessionAgent) Model() Model {
890 return a.largeModel
891}
892
893func (a *sessionAgent) promptPrefix() string {
894 if a.isClaudeCode() {
895 return "You are Claude Code, Anthropic's official CLI for Claude."
896 }
897 return a.systemPromptPrefix
898}
899
900func (a *sessionAgent) isClaudeCode() bool {
901 cfg := config.Get()
902 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
903 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
904}