1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "errors"
15 "fmt"
16 "log/slog"
17 "os"
18 "strconv"
19 "strings"
20 "sync"
21 "time"
22
23 "charm.land/fantasy"
24 "charm.land/fantasy/providers/anthropic"
25 "charm.land/fantasy/providers/bedrock"
26 "charm.land/fantasy/providers/google"
27 "charm.land/fantasy/providers/openai"
28 "charm.land/fantasy/providers/openrouter"
29 "git.secluded.site/crush/internal/agent/tools"
30 "git.secluded.site/crush/internal/config"
31 "git.secluded.site/crush/internal/csync"
32 "git.secluded.site/crush/internal/message"
33 "git.secluded.site/crush/internal/notification"
34 "git.secluded.site/crush/internal/permission"
35 "git.secluded.site/crush/internal/session"
36 "git.secluded.site/crush/internal/stringext"
37 "github.com/charmbracelet/catwalk/pkg/catwalk"
38)
39
40//go:embed templates/title.md
41var titlePrompt []byte
42
43//go:embed templates/summary.md
44var summaryPrompt []byte
45
46type SessionAgentCall struct {
47 SessionID string
48 Prompt string
49 ProviderOptions fantasy.ProviderOptions
50 Attachments []message.Attachment
51 MaxOutputTokens int64
52 Temperature *float64
53 TopP *float64
54 TopK *int64
55 FrequencyPenalty *float64
56 PresencePenalty *float64
57 NonInteractive bool
58}
59
60type SessionAgent interface {
61 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
62 SetModels(large Model, small Model)
63 SetTools(tools []fantasy.AgentTool)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 ClearQueue(sessionID string)
70 Summarize(context.Context, string, fantasy.ProviderOptions) error
71 Model() Model
72}
73
74type Model struct {
75 Model fantasy.LanguageModel
76 CatwalkCfg catwalk.Model
77 ModelCfg config.SelectedModel
78}
79
80type sessionAgent struct {
81 largeModel Model
82 smallModel Model
83 systemPromptPrefix string
84 systemPrompt string
85 tools []fantasy.AgentTool
86 sessions session.Service
87 messages message.Service
88 disableAutoSummarize bool
89 isYolo bool
90
91 messageQueue *csync.Map[string, []SessionAgentCall]
92 activeRequests *csync.Map[string, context.CancelFunc]
93}
94
95type SessionAgentOptions struct {
96 LargeModel Model
97 SmallModel Model
98 SystemPromptPrefix string
99 SystemPrompt string
100 DisableAutoSummarize bool
101 IsYolo bool
102 Sessions session.Service
103 Messages message.Service
104 Tools []fantasy.AgentTool
105}
106
107func NewSessionAgent(
108 opts SessionAgentOptions,
109) SessionAgent {
110 return &sessionAgent{
111 largeModel: opts.LargeModel,
112 smallModel: opts.SmallModel,
113 systemPromptPrefix: opts.SystemPromptPrefix,
114 systemPrompt: opts.SystemPrompt,
115 sessions: opts.Sessions,
116 messages: opts.Messages,
117 disableAutoSummarize: opts.DisableAutoSummarize,
118 tools: opts.Tools,
119 isYolo: opts.IsYolo,
120 messageQueue: csync.NewMap[string, []SessionAgentCall](),
121 activeRequests: csync.NewMap[string, context.CancelFunc](),
122 }
123}
124
125func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
126 if call.Prompt == "" {
127 return nil, ErrEmptyPrompt
128 }
129 if call.SessionID == "" {
130 return nil, ErrSessionMissing
131 }
132
133 // Queue the message if busy
134 if a.IsSessionBusy(call.SessionID) {
135 existing, ok := a.messageQueue.Get(call.SessionID)
136 if !ok {
137 existing = []SessionAgentCall{}
138 }
139 existing = append(existing, call)
140 a.messageQueue.Set(call.SessionID, existing)
141 return nil, nil
142 }
143
144 if len(a.tools) > 0 {
145 // Add Anthropic caching to the last tool.
146 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
147 }
148
149 agent := fantasy.NewAgent(
150 a.largeModel.Model,
151 fantasy.WithSystemPrompt(a.systemPrompt),
152 fantasy.WithTools(a.tools...),
153 )
154
155 sessionLock := sync.Mutex{}
156 currentSession, err := a.sessions.Get(ctx, call.SessionID)
157 if err != nil {
158 return nil, fmt.Errorf("failed to get session: %w", err)
159 }
160
161 msgs, err := a.getSessionMessages(ctx, currentSession)
162 if err != nil {
163 return nil, fmt.Errorf("failed to get session messages: %w", err)
164 }
165
166 var wg sync.WaitGroup
167 // Generate title if first message.
168 if len(msgs) == 0 {
169 wg.Go(func() {
170 sessionLock.Lock()
171 a.generateTitle(ctx, ¤tSession, call.Prompt)
172 sessionLock.Unlock()
173 })
174 }
175
176 // Add the user message to the session.
177 _, err = a.createUserMessage(ctx, call)
178 if err != nil {
179 return nil, err
180 }
181
182 // Add the session to the context.
183 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
184
185 genCtx, cancel := context.WithCancel(ctx)
186 a.activeRequests.Set(call.SessionID, cancel)
187
188 defer cancel()
189 defer a.activeRequests.Del(call.SessionID)
190
191 history, files := a.preparePrompt(msgs, call.Attachments...)
192
193 startTime := time.Now()
194 a.eventPromptSent(call.SessionID)
195
196 var currentAssistant *message.Message
197 var shouldSummarize bool
198 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
199 Prompt: call.Prompt,
200 Files: files,
201 Messages: history,
202 ProviderOptions: call.ProviderOptions,
203 MaxOutputTokens: &call.MaxOutputTokens,
204 TopP: call.TopP,
205 Temperature: call.Temperature,
206 PresencePenalty: call.PresencePenalty,
207 TopK: call.TopK,
208 FrequencyPenalty: call.FrequencyPenalty,
209 // Before each step create a new assistant message.
210 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
211 prepared.Messages = options.Messages
212 // Reset all cached items.
213 for i := range prepared.Messages {
214 prepared.Messages[i].ProviderOptions = nil
215 }
216
217 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
218 a.messageQueue.Del(call.SessionID)
219 for _, queued := range queuedCalls {
220 userMessage, createErr := a.createUserMessage(callContext, queued)
221 if createErr != nil {
222 return callContext, prepared, createErr
223 }
224 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
225 }
226
227 lastSystemRoleInx := 0
228 systemMessageUpdated := false
229 for i, msg := range prepared.Messages {
230 // Only add cache control to the last message.
231 if msg.Role == fantasy.MessageRoleSystem {
232 lastSystemRoleInx = i
233 } else if !systemMessageUpdated {
234 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
235 systemMessageUpdated = true
236 }
237 // Than add cache control to the last 2 messages.
238 if i > len(prepared.Messages)-3 {
239 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
240 }
241 }
242
243 if promptPrefix := a.promptPrefix(); promptPrefix != "" {
244 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
245 }
246
247 var assistantMsg message.Message
248 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
249 Role: message.Assistant,
250 Parts: []message.ContentPart{},
251 Model: a.largeModel.ModelCfg.Model,
252 Provider: a.largeModel.ModelCfg.Provider,
253 })
254 if err != nil {
255 return callContext, prepared, err
256 }
257 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
258 currentAssistant = &assistantMsg
259 return callContext, prepared, err
260 },
261 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
262 currentAssistant.AppendReasoningContent(reasoning.Text)
263 return a.messages.Update(genCtx, *currentAssistant)
264 },
265 OnReasoningDelta: func(id string, text string) error {
266 currentAssistant.AppendReasoningContent(text)
267 return a.messages.Update(genCtx, *currentAssistant)
268 },
269 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
270 // handle anthropic signature
271 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
272 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
273 currentAssistant.AppendReasoningSignature(reasoning.Signature)
274 }
275 }
276 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
277 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
278 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
279 }
280 }
281 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
282 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
283 currentAssistant.SetReasoningResponsesData(reasoning)
284 }
285 }
286 currentAssistant.FinishThinking()
287 return a.messages.Update(genCtx, *currentAssistant)
288 },
289 OnTextDelta: func(id string, text string) error {
290 // Strip leading newline from initial text content. This is is
291 // particularly important in non-interactive mode where leading
292 // newlines are very visible.
293 if len(currentAssistant.Parts) == 0 {
294 text = strings.TrimPrefix(text, "\n")
295 }
296
297 currentAssistant.AppendContent(text)
298 return a.messages.Update(genCtx, *currentAssistant)
299 },
300 OnToolInputStart: func(id string, toolName string) error {
301 toolCall := message.ToolCall{
302 ID: id,
303 Name: toolName,
304 ProviderExecuted: false,
305 Finished: false,
306 }
307 currentAssistant.AddToolCall(toolCall)
308 return a.messages.Update(genCtx, *currentAssistant)
309 },
310 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
311 // TODO: implement
312 },
313 OnToolCall: func(tc fantasy.ToolCallContent) error {
314 toolCall := message.ToolCall{
315 ID: tc.ToolCallID,
316 Name: tc.ToolName,
317 Input: tc.Input,
318 ProviderExecuted: false,
319 Finished: true,
320 }
321 currentAssistant.AddToolCall(toolCall)
322 return a.messages.Update(genCtx, *currentAssistant)
323 },
324 OnToolResult: func(result fantasy.ToolResultContent) error {
325 var resultContent string
326 isError := false
327 switch result.Result.GetType() {
328 case fantasy.ToolResultContentTypeText:
329 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
330 if ok {
331 resultContent = r.Text
332 }
333 case fantasy.ToolResultContentTypeError:
334 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
335 if ok {
336 isError = true
337 resultContent = r.Error.Error()
338 }
339 case fantasy.ToolResultContentTypeMedia:
340 // TODO: handle this message type
341 }
342 toolResult := message.ToolResult{
343 ToolCallID: result.ToolCallID,
344 Name: result.ToolName,
345 Content: resultContent,
346 IsError: isError,
347 Metadata: result.ClientMetadata,
348 }
349 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
350 Role: message.Tool,
351 Parts: []message.ContentPart{
352 toolResult,
353 },
354 })
355 if createMsgErr != nil {
356 return createMsgErr
357 }
358 return nil
359 },
360 OnStepFinish: func(stepResult fantasy.StepResult) error {
361 finishReason := message.FinishReasonUnknown
362 switch stepResult.FinishReason {
363 case fantasy.FinishReasonLength:
364 finishReason = message.FinishReasonMaxTokens
365 case fantasy.FinishReasonStop:
366 finishReason = message.FinishReasonEndTurn
367 case fantasy.FinishReasonToolCalls:
368 finishReason = message.FinishReasonToolUse
369 }
370 currentAssistant.AddFinish(finishReason, "", "")
371 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
372 sessionLock.Lock()
373 _, sessionErr := a.sessions.Save(genCtx, currentSession)
374 sessionLock.Unlock()
375 if sessionErr != nil {
376 return sessionErr
377 }
378 return a.messages.Update(genCtx, *currentAssistant)
379 },
380 StopWhen: []fantasy.StopCondition{
381 func(_ []fantasy.StepResult) bool {
382 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
383 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
384 remaining := cw - tokens
385 var threshold int64
386 if cw > 200_000 {
387 threshold = 20_000
388 } else {
389 threshold = int64(float64(cw) * 0.2)
390 }
391 if (remaining <= threshold) && !a.disableAutoSummarize {
392 shouldSummarize = true
393 return true
394 }
395 return false
396 },
397 },
398 })
399
400 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
401
402 if err != nil {
403 isCancelErr := errors.Is(err, context.Canceled)
404 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
405 if currentAssistant == nil {
406 return result, err
407 }
408 // Ensure we finish thinking on error to close the reasoning state.
409 currentAssistant.FinishThinking()
410 toolCalls := currentAssistant.ToolCalls()
411 // INFO: we use the parent context here because the genCtx has been cancelled.
412 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
413 if createErr != nil {
414 return nil, createErr
415 }
416 for _, tc := range toolCalls {
417 if !tc.Finished {
418 tc.Finished = true
419 tc.Input = "{}"
420 currentAssistant.AddToolCall(tc)
421 updateErr := a.messages.Update(ctx, *currentAssistant)
422 if updateErr != nil {
423 return nil, updateErr
424 }
425 }
426
427 found := false
428 for _, msg := range msgs {
429 if msg.Role == message.Tool {
430 for _, tr := range msg.ToolResults() {
431 if tr.ToolCallID == tc.ID {
432 found = true
433 break
434 }
435 }
436 }
437 if found {
438 break
439 }
440 }
441 if found {
442 continue
443 }
444 content := "There was an error while executing the tool"
445 if isCancelErr {
446 content = "Tool execution canceled by user"
447 } else if isPermissionErr {
448 content = "User denied permission"
449 }
450 toolResult := message.ToolResult{
451 ToolCallID: tc.ID,
452 Name: tc.Name,
453 Content: content,
454 IsError: true,
455 }
456 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
457 Role: message.Tool,
458 Parts: []message.ContentPart{
459 toolResult,
460 },
461 })
462 if createErr != nil {
463 return nil, createErr
464 }
465 }
466 var fantasyErr *fantasy.Error
467 var providerErr *fantasy.ProviderError
468 const defaultTitle = "Provider Error"
469 if isCancelErr {
470 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
471 } else if isPermissionErr {
472 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
473 } else if errors.As(err, &providerErr) {
474 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
475 } else if errors.As(err, &fantasyErr) {
476 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
477 } else {
478 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
479 }
480 // Note: we use the parent context here because the genCtx has been
481 // cancelled.
482 updateErr := a.messages.Update(ctx, *currentAssistant)
483 if updateErr != nil {
484 return nil, updateErr
485 }
486 return nil, err
487 }
488 wg.Wait()
489
490 // Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
491 if !call.NonInteractive {
492 notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
493 _ = notification.Send("Crush is waiting...", notifBody)
494 }
495
496 if shouldSummarize {
497 a.activeRequests.Del(call.SessionID)
498 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
499 return nil, summarizeErr
500 }
501 // If the agent wasn't done...
502 if len(currentAssistant.ToolCalls()) > 0 {
503 existing, ok := a.messageQueue.Get(call.SessionID)
504 if !ok {
505 existing = []SessionAgentCall{}
506 }
507 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
508 existing = append(existing, call)
509 a.messageQueue.Set(call.SessionID, existing)
510 }
511 }
512
513 // Release active request before processing queued messages.
514 a.activeRequests.Del(call.SessionID)
515 cancel()
516
517 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
518 if !ok || len(queuedMessages) == 0 {
519 return result, err
520 }
521 // There are queued messages restart the loop.
522 firstQueuedMessage := queuedMessages[0]
523 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
524 return a.Run(ctx, firstQueuedMessage)
525}
526
527func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
528 if a.IsSessionBusy(sessionID) {
529 return ErrSessionBusy
530 }
531
532 currentSession, err := a.sessions.Get(ctx, sessionID)
533 if err != nil {
534 return fmt.Errorf("failed to get session: %w", err)
535 }
536 msgs, err := a.getSessionMessages(ctx, currentSession)
537 if err != nil {
538 return err
539 }
540 if len(msgs) == 0 {
541 // Nothing to summarize.
542 return nil
543 }
544
545 aiMsgs, _ := a.preparePrompt(msgs)
546
547 genCtx, cancel := context.WithCancel(ctx)
548 a.activeRequests.Set(sessionID, cancel)
549 defer a.activeRequests.Del(sessionID)
550 defer cancel()
551
552 agent := fantasy.NewAgent(a.largeModel.Model,
553 fantasy.WithSystemPrompt(string(summaryPrompt)),
554 )
555 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
556 Role: message.Assistant,
557 Model: a.largeModel.Model.Model(),
558 Provider: a.largeModel.Model.Provider(),
559 IsSummaryMessage: true,
560 })
561 if err != nil {
562 return err
563 }
564
565 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
566 Prompt: "Provide a detailed summary of our conversation above.",
567 Messages: aiMsgs,
568 ProviderOptions: opts,
569 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
570 prepared.Messages = options.Messages
571 if a.systemPromptPrefix != "" {
572 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
573 }
574 return callContext, prepared, nil
575 },
576 OnReasoningDelta: func(id string, text string) error {
577 summaryMessage.AppendReasoningContent(text)
578 return a.messages.Update(genCtx, summaryMessage)
579 },
580 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
581 // Handle anthropic signature.
582 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
583 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
584 summaryMessage.AppendReasoningSignature(signature.Signature)
585 }
586 }
587 summaryMessage.FinishThinking()
588 return a.messages.Update(genCtx, summaryMessage)
589 },
590 OnTextDelta: func(id, text string) error {
591 summaryMessage.AppendContent(text)
592 return a.messages.Update(genCtx, summaryMessage)
593 },
594 })
595 if err != nil {
596 isCancelErr := errors.Is(err, context.Canceled)
597 if isCancelErr {
598 // User cancelled summarize we need to remove the summary message.
599 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
600 return deleteErr
601 }
602 return err
603 }
604
605 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
606 err = a.messages.Update(genCtx, summaryMessage)
607 if err != nil {
608 return err
609 }
610
611 var openrouterCost *float64
612 for _, step := range resp.Steps {
613 stepCost := a.openrouterCost(step.ProviderMetadata)
614 if stepCost != nil {
615 newCost := *stepCost
616 if openrouterCost != nil {
617 newCost += *openrouterCost
618 }
619 openrouterCost = &newCost
620 }
621 }
622
623 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
624
625 // Just in case, get just the last usage info.
626 usage := resp.Response.Usage
627 currentSession.SummaryMessageID = summaryMessage.ID
628 currentSession.CompletionTokens = usage.OutputTokens
629 currentSession.PromptTokens = 0
630 _, err = a.sessions.Save(genCtx, currentSession)
631 return err
632}
633
634func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
635 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
636 return fantasy.ProviderOptions{}
637 }
638 return fantasy.ProviderOptions{
639 anthropic.Name: &anthropic.ProviderCacheControlOptions{
640 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
641 },
642 bedrock.Name: &anthropic.ProviderCacheControlOptions{
643 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
644 },
645 }
646}
647
648func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
649 var attachmentParts []message.ContentPart
650 for _, attachment := range call.Attachments {
651 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
652 }
653 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
654 parts = append(parts, attachmentParts...)
655 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
656 Role: message.User,
657 Parts: parts,
658 })
659 if err != nil {
660 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
661 }
662 return msg, nil
663}
664
665func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
666 var history []fantasy.Message
667 for _, m := range msgs {
668 if len(m.Parts) == 0 {
669 continue
670 }
671 // Assistant message without content or tool calls (cancelled before it
672 // returned anything).
673 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
674 continue
675 }
676 history = append(history, m.ToAIMessage()...)
677 }
678
679 var files []fantasy.FilePart
680 for _, attachment := range attachments {
681 files = append(files, fantasy.FilePart{
682 Filename: attachment.FileName,
683 Data: attachment.Content,
684 MediaType: attachment.MimeType,
685 })
686 }
687
688 return history, files
689}
690
691func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
692 msgs, err := a.messages.List(ctx, session.ID)
693 if err != nil {
694 return nil, fmt.Errorf("failed to list messages: %w", err)
695 }
696
697 if session.SummaryMessageID != "" {
698 summaryMsgInex := -1
699 for i, msg := range msgs {
700 if msg.ID == session.SummaryMessageID {
701 summaryMsgInex = i
702 break
703 }
704 }
705 if summaryMsgInex != -1 {
706 msgs = msgs[summaryMsgInex:]
707 msgs[0].Role = message.User
708 }
709 }
710 return msgs, nil
711}
712
713func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
714 if prompt == "" {
715 return
716 }
717
718 var maxOutput int64 = 40
719 if a.smallModel.CatwalkCfg.CanReason {
720 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
721 }
722
723 agent := fantasy.NewAgent(a.smallModel.Model,
724 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
725 fantasy.WithMaxOutputTokens(maxOutput),
726 )
727
728 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
729 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
730 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
731 prepared.Messages = options.Messages
732 if a.systemPromptPrefix != "" {
733 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
734 }
735 return callContext, prepared, nil
736 },
737 })
738 if err != nil {
739 slog.Error("error generating title", "err", err)
740 return
741 }
742
743 title := resp.Response.Content.Text()
744
745 title = strings.ReplaceAll(title, "\n", " ")
746
747 // Remove thinking tags if present.
748 if idx := strings.Index(title, "</think>"); idx > 0 {
749 title = title[idx+len("</think>"):]
750 }
751
752 title = strings.TrimSpace(title)
753 if title == "" {
754 slog.Warn("failed to generate title", "warn", "empty title")
755 return
756 }
757
758 session.Title = title
759
760 var openrouterCost *float64
761 for _, step := range resp.Steps {
762 stepCost := a.openrouterCost(step.ProviderMetadata)
763 if stepCost != nil {
764 newCost := *stepCost
765 if openrouterCost != nil {
766 newCost += *openrouterCost
767 }
768 openrouterCost = &newCost
769 }
770 }
771
772 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
773 _, saveErr := a.sessions.Save(ctx, *session)
774 if saveErr != nil {
775 slog.Error("failed to save session title & usage", "error", saveErr)
776 return
777 }
778}
779
780func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
781 openrouterMetadata, ok := metadata[openrouter.Name]
782 if !ok {
783 return nil
784 }
785
786 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
787 if !ok {
788 return nil
789 }
790 return &opts.Usage.Cost
791}
792
793func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
794 modelConfig := model.CatwalkCfg
795 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
796 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
797 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
798 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
799
800 if a.isClaudeCode() {
801 cost = 0
802 }
803
804 a.eventTokensUsed(session.ID, model, usage, cost)
805
806 if overrideCost != nil {
807 session.Cost += *overrideCost
808 } else {
809 session.Cost += cost
810 }
811
812 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
813 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
814}
815
816func (a *sessionAgent) Cancel(sessionID string) {
817 // Cancel regular requests.
818 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
819 slog.Info("Request cancellation initiated", "session_id", sessionID)
820 cancel()
821 }
822
823 // Also check for summarize requests.
824 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
825 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
826 cancel()
827 }
828
829 if a.QueuedPrompts(sessionID) > 0 {
830 slog.Info("Clearing queued prompts", "session_id", sessionID)
831 a.messageQueue.Del(sessionID)
832 }
833}
834
835func (a *sessionAgent) ClearQueue(sessionID string) {
836 if a.QueuedPrompts(sessionID) > 0 {
837 slog.Info("Clearing queued prompts", "session_id", sessionID)
838 a.messageQueue.Del(sessionID)
839 }
840}
841
842func (a *sessionAgent) CancelAll() {
843 if !a.IsBusy() {
844 return
845 }
846 for key := range a.activeRequests.Seq2() {
847 a.Cancel(key) // key is sessionID
848 }
849
850 timeout := time.After(5 * time.Second)
851 for a.IsBusy() {
852 select {
853 case <-timeout:
854 return
855 default:
856 time.Sleep(200 * time.Millisecond)
857 }
858 }
859}
860
861func (a *sessionAgent) IsBusy() bool {
862 var busy bool
863 for cancelFunc := range a.activeRequests.Seq() {
864 if cancelFunc != nil {
865 busy = true
866 break
867 }
868 }
869 return busy
870}
871
872func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
873 _, busy := a.activeRequests.Get(sessionID)
874 return busy
875}
876
877func (a *sessionAgent) QueuedPrompts(sessionID string) int {
878 l, ok := a.messageQueue.Get(sessionID)
879 if !ok {
880 return 0
881 }
882 return len(l)
883}
884
885func (a *sessionAgent) SetModels(large Model, small Model) {
886 a.largeModel = large
887 a.smallModel = small
888}
889
890func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
891 a.tools = tools
892}
893
894func (a *sessionAgent) Model() Model {
895 return a.largeModel
896}
897
898func (a *sessionAgent) promptPrefix() string {
899 if a.isClaudeCode() {
900 return "You are Claude Code, Anthropic's official CLI for Claude."
901 }
902 return a.systemPromptPrefix
903}
904
905func (a *sessionAgent) isClaudeCode() bool {
906 cfg := config.Get()
907 pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
908 return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
909}