1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "github.com/charmbracelet/catwalk/pkg/catwalk"
30 "github.com/charmbracelet/crush/internal/agent/tools"
31 "github.com/charmbracelet/crush/internal/config"
32 "github.com/charmbracelet/crush/internal/csync"
33 "github.com/charmbracelet/crush/internal/message"
34 "github.com/charmbracelet/crush/internal/permission"
35 "github.com/charmbracelet/crush/internal/session"
36 "github.com/charmbracelet/crush/internal/stringext"
37)
38
39//go:embed templates/title.md
40var titlePrompt []byte
41
42//go:embed templates/summary.md
43var summaryPrompt []byte
44
45type SessionAgentCall struct {
46 SessionID string
47 Prompt string
48 ProviderOptions fantasy.ProviderOptions
49 Attachments []message.Attachment
50 MaxOutputTokens int64
51 Temperature *float64
52 TopP *float64
53 TopK *int64
54 FrequencyPenalty *float64
55 PresencePenalty *float64
56}
57
58type SessionAgent interface {
59 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
60 SetModels(large Model, small Model)
61 SetTools(tools []fantasy.AgentTool)
62 Cancel(sessionID string)
63 CancelAll()
64 IsSessionBusy(sessionID string) bool
65 IsBusy() bool
66 QueuedPrompts(sessionID string) int
67 ClearQueue(sessionID string)
68 Summarize(context.Context, string, fantasy.ProviderOptions) error
69 Model() Model
70}
71
72type Model struct {
73 Model fantasy.LanguageModel
74 CatwalkCfg catwalk.Model
75 ModelCfg config.SelectedModel
76}
77
78type sessionAgent struct {
79 largeModel Model
80 smallModel Model
81 systemPromptPrefix string
82 systemPrompt string
83 tools []fantasy.AgentTool
84 sessions session.Service
85 messages message.Service
86 disableAutoSummarize bool
87 isYolo bool
88
89 messageQueue *csync.Map[string, []SessionAgentCall]
90 activeRequests *csync.Map[string, context.CancelFunc]
91}
92
93type SessionAgentOptions struct {
94 LargeModel Model
95 SmallModel Model
96 SystemPromptPrefix string
97 SystemPrompt string
98 DisableAutoSummarize bool
99 IsYolo bool
100 Sessions session.Service
101 Messages message.Service
102 Tools []fantasy.AgentTool
103}
104
105func NewSessionAgent(
106 opts SessionAgentOptions,
107) SessionAgent {
108 return &sessionAgent{
109 largeModel: opts.LargeModel,
110 smallModel: opts.SmallModel,
111 systemPromptPrefix: opts.SystemPromptPrefix,
112 systemPrompt: opts.SystemPrompt,
113 sessions: opts.Sessions,
114 messages: opts.Messages,
115 disableAutoSummarize: opts.DisableAutoSummarize,
116 tools: opts.Tools,
117 isYolo: opts.IsYolo,
118 messageQueue: csync.NewMap[string, []SessionAgentCall](),
119 activeRequests: csync.NewMap[string, context.CancelFunc](),
120 }
121}
122
123func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
124 if call.Prompt == "" {
125 return nil, ErrEmptyPrompt
126 }
127 if call.SessionID == "" {
128 return nil, ErrSessionMissing
129 }
130
131 // Queue the message if busy
132 if a.IsSessionBusy(call.SessionID) {
133 existing, ok := a.messageQueue.Get(call.SessionID)
134 if !ok {
135 existing = []SessionAgentCall{}
136 }
137 existing = append(existing, call)
138 a.messageQueue.Set(call.SessionID, existing)
139 return nil, nil
140 }
141
142 if len(a.tools) > 0 {
143 // Add Anthropic caching to the last tool.
144 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
145 }
146
147 agent := fantasy.NewAgent(
148 a.largeModel.Model,
149 fantasy.WithSystemPrompt(a.systemPrompt),
150 fantasy.WithTools(a.tools...),
151 )
152
153 sessionLock := sync.Mutex{}
154 currentSession, err := a.sessions.Get(ctx, call.SessionID)
155 if err != nil {
156 return nil, fmt.Errorf("failed to get session: %w", err)
157 }
158
159 msgs, err := a.getSessionMessages(ctx, currentSession)
160 if err != nil {
161 return nil, fmt.Errorf("failed to get session messages: %w", err)
162 }
163
164 var wg sync.WaitGroup
165 // Generate title if first message.
166 if len(msgs) == 0 {
167 wg.Go(func() {
168 sessionLock.Lock()
169 a.generateTitle(ctx, ¤tSession, call.Prompt)
170 sessionLock.Unlock()
171 })
172 }
173
174 // Add the user message to the session.
175 _, err = a.createUserMessage(ctx, call)
176 if err != nil {
177 return nil, err
178 }
179
180 // Add the session to the context.
181 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
182
183 genCtx, cancel := context.WithCancel(ctx)
184 a.activeRequests.Set(call.SessionID, cancel)
185
186 defer cancel()
187 defer a.activeRequests.Del(call.SessionID)
188
189 history, files := a.preparePrompt(msgs, call.Attachments...)
190
191 startTime := time.Now()
192 a.eventPromptSent(call.SessionID)
193
194 var currentAssistant *message.Message
195 var shouldSummarize bool
196 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
197 Prompt: call.Prompt,
198 Files: files,
199 Messages: history,
200 ProviderOptions: call.ProviderOptions,
201 MaxOutputTokens: &call.MaxOutputTokens,
202 TopP: call.TopP,
203 Temperature: call.Temperature,
204 PresencePenalty: call.PresencePenalty,
205 TopK: call.TopK,
206 FrequencyPenalty: call.FrequencyPenalty,
207 // Before each step create a new assistant message.
208 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
209 prepared.Messages = options.Messages
210 // Reset all cached items.
211 for i := range prepared.Messages {
212 prepared.Messages[i].ProviderOptions = nil
213 }
214
215 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
216 a.messageQueue.Del(call.SessionID)
217 for _, queued := range queuedCalls {
218 userMessage, createErr := a.createUserMessage(callContext, queued)
219 if createErr != nil {
220 return callContext, prepared, createErr
221 }
222 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
223 }
224
225 lastSystemRoleInx := 0
226 systemMessageUpdated := false
227 for i, msg := range prepared.Messages {
228 // Only add cache control to the last message.
229 if msg.Role == fantasy.MessageRoleSystem {
230 lastSystemRoleInx = i
231 } else if !systemMessageUpdated {
232 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
233 systemMessageUpdated = true
234 }
235 // Than add cache control to the last 2 messages.
236 if i > len(prepared.Messages)-3 {
237 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
238 }
239 }
240
241 if a.systemPromptPrefix != "" {
242 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
243 }
244
245 var assistantMsg message.Message
246 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
247 Role: message.Assistant,
248 Parts: []message.ContentPart{},
249 Model: a.largeModel.ModelCfg.Model,
250 Provider: a.largeModel.ModelCfg.Provider,
251 })
252 if err != nil {
253 return callContext, prepared, err
254 }
255 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
256 currentAssistant = &assistantMsg
257 return callContext, prepared, err
258 },
259 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
260 currentAssistant.AppendReasoningContent(reasoning.Text)
261 return a.messages.Update(genCtx, *currentAssistant)
262 },
263 OnReasoningDelta: func(id string, text string) error {
264 currentAssistant.AppendReasoningContent(text)
265 return a.messages.Update(genCtx, *currentAssistant)
266 },
267 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
268 // handle anthropic signature
269 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
270 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
271 currentAssistant.AppendReasoningSignature(reasoning.Signature)
272 }
273 }
274 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
275 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
276 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
277 }
278 }
279 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
280 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
281 currentAssistant.SetReasoningResponsesData(reasoning)
282 }
283 }
284 currentAssistant.FinishThinking()
285 return a.messages.Update(genCtx, *currentAssistant)
286 },
287 OnTextDelta: func(id string, text string) error {
288 // Strip leading newline from initial text content. This is is
289 // particularly important in non-interactive mode where leading
290 // newlines are very visible.
291 if len(currentAssistant.Parts) == 0 {
292 text = strings.TrimPrefix(text, "\n")
293 }
294
295 currentAssistant.AppendContent(text)
296 return a.messages.Update(genCtx, *currentAssistant)
297 },
298 OnToolInputStart: func(id string, toolName string) error {
299 toolCall := message.ToolCall{
300 ID: id,
301 Name: toolName,
302 ProviderExecuted: false,
303 Finished: false,
304 }
305 currentAssistant.AddToolCall(toolCall)
306 return a.messages.Update(genCtx, *currentAssistant)
307 },
308 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
309 // TODO: implement
310 },
311 OnToolCall: func(tc fantasy.ToolCallContent) error {
312 toolCall := message.ToolCall{
313 ID: tc.ToolCallID,
314 Name: tc.ToolName,
315 Input: tc.Input,
316 ProviderExecuted: false,
317 Finished: true,
318 }
319 currentAssistant.AddToolCall(toolCall)
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 OnToolResult: func(result fantasy.ToolResultContent) error {
323 var resultContent string
324 isError := false
325 switch result.Result.GetType() {
326 case fantasy.ToolResultContentTypeText:
327 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
328 if ok {
329 resultContent = r.Text
330 }
331 case fantasy.ToolResultContentTypeError:
332 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
333 if ok {
334 isError = true
335 resultContent = r.Error.Error()
336 }
337 case fantasy.ToolResultContentTypeMedia:
338 // TODO: handle this message type
339 }
340 toolResult := message.ToolResult{
341 ToolCallID: result.ToolCallID,
342 Name: result.ToolName,
343 Content: resultContent,
344 IsError: isError,
345 Metadata: result.ClientMetadata,
346 }
347 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
348 Role: message.Tool,
349 Parts: []message.ContentPart{
350 toolResult,
351 },
352 })
353 if createMsgErr != nil {
354 return createMsgErr
355 }
356 return nil
357 },
358 OnStepFinish: func(stepResult fantasy.StepResult) error {
359 finishReason := message.FinishReasonUnknown
360 switch stepResult.FinishReason {
361 case fantasy.FinishReasonLength:
362 finishReason = message.FinishReasonMaxTokens
363 case fantasy.FinishReasonStop:
364 finishReason = message.FinishReasonEndTurn
365 case fantasy.FinishReasonToolCalls:
366 finishReason = message.FinishReasonToolUse
367 }
368 currentAssistant.AddFinish(finishReason, "", "")
369 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
370 sessionLock.Lock()
371 _, sessionErr := a.sessions.Save(genCtx, currentSession)
372 sessionLock.Unlock()
373 if sessionErr != nil {
374 return sessionErr
375 }
376 return a.messages.Update(genCtx, *currentAssistant)
377 },
378 StopWhen: []fantasy.StopCondition{
379 func(_ []fantasy.StepResult) bool {
380 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
381 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
382 remaining := cw - tokens
383 var threshold int64
384 if cw > 200_000 {
385 threshold = 20_000
386 } else {
387 threshold = int64(float64(cw) * 0.2)
388 }
389 if (remaining <= threshold) && !a.disableAutoSummarize {
390 shouldSummarize = true
391 return true
392 }
393 return false
394 },
395 },
396 })
397
398 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
399
400 if err != nil {
401 isCancelErr := errors.Is(err, context.Canceled)
402 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
403 if currentAssistant == nil {
404 return result, err
405 }
406 // Ensure we finish thinking on error to close the reasoning state.
407 currentAssistant.FinishThinking()
408 toolCalls := currentAssistant.ToolCalls()
409 // INFO: we use the parent context here because the genCtx has been cancelled.
410 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
411 if createErr != nil {
412 return nil, createErr
413 }
414 for _, tc := range toolCalls {
415 if !tc.Finished {
416 tc.Finished = true
417 tc.Input = "{}"
418 currentAssistant.AddToolCall(tc)
419 updateErr := a.messages.Update(ctx, *currentAssistant)
420 if updateErr != nil {
421 return nil, updateErr
422 }
423 }
424
425 found := false
426 for _, msg := range msgs {
427 if msg.Role == message.Tool {
428 for _, tr := range msg.ToolResults() {
429 if tr.ToolCallID == tc.ID {
430 found = true
431 break
432 }
433 }
434 }
435 if found {
436 break
437 }
438 }
439 if found {
440 continue
441 }
442 content := "There was an error while executing the tool"
443 if isCancelErr {
444 content = "Tool execution canceled by user"
445 } else if isPermissionErr {
446 content = "User denied permission"
447 }
448 toolResult := message.ToolResult{
449 ToolCallID: tc.ID,
450 Name: tc.Name,
451 Content: content,
452 IsError: true,
453 }
454 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
455 Role: message.Tool,
456 Parts: []message.ContentPart{
457 toolResult,
458 },
459 })
460 if createErr != nil {
461 return nil, createErr
462 }
463 }
464 var fantasyErr *fantasy.Error
465 var providerErr *fantasy.ProviderError
466 const defaultTitle = "Provider Error"
467 if isCancelErr {
468 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
469 } else if isPermissionErr {
470 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
471 } else if errors.As(err, &providerErr) {
472 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
473 } else if errors.As(err, &fantasyErr) {
474 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
475 } else {
476 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
477 }
478 // Note: we use the parent context here because the genCtx has been
479 // cancelled.
480 updateErr := a.messages.Update(ctx, *currentAssistant)
481 if updateErr != nil {
482 return nil, updateErr
483 }
484 return nil, err
485 }
486 wg.Wait()
487
488 if shouldSummarize {
489 a.activeRequests.Del(call.SessionID)
490 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
491 return nil, summarizeErr
492 }
493 // If the agent wasn't done...
494 if len(currentAssistant.ToolCalls()) > 0 {
495 existing, ok := a.messageQueue.Get(call.SessionID)
496 if !ok {
497 existing = []SessionAgentCall{}
498 }
499 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
500 existing = append(existing, call)
501 a.messageQueue.Set(call.SessionID, existing)
502 }
503 }
504
505 // Release active request before processing queued messages.
506 a.activeRequests.Del(call.SessionID)
507 cancel()
508
509 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
510 if !ok || len(queuedMessages) == 0 {
511 return result, err
512 }
513 // There are queued messages restart the loop.
514 firstQueuedMessage := queuedMessages[0]
515 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
516 return a.Run(ctx, firstQueuedMessage)
517}
518
519func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
520 if a.IsSessionBusy(sessionID) {
521 return ErrSessionBusy
522 }
523
524 currentSession, err := a.sessions.Get(ctx, sessionID)
525 if err != nil {
526 return fmt.Errorf("failed to get session: %w", err)
527 }
528 msgs, err := a.getSessionMessages(ctx, currentSession)
529 if err != nil {
530 return err
531 }
532 if len(msgs) == 0 {
533 // Nothing to summarize.
534 return nil
535 }
536
537 aiMsgs, _ := a.preparePrompt(msgs)
538
539 genCtx, cancel := context.WithCancel(ctx)
540 a.activeRequests.Set(sessionID, cancel)
541 defer a.activeRequests.Del(sessionID)
542 defer cancel()
543
544 agent := fantasy.NewAgent(a.largeModel.Model,
545 fantasy.WithSystemPrompt(string(summaryPrompt)),
546 )
547 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
548 Role: message.Assistant,
549 Model: a.largeModel.Model.Model(),
550 Provider: a.largeModel.Model.Provider(),
551 IsSummaryMessage: true,
552 })
553 if err != nil {
554 return err
555 }
556
557 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
558 Prompt: "Provide a detailed summary of our conversation above.",
559 Messages: aiMsgs,
560 ProviderOptions: opts,
561 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
562 prepared.Messages = options.Messages
563 if a.systemPromptPrefix != "" {
564 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
565 }
566 return callContext, prepared, nil
567 },
568 OnReasoningDelta: func(id string, text string) error {
569 summaryMessage.AppendReasoningContent(text)
570 return a.messages.Update(genCtx, summaryMessage)
571 },
572 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
573 // Handle anthropic signature.
574 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
575 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
576 summaryMessage.AppendReasoningSignature(signature.Signature)
577 }
578 }
579 summaryMessage.FinishThinking()
580 return a.messages.Update(genCtx, summaryMessage)
581 },
582 OnTextDelta: func(id, text string) error {
583 summaryMessage.AppendContent(text)
584 return a.messages.Update(genCtx, summaryMessage)
585 },
586 })
587 if err != nil {
588 isCancelErr := errors.Is(err, context.Canceled)
589 if isCancelErr {
590 // User cancelled summarize we need to remove the summary message.
591 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
592 return deleteErr
593 }
594 return err
595 }
596
597 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
598 err = a.messages.Update(genCtx, summaryMessage)
599 if err != nil {
600 return err
601 }
602
603 var openrouterCost *float64
604 for _, step := range resp.Steps {
605 stepCost := a.openrouterCost(step.ProviderMetadata)
606 if stepCost != nil {
607 newCost := *stepCost
608 if openrouterCost != nil {
609 newCost += *openrouterCost
610 }
611 openrouterCost = &newCost
612 }
613 }
614
615 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
616
617 // Just in case, get just the last usage info.
618 usage := resp.Response.Usage
619 currentSession.SummaryMessageID = summaryMessage.ID
620 currentSession.CompletionTokens = usage.OutputTokens
621 currentSession.PromptTokens = 0
622 _, err = a.sessions.Save(genCtx, currentSession)
623 return err
624}
625
626func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
627 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
628 return fantasy.ProviderOptions{}
629 }
630 return fantasy.ProviderOptions{
631 anthropic.Name: &anthropic.ProviderCacheControlOptions{
632 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
633 },
634 bedrock.Name: &anthropic.ProviderCacheControlOptions{
635 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
636 },
637 }
638}
639
640func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
641 var attachmentParts []message.ContentPart
642 for _, attachment := range call.Attachments {
643 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
644 }
645 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
646 parts = append(parts, attachmentParts...)
647 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
648 Role: message.User,
649 Parts: parts,
650 })
651 if err != nil {
652 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
653 }
654 return msg, nil
655}
656
657func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
658 var history []fantasy.Message
659 for _, m := range msgs {
660 if len(m.Parts) == 0 {
661 continue
662 }
663 // Assistant message without content or tool calls (cancelled before it
664 // returned anything).
665 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
666 continue
667 }
668 history = append(history, m.ToAIMessage()...)
669 }
670
671 var files []fantasy.FilePart
672 for _, attachment := range attachments {
673 files = append(files, fantasy.FilePart{
674 Filename: attachment.FileName,
675 Data: attachment.Content,
676 MediaType: attachment.MimeType,
677 })
678 }
679
680 return history, files
681}
682
683func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
684 msgs, err := a.messages.List(ctx, session.ID)
685 if err != nil {
686 return nil, fmt.Errorf("failed to list messages: %w", err)
687 }
688
689 if session.SummaryMessageID != "" {
690 summaryMsgInex := -1
691 for i, msg := range msgs {
692 if msg.ID == session.SummaryMessageID {
693 summaryMsgInex = i
694 break
695 }
696 }
697 if summaryMsgInex != -1 {
698 msgs = msgs[summaryMsgInex:]
699 msgs[0].Role = message.User
700 }
701 }
702 return msgs, nil
703}
704
705func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
706 if prompt == "" {
707 return
708 }
709
710 var maxOutput int64 = 40
711 if a.smallModel.CatwalkCfg.CanReason {
712 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
713 }
714
715 agent := fantasy.NewAgent(a.smallModel.Model,
716 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
717 fantasy.WithMaxOutputTokens(maxOutput),
718 )
719
720 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
721 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
722 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
723 prepared.Messages = options.Messages
724 if a.systemPromptPrefix != "" {
725 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
726 }
727 return callContext, prepared, nil
728 },
729 })
730 if err != nil {
731 slog.Error("error generating title", "err", err)
732 return
733 }
734
735 title := resp.Response.Content.Text()
736
737 title = strings.ReplaceAll(title, "\n", " ")
738
739 // Remove thinking tags if present.
740 if idx := strings.Index(title, "</think>"); idx > 0 {
741 title = title[idx+len("</think>"):]
742 }
743
744 title = strings.TrimSpace(title)
745 if title == "" {
746 slog.Warn("failed to generate title", "warn", "empty title")
747 return
748 }
749
750 session.Title = title
751
752 var openrouterCost *float64
753 for _, step := range resp.Steps {
754 stepCost := a.openrouterCost(step.ProviderMetadata)
755 if stepCost != nil {
756 newCost := *stepCost
757 if openrouterCost != nil {
758 newCost += *openrouterCost
759 }
760 openrouterCost = &newCost
761 }
762 }
763
764 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
765 _, saveErr := a.sessions.Save(ctx, *session)
766 if saveErr != nil {
767 slog.Error("failed to save session title & usage", "error", saveErr)
768 return
769 }
770}
771
772func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
773 openrouterMetadata, ok := metadata[openrouter.Name]
774 if !ok {
775 return nil
776 }
777
778 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
779 if !ok {
780 return nil
781 }
782 return &opts.Usage.Cost
783}
784
785func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
786 modelConfig := model.CatwalkCfg
787 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
788 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
789 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
790 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
791
792 a.eventTokensUsed(session.ID, model, usage, cost)
793
794 if overrideCost != nil {
795 session.Cost += *overrideCost
796 } else {
797 session.Cost += cost
798 }
799
800 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
801 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
802}
803
804func (a *sessionAgent) Cancel(sessionID string) {
805 // Cancel regular requests.
806 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
807 slog.Info("Request cancellation initiated", "session_id", sessionID)
808 cancel()
809 }
810
811 // Also check for summarize requests.
812 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
813 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
814 cancel()
815 }
816
817 if a.QueuedPrompts(sessionID) > 0 {
818 slog.Info("Clearing queued prompts", "session_id", sessionID)
819 a.messageQueue.Del(sessionID)
820 }
821}
822
823func (a *sessionAgent) ClearQueue(sessionID string) {
824 if a.QueuedPrompts(sessionID) > 0 {
825 slog.Info("Clearing queued prompts", "session_id", sessionID)
826 a.messageQueue.Del(sessionID)
827 }
828}
829
830func (a *sessionAgent) CancelAll() {
831 if !a.IsBusy() {
832 return
833 }
834 for key := range a.activeRequests.Seq2() {
835 a.Cancel(key) // key is sessionID
836 }
837
838 timeout := time.After(5 * time.Second)
839 for a.IsBusy() {
840 select {
841 case <-timeout:
842 return
843 default:
844 time.Sleep(200 * time.Millisecond)
845 }
846 }
847}
848
849func (a *sessionAgent) IsBusy() bool {
850 var busy bool
851 for cancelFunc := range a.activeRequests.Seq() {
852 if cancelFunc != nil {
853 busy = true
854 break
855 }
856 }
857 return busy
858}
859
860func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
861 _, busy := a.activeRequests.Get(sessionID)
862 return busy
863}
864
865func (a *sessionAgent) QueuedPrompts(sessionID string) int {
866 l, ok := a.messageQueue.Get(sessionID)
867 if !ok {
868 return 0
869 }
870 return len(l)
871}
872
873func (a *sessionAgent) SetModels(large Model, small Model) {
874 a.largeModel = large
875 a.smallModel = small
876}
877
878func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
879 a.tools = tools
880}
881
882func (a *sessionAgent) Model() Model {
883 return a.largeModel
884}