1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "github.com/charmbracelet/catwalk/pkg/catwalk"
30 "github.com/charmbracelet/crush/internal/agent/tools"
31 "github.com/charmbracelet/crush/internal/config"
32 "github.com/charmbracelet/crush/internal/csync"
33 "github.com/charmbracelet/crush/internal/message"
34 "github.com/charmbracelet/crush/internal/notification"
35 "github.com/charmbracelet/crush/internal/permission"
36 "github.com/charmbracelet/crush/internal/session"
37 "github.com/charmbracelet/crush/internal/stringext"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57}
58
59type SessionAgent interface {
60 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
61 SetModels(large Model, small Model)
62 SetTools(tools []fantasy.AgentTool)
63 Cancel(sessionID string)
64 CancelAll()
65 IsSessionBusy(sessionID string) bool
66 IsBusy() bool
67 QueuedPrompts(sessionID string) int
68 ClearQueue(sessionID string)
69 Summarize(context.Context, string, fantasy.ProviderOptions) error
70 Model() Model
71}
72
73type Model struct {
74 Model fantasy.LanguageModel
75 CatwalkCfg catwalk.Model
76 ModelCfg config.SelectedModel
77}
78
79type sessionAgent struct {
80 largeModel Model
81 smallModel Model
82 systemPromptPrefix string
83 systemPrompt string
84 tools []fantasy.AgentTool
85 sessions session.Service
86 messages message.Service
87 disableAutoSummarize bool
88 isYolo bool
89
90 messageQueue *csync.Map[string, []SessionAgentCall]
91 activeRequests *csync.Map[string, context.CancelFunc]
92}
93
94type SessionAgentOptions struct {
95 LargeModel Model
96 SmallModel Model
97 SystemPromptPrefix string
98 SystemPrompt string
99 DisableAutoSummarize bool
100 IsYolo bool
101 Sessions session.Service
102 Messages message.Service
103 Tools []fantasy.AgentTool
104}
105
106func NewSessionAgent(
107 opts SessionAgentOptions,
108) SessionAgent {
109 return &sessionAgent{
110 largeModel: opts.LargeModel,
111 smallModel: opts.SmallModel,
112 systemPromptPrefix: opts.SystemPromptPrefix,
113 systemPrompt: opts.SystemPrompt,
114 sessions: opts.Sessions,
115 messages: opts.Messages,
116 disableAutoSummarize: opts.DisableAutoSummarize,
117 tools: opts.Tools,
118 isYolo: opts.IsYolo,
119 messageQueue: csync.NewMap[string, []SessionAgentCall](),
120 activeRequests: csync.NewMap[string, context.CancelFunc](),
121 }
122}
123
124func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
125 if call.Prompt == "" {
126 return nil, ErrEmptyPrompt
127 }
128 if call.SessionID == "" {
129 return nil, ErrSessionMissing
130 }
131
132 // Queue the message if busy
133 if a.IsSessionBusy(call.SessionID) {
134 existing, ok := a.messageQueue.Get(call.SessionID)
135 if !ok {
136 existing = []SessionAgentCall{}
137 }
138 existing = append(existing, call)
139 a.messageQueue.Set(call.SessionID, existing)
140 return nil, nil
141 }
142
143 if len(a.tools) > 0 {
144 // Add Anthropic caching to the last tool.
145 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
146 }
147
148 agent := fantasy.NewAgent(
149 a.largeModel.Model,
150 fantasy.WithSystemPrompt(a.systemPrompt),
151 fantasy.WithTools(a.tools...),
152 )
153
154 sessionLock := sync.Mutex{}
155 currentSession, err := a.sessions.Get(ctx, call.SessionID)
156 if err != nil {
157 return nil, fmt.Errorf("failed to get session: %w", err)
158 }
159
160 msgs, err := a.getSessionMessages(ctx, currentSession)
161 if err != nil {
162 return nil, fmt.Errorf("failed to get session messages: %w", err)
163 }
164
165 var wg sync.WaitGroup
166 // Generate title if first message.
167 if len(msgs) == 0 {
168 wg.Go(func() {
169 sessionLock.Lock()
170 a.generateTitle(ctx, ¤tSession, call.Prompt)
171 sessionLock.Unlock()
172 })
173 }
174
175 // Add the user message to the session.
176 _, err = a.createUserMessage(ctx, call)
177 if err != nil {
178 return nil, err
179 }
180
181 // Add the session to the context.
182 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
183
184 genCtx, cancel := context.WithCancel(ctx)
185 a.activeRequests.Set(call.SessionID, cancel)
186
187 defer cancel()
188 defer a.activeRequests.Del(call.SessionID)
189
190 history, files := a.preparePrompt(msgs, call.Attachments...)
191
192 startTime := time.Now()
193 a.eventPromptSent(call.SessionID)
194
195 var currentAssistant *message.Message
196 var shouldSummarize bool
197 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
198 Prompt: call.Prompt,
199 Files: files,
200 Messages: history,
201 ProviderOptions: call.ProviderOptions,
202 MaxOutputTokens: &call.MaxOutputTokens,
203 TopP: call.TopP,
204 Temperature: call.Temperature,
205 PresencePenalty: call.PresencePenalty,
206 TopK: call.TopK,
207 FrequencyPenalty: call.FrequencyPenalty,
208 // Before each step create a new assistant message.
209 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
210 prepared.Messages = options.Messages
211 // Reset all cached items.
212 for i := range prepared.Messages {
213 prepared.Messages[i].ProviderOptions = nil
214 }
215
216 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
217 a.messageQueue.Del(call.SessionID)
218 for _, queued := range queuedCalls {
219 userMessage, createErr := a.createUserMessage(callContext, queued)
220 if createErr != nil {
221 return callContext, prepared, createErr
222 }
223 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
224 }
225
226 lastSystemRoleInx := 0
227 systemMessageUpdated := false
228 for i, msg := range prepared.Messages {
229 // Only add cache control to the last message.
230 if msg.Role == fantasy.MessageRoleSystem {
231 lastSystemRoleInx = i
232 } else if !systemMessageUpdated {
233 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
234 systemMessageUpdated = true
235 }
236 // Than add cache control to the last 2 messages.
237 if i > len(prepared.Messages)-3 {
238 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
239 }
240 }
241
242 if a.systemPromptPrefix != "" {
243 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
244 }
245
246 var assistantMsg message.Message
247 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
248 Role: message.Assistant,
249 Parts: []message.ContentPart{},
250 Model: a.largeModel.ModelCfg.Model,
251 Provider: a.largeModel.ModelCfg.Provider,
252 })
253 if err != nil {
254 return callContext, prepared, err
255 }
256 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
257 currentAssistant = &assistantMsg
258 return callContext, prepared, err
259 },
260 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
261 currentAssistant.AppendReasoningContent(reasoning.Text)
262 return a.messages.Update(genCtx, *currentAssistant)
263 },
264 OnReasoningDelta: func(id string, text string) error {
265 currentAssistant.AppendReasoningContent(text)
266 return a.messages.Update(genCtx, *currentAssistant)
267 },
268 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
269 // handle anthropic signature
270 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
271 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
272 currentAssistant.AppendReasoningSignature(reasoning.Signature)
273 }
274 }
275 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
276 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
277 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
278 }
279 }
280 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
281 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
282 currentAssistant.SetReasoningResponsesData(reasoning)
283 }
284 }
285 currentAssistant.FinishThinking()
286 return a.messages.Update(genCtx, *currentAssistant)
287 },
288 OnTextDelta: func(id string, text string) error {
289 // Strip leading newline from initial text content. This is is
290 // particularly important in non-interactive mode where leading
291 // newlines are very visible.
292 if len(currentAssistant.Parts) == 0 {
293 text = strings.TrimPrefix(text, "\n")
294 }
295
296 currentAssistant.AppendContent(text)
297 return a.messages.Update(genCtx, *currentAssistant)
298 },
299 OnToolInputStart: func(id string, toolName string) error {
300 toolCall := message.ToolCall{
301 ID: id,
302 Name: toolName,
303 ProviderExecuted: false,
304 Finished: false,
305 }
306 currentAssistant.AddToolCall(toolCall)
307 return a.messages.Update(genCtx, *currentAssistant)
308 },
309 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
310 // TODO: implement
311 },
312 OnToolCall: func(tc fantasy.ToolCallContent) error {
313 toolCall := message.ToolCall{
314 ID: tc.ToolCallID,
315 Name: tc.ToolName,
316 Input: tc.Input,
317 ProviderExecuted: false,
318 Finished: true,
319 }
320 currentAssistant.AddToolCall(toolCall)
321 return a.messages.Update(genCtx, *currentAssistant)
322 },
323 OnToolResult: func(result fantasy.ToolResultContent) error {
324 var resultContent string
325 isError := false
326 switch result.Result.GetType() {
327 case fantasy.ToolResultContentTypeText:
328 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
329 if ok {
330 resultContent = r.Text
331 }
332 case fantasy.ToolResultContentTypeError:
333 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
334 if ok {
335 isError = true
336 resultContent = r.Error.Error()
337 }
338 case fantasy.ToolResultContentTypeMedia:
339 // TODO: handle this message type
340 }
341 toolResult := message.ToolResult{
342 ToolCallID: result.ToolCallID,
343 Name: result.ToolName,
344 Content: resultContent,
345 IsError: isError,
346 Metadata: result.ClientMetadata,
347 }
348 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
349 Role: message.Tool,
350 Parts: []message.ContentPart{
351 toolResult,
352 },
353 })
354 if createMsgErr != nil {
355 return createMsgErr
356 }
357 return nil
358 },
359 OnStepFinish: func(stepResult fantasy.StepResult) error {
360 finishReason := message.FinishReasonUnknown
361 switch stepResult.FinishReason {
362 case fantasy.FinishReasonLength:
363 finishReason = message.FinishReasonMaxTokens
364 case fantasy.FinishReasonStop:
365 finishReason = message.FinishReasonEndTurn
366 case fantasy.FinishReasonToolCalls:
367 finishReason = message.FinishReasonToolUse
368 }
369 currentAssistant.AddFinish(finishReason, "", "")
370 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
371 sessionLock.Lock()
372 _, sessionErr := a.sessions.Save(genCtx, currentSession)
373 sessionLock.Unlock()
374 if sessionErr != nil {
375 return sessionErr
376 }
377 return a.messages.Update(genCtx, *currentAssistant)
378 },
379 StopWhen: []fantasy.StopCondition{
380 func(_ []fantasy.StepResult) bool {
381 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
382 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
383 remaining := cw - tokens
384 var threshold int64
385 if cw > 200_000 {
386 threshold = 20_000
387 } else {
388 threshold = int64(float64(cw) * 0.2)
389 }
390 if (remaining <= threshold) && !a.disableAutoSummarize {
391 shouldSummarize = true
392 return true
393 }
394 return false
395 },
396 },
397 })
398
399 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
400
401 if err != nil {
402 isCancelErr := errors.Is(err, context.Canceled)
403 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
404 if currentAssistant == nil {
405 return result, err
406 }
407 // Ensure we finish thinking on error to close the reasoning state.
408 currentAssistant.FinishThinking()
409 toolCalls := currentAssistant.ToolCalls()
410 // INFO: we use the parent context here because the genCtx has been cancelled.
411 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
412 if createErr != nil {
413 return nil, createErr
414 }
415 for _, tc := range toolCalls {
416 if !tc.Finished {
417 tc.Finished = true
418 tc.Input = "{}"
419 currentAssistant.AddToolCall(tc)
420 updateErr := a.messages.Update(ctx, *currentAssistant)
421 if updateErr != nil {
422 return nil, updateErr
423 }
424 }
425
426 found := false
427 for _, msg := range msgs {
428 if msg.Role == message.Tool {
429 for _, tr := range msg.ToolResults() {
430 if tr.ToolCallID == tc.ID {
431 found = true
432 break
433 }
434 }
435 }
436 if found {
437 break
438 }
439 }
440 if found {
441 continue
442 }
443 content := "There was an error while executing the tool"
444 if isCancelErr {
445 content = "Tool execution canceled by user"
446 } else if isPermissionErr {
447 content = "User denied permission"
448 }
449 toolResult := message.ToolResult{
450 ToolCallID: tc.ID,
451 Name: tc.Name,
452 Content: content,
453 IsError: true,
454 }
455 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
456 Role: message.Tool,
457 Parts: []message.ContentPart{
458 toolResult,
459 },
460 })
461 if createErr != nil {
462 return nil, createErr
463 }
464 }
465 var fantasyErr *fantasy.Error
466 var providerErr *fantasy.ProviderError
467 const defaultTitle = "Provider Error"
468 if isCancelErr {
469 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
470 } else if isPermissionErr {
471 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
472 } else if errors.As(err, &providerErr) {
473 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
474 } else if errors.As(err, &fantasyErr) {
475 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
476 } else {
477 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
478 }
479 // Note: we use the parent context here because the genCtx has been
480 // cancelled.
481 updateErr := a.messages.Update(ctx, *currentAssistant)
482 if updateErr != nil {
483 return nil, updateErr
484 }
485 return nil, err
486 }
487 wg.Wait()
488
489 // Send notification that agent has finished its turn.
490 sessionLock.Lock()
491 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
492 sessionLock.Unlock()
493 if err := notification.Send("Crush is waiting...", notifBody); err != nil {
494 // Log but don't fail on notification errors.
495 }
496
497 if shouldSummarize {
498 a.activeRequests.Del(call.SessionID)
499 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
500 return nil, summarizeErr
501 }
502 // If the agent wasn't done...
503 if len(currentAssistant.ToolCalls()) > 0 {
504 existing, ok := a.messageQueue.Get(call.SessionID)
505 if !ok {
506 existing = []SessionAgentCall{}
507 }
508 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
509 existing = append(existing, call)
510 a.messageQueue.Set(call.SessionID, existing)
511 }
512 }
513
514 // Release active request before processing queued messages.
515 a.activeRequests.Del(call.SessionID)
516 cancel()
517
518 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
519 if !ok || len(queuedMessages) == 0 {
520 return result, err
521 }
522 // There are queued messages restart the loop.
523 firstQueuedMessage := queuedMessages[0]
524 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
525 return a.Run(ctx, firstQueuedMessage)
526}
527
528func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
529 if a.IsSessionBusy(sessionID) {
530 return ErrSessionBusy
531 }
532
533 currentSession, err := a.sessions.Get(ctx, sessionID)
534 if err != nil {
535 return fmt.Errorf("failed to get session: %w", err)
536 }
537 msgs, err := a.getSessionMessages(ctx, currentSession)
538 if err != nil {
539 return err
540 }
541 if len(msgs) == 0 {
542 // Nothing to summarize.
543 return nil
544 }
545
546 aiMsgs, _ := a.preparePrompt(msgs)
547
548 genCtx, cancel := context.WithCancel(ctx)
549 a.activeRequests.Set(sessionID, cancel)
550 defer a.activeRequests.Del(sessionID)
551 defer cancel()
552
553 agent := fantasy.NewAgent(a.largeModel.Model,
554 fantasy.WithSystemPrompt(string(summaryPrompt)),
555 )
556 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
557 Role: message.Assistant,
558 Model: a.largeModel.Model.Model(),
559 Provider: a.largeModel.Model.Provider(),
560 IsSummaryMessage: true,
561 })
562 if err != nil {
563 return err
564 }
565
566 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
567 Prompt: "Provide a detailed summary of our conversation above.",
568 Messages: aiMsgs,
569 ProviderOptions: opts,
570 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
571 prepared.Messages = options.Messages
572 if a.systemPromptPrefix != "" {
573 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
574 }
575 return callContext, prepared, nil
576 },
577 OnReasoningDelta: func(id string, text string) error {
578 summaryMessage.AppendReasoningContent(text)
579 return a.messages.Update(genCtx, summaryMessage)
580 },
581 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
582 // Handle anthropic signature.
583 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
584 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
585 summaryMessage.AppendReasoningSignature(signature.Signature)
586 }
587 }
588 summaryMessage.FinishThinking()
589 return a.messages.Update(genCtx, summaryMessage)
590 },
591 OnTextDelta: func(id, text string) error {
592 summaryMessage.AppendContent(text)
593 return a.messages.Update(genCtx, summaryMessage)
594 },
595 })
596 if err != nil {
597 isCancelErr := errors.Is(err, context.Canceled)
598 if isCancelErr {
599 // User cancelled summarize we need to remove the summary message.
600 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
601 return deleteErr
602 }
603 return err
604 }
605
606 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
607 err = a.messages.Update(genCtx, summaryMessage)
608 if err != nil {
609 return err
610 }
611
612 var openrouterCost *float64
613 for _, step := range resp.Steps {
614 stepCost := a.openrouterCost(step.ProviderMetadata)
615 if stepCost != nil {
616 newCost := *stepCost
617 if openrouterCost != nil {
618 newCost += *openrouterCost
619 }
620 openrouterCost = &newCost
621 }
622 }
623
624 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
625
626 // Just in case, get just the last usage info.
627 usage := resp.Response.Usage
628 currentSession.SummaryMessageID = summaryMessage.ID
629 currentSession.CompletionTokens = usage.OutputTokens
630 currentSession.PromptTokens = 0
631 _, err = a.sessions.Save(genCtx, currentSession)
632 return err
633}
634
635func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
636 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
637 return fantasy.ProviderOptions{}
638 }
639 return fantasy.ProviderOptions{
640 anthropic.Name: &anthropic.ProviderCacheControlOptions{
641 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
642 },
643 bedrock.Name: &anthropic.ProviderCacheControlOptions{
644 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
645 },
646 }
647}
648
649func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
650 var attachmentParts []message.ContentPart
651 for _, attachment := range call.Attachments {
652 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
653 }
654 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
655 parts = append(parts, attachmentParts...)
656 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
657 Role: message.User,
658 Parts: parts,
659 })
660 if err != nil {
661 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
662 }
663 return msg, nil
664}
665
666func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
667 var history []fantasy.Message
668 for _, m := range msgs {
669 if len(m.Parts) == 0 {
670 continue
671 }
672 // Assistant message without content or tool calls (cancelled before it
673 // returned anything).
674 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
675 continue
676 }
677 history = append(history, m.ToAIMessage()...)
678 }
679
680 var files []fantasy.FilePart
681 for _, attachment := range attachments {
682 files = append(files, fantasy.FilePart{
683 Filename: attachment.FileName,
684 Data: attachment.Content,
685 MediaType: attachment.MimeType,
686 })
687 }
688
689 return history, files
690}
691
692func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
693 msgs, err := a.messages.List(ctx, session.ID)
694 if err != nil {
695 return nil, fmt.Errorf("failed to list messages: %w", err)
696 }
697
698 if session.SummaryMessageID != "" {
699 summaryMsgInex := -1
700 for i, msg := range msgs {
701 if msg.ID == session.SummaryMessageID {
702 summaryMsgInex = i
703 break
704 }
705 }
706 if summaryMsgInex != -1 {
707 msgs = msgs[summaryMsgInex:]
708 msgs[0].Role = message.User
709 }
710 }
711 return msgs, nil
712}
713
714func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
715 if prompt == "" {
716 return
717 }
718
719 var maxOutput int64 = 40
720 if a.smallModel.CatwalkCfg.CanReason {
721 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
722 }
723
724 agent := fantasy.NewAgent(a.smallModel.Model,
725 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
726 fantasy.WithMaxOutputTokens(maxOutput),
727 )
728
729 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
730 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
731 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
732 prepared.Messages = options.Messages
733 if a.systemPromptPrefix != "" {
734 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
735 }
736 return callContext, prepared, nil
737 },
738 })
739 if err != nil {
740 slog.Error("error generating title", "err", err)
741 return
742 }
743
744 title := resp.Response.Content.Text()
745
746 title = strings.ReplaceAll(title, "\n", " ")
747
748 // Remove thinking tags if present.
749 if idx := strings.Index(title, "</think>"); idx > 0 {
750 title = title[idx+len("</think>"):]
751 }
752
753 title = strings.TrimSpace(title)
754 if title == "" {
755 slog.Warn("failed to generate title", "warn", "empty title")
756 return
757 }
758
759 session.Title = title
760
761 var openrouterCost *float64
762 for _, step := range resp.Steps {
763 stepCost := a.openrouterCost(step.ProviderMetadata)
764 if stepCost != nil {
765 newCost := *stepCost
766 if openrouterCost != nil {
767 newCost += *openrouterCost
768 }
769 openrouterCost = &newCost
770 }
771 }
772
773 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
774 _, saveErr := a.sessions.Save(ctx, *session)
775 if saveErr != nil {
776 slog.Error("failed to save session title & usage", "error", saveErr)
777 return
778 }
779}
780
781func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
782 openrouterMetadata, ok := metadata[openrouter.Name]
783 if !ok {
784 return nil
785 }
786
787 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
788 if !ok {
789 return nil
790 }
791 return &opts.Usage.Cost
792}
793
794func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
795 modelConfig := model.CatwalkCfg
796 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
797 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
798 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
799 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
800
801 a.eventTokensUsed(session.ID, model, usage, cost)
802
803 if overrideCost != nil {
804 session.Cost += *overrideCost
805 } else {
806 session.Cost += cost
807 }
808
809 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
810 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
811}
812
813func (a *sessionAgent) Cancel(sessionID string) {
814 // Cancel regular requests.
815 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
816 slog.Info("Request cancellation initiated", "session_id", sessionID)
817 cancel()
818 }
819
820 // Also check for summarize requests.
821 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
822 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
823 cancel()
824 }
825
826 if a.QueuedPrompts(sessionID) > 0 {
827 slog.Info("Clearing queued prompts", "session_id", sessionID)
828 a.messageQueue.Del(sessionID)
829 }
830}
831
832func (a *sessionAgent) ClearQueue(sessionID string) {
833 if a.QueuedPrompts(sessionID) > 0 {
834 slog.Info("Clearing queued prompts", "session_id", sessionID)
835 a.messageQueue.Del(sessionID)
836 }
837}
838
839func (a *sessionAgent) CancelAll() {
840 if !a.IsBusy() {
841 return
842 }
843 for key := range a.activeRequests.Seq2() {
844 a.Cancel(key) // key is sessionID
845 }
846
847 timeout := time.After(5 * time.Second)
848 for a.IsBusy() {
849 select {
850 case <-timeout:
851 return
852 default:
853 time.Sleep(200 * time.Millisecond)
854 }
855 }
856}
857
858func (a *sessionAgent) IsBusy() bool {
859 var busy bool
860 for cancelFunc := range a.activeRequests.Seq() {
861 if cancelFunc != nil {
862 busy = true
863 break
864 }
865 }
866 return busy
867}
868
869func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
870 _, busy := a.activeRequests.Get(sessionID)
871 return busy
872}
873
874func (a *sessionAgent) QueuedPrompts(sessionID string) int {
875 l, ok := a.messageQueue.Get(sessionID)
876 if !ok {
877 return 0
878 }
879 return len(l)
880}
881
882func (a *sessionAgent) SetModels(large Model, small Model) {
883 a.largeModel = large
884 a.smallModel = small
885}
886
887func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
888 a.tools = tools
889}
890
891func (a *sessionAgent) Model() Model {
892 return a.largeModel
893}