1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/google"
19 "charm.land/fantasy/providers/openai"
20 "charm.land/fantasy/providers/openrouter"
21 "github.com/charmbracelet/catwalk/pkg/catwalk"
22 "git.secluded.site/crush/internal/agent/tools"
23 "git.secluded.site/crush/internal/config"
24 "git.secluded.site/crush/internal/csync"
25 "git.secluded.site/crush/internal/message"
26 "git.secluded.site/crush/internal/notification"
27 "git.secluded.site/crush/internal/permission"
28 "git.secluded.site/crush/internal/session"
29)
30
31//go:embed templates/title.md
32var titlePrompt []byte
33
34//go:embed templates/summary.md
35var summaryPrompt []byte
36
37type SessionAgentCall struct {
38 SessionID string
39 Prompt string
40 ProviderOptions fantasy.ProviderOptions
41 Attachments []message.Attachment
42 MaxOutputTokens int64
43 Temperature *float64
44 TopP *float64
45 TopK *int64
46 FrequencyPenalty *float64
47 PresencePenalty *float64
48}
49
50type SessionAgent interface {
51 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
52 SetModels(large Model, small Model)
53 SetTools(tools []fantasy.AgentTool)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 QueuedPrompts(sessionID string) int
59 ClearQueue(sessionID string)
60 Summarize(context.Context, string, fantasy.ProviderOptions) error
61 Model() Model
62}
63
64const completionNotificationDelay = 5 * time.Second
65
66type Model struct {
67 Model fantasy.LanguageModel
68 CatwalkCfg catwalk.Model
69 ModelCfg config.SelectedModel
70}
71
72type sessionAgent struct {
73 largeModel Model
74 smallModel Model
75 systemPromptPrefix string
76 systemPrompt string
77 tools []fantasy.AgentTool
78 sessions session.Service
79 messages message.Service
80 disableAutoSummarize bool
81 isYolo bool
82
83 messageQueue *csync.Map[string, []SessionAgentCall]
84 activeRequests *csync.Map[string, context.CancelFunc]
85 notifier *notification.Notifier
86 notifyCtx context.Context
87 completionCancels *csync.Map[string, context.CancelFunc]
88}
89
90type SessionAgentOptions struct {
91 LargeModel Model
92 SmallModel Model
93 SystemPromptPrefix string
94 SystemPrompt string
95 DisableAutoSummarize bool
96 IsYolo bool
97 Sessions session.Service
98 Messages message.Service
99 Tools []fantasy.AgentTool
100 Notifier *notification.Notifier
101 NotificationCtx context.Context
102}
103
104func NewSessionAgent(
105 opts SessionAgentOptions,
106) SessionAgent {
107 notifyCtx := opts.NotificationCtx
108 if notifyCtx == nil {
109 notifyCtx = context.Background()
110 }
111
112 return &sessionAgent{
113 largeModel: opts.LargeModel,
114 smallModel: opts.SmallModel,
115 systemPromptPrefix: opts.SystemPromptPrefix,
116 systemPrompt: opts.SystemPrompt,
117 sessions: opts.Sessions,
118 messages: opts.Messages,
119 disableAutoSummarize: opts.DisableAutoSummarize,
120 tools: opts.Tools,
121 isYolo: opts.IsYolo,
122 messageQueue: csync.NewMap[string, []SessionAgentCall](),
123 activeRequests: csync.NewMap[string, context.CancelFunc](),
124 notifier: opts.Notifier,
125 notifyCtx: notifyCtx,
126 completionCancels: csync.NewMap[string, context.CancelFunc](),
127 }
128}
129
130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
131 if call.Prompt == "" {
132 return nil, ErrEmptyPrompt
133 }
134 if call.SessionID == "" {
135 return nil, ErrSessionMissing
136 }
137
138 a.cancelCompletionNotification(call.SessionID)
139
140 // Queue the message if busy
141 if a.IsSessionBusy(call.SessionID) {
142 existing, ok := a.messageQueue.Get(call.SessionID)
143 if !ok {
144 existing = []SessionAgentCall{}
145 }
146 existing = append(existing, call)
147 a.messageQueue.Set(call.SessionID, existing)
148 return nil, nil
149 }
150
151 if len(a.tools) > 0 {
152 // add anthropic caching to the last tool
153 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
154 }
155
156 agent := fantasy.NewAgent(
157 a.largeModel.Model,
158 fantasy.WithSystemPrompt(a.systemPrompt),
159 fantasy.WithTools(a.tools...),
160 )
161
162 sessionLock := sync.Mutex{}
163 currentSession, err := a.sessions.Get(ctx, call.SessionID)
164 if err != nil {
165 return nil, fmt.Errorf("failed to get session: %w", err)
166 }
167
168 msgs, err := a.getSessionMessages(ctx, currentSession)
169 if err != nil {
170 return nil, fmt.Errorf("failed to get session messages: %w", err)
171 }
172
173 var wg sync.WaitGroup
174 // Generate title if first message
175 if len(msgs) == 0 {
176 wg.Go(func() {
177 sessionLock.Lock()
178 a.generateTitle(ctx, ¤tSession, call.Prompt)
179 sessionLock.Unlock()
180 })
181 }
182
183 // Add the user message to the session
184 _, err = a.createUserMessage(ctx, call)
185 if err != nil {
186 return nil, err
187 }
188
189 // add the session to the context
190 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
191
192 genCtx, cancel := context.WithCancel(ctx)
193 a.activeRequests.Set(call.SessionID, cancel)
194
195 defer cancel()
196 defer a.activeRequests.Del(call.SessionID)
197
198 history, files := a.preparePrompt(msgs, call.Attachments...)
199
200 startTime := time.Now()
201 a.eventPromptSent(call.SessionID)
202
203 var currentAssistant *message.Message
204 var shouldSummarize bool
205 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
206 Prompt: call.Prompt,
207 Files: files,
208 Messages: history,
209 ProviderOptions: call.ProviderOptions,
210 MaxOutputTokens: &call.MaxOutputTokens,
211 TopP: call.TopP,
212 Temperature: call.Temperature,
213 PresencePenalty: call.PresencePenalty,
214 TopK: call.TopK,
215 FrequencyPenalty: call.FrequencyPenalty,
216 // Before each step create the new assistant message
217 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
218 prepared.Messages = options.Messages
219 // reset all cached items
220 for i := range prepared.Messages {
221 prepared.Messages[i].ProviderOptions = nil
222 }
223
224 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
225 a.messageQueue.Del(call.SessionID)
226 for _, queued := range queuedCalls {
227 userMessage, createErr := a.createUserMessage(callContext, queued)
228 if createErr != nil {
229 return callContext, prepared, createErr
230 }
231 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
232 }
233
234 lastSystemRoleInx := 0
235 systemMessageUpdated := false
236 for i, msg := range prepared.Messages {
237 // only add cache control to the last message
238 if msg.Role == fantasy.MessageRoleSystem {
239 lastSystemRoleInx = i
240 } else if !systemMessageUpdated {
241 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
242 systemMessageUpdated = true
243 }
244 // than add cache control to the last 2 messages
245 if i > len(prepared.Messages)-3 {
246 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
247 }
248 }
249
250 if a.systemPromptPrefix != "" {
251 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
252 }
253
254 var assistantMsg message.Message
255 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
256 Role: message.Assistant,
257 Parts: []message.ContentPart{},
258 Model: a.largeModel.ModelCfg.Model,
259 Provider: a.largeModel.ModelCfg.Provider,
260 })
261 if err != nil {
262 return callContext, prepared, err
263 }
264 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
265 currentAssistant = &assistantMsg
266 return callContext, prepared, err
267 },
268 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
269 currentAssistant.AppendReasoningContent(reasoning.Text)
270 return a.messages.Update(genCtx, *currentAssistant)
271 },
272 OnReasoningDelta: func(id string, text string) error {
273 currentAssistant.AppendReasoningContent(text)
274 return a.messages.Update(genCtx, *currentAssistant)
275 },
276 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
277 // handle anthropic signature
278 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
279 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
280 currentAssistant.AppendReasoningSignature(reasoning.Signature)
281 }
282 }
283 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
284 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
285 currentAssistant.AppendReasoningSignature(reasoning.Signature)
286 }
287 }
288 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
289 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
290 currentAssistant.SetReasoningResponsesData(reasoning)
291 }
292 }
293 currentAssistant.FinishThinking()
294 return a.messages.Update(genCtx, *currentAssistant)
295 },
296 OnTextDelta: func(id string, text string) error {
297 currentAssistant.AppendContent(text)
298 return a.messages.Update(genCtx, *currentAssistant)
299 },
300 OnToolInputStart: func(id string, toolName string) error {
301 toolCall := message.ToolCall{
302 ID: id,
303 Name: toolName,
304 ProviderExecuted: false,
305 Finished: false,
306 }
307 currentAssistant.AddToolCall(toolCall)
308 return a.messages.Update(genCtx, *currentAssistant)
309 },
310 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
311 // TODO: implement
312 },
313 OnToolCall: func(tc fantasy.ToolCallContent) error {
314 toolCall := message.ToolCall{
315 ID: tc.ToolCallID,
316 Name: tc.ToolName,
317 Input: tc.Input,
318 ProviderExecuted: false,
319 Finished: true,
320 }
321 currentAssistant.AddToolCall(toolCall)
322 return a.messages.Update(genCtx, *currentAssistant)
323 },
324 OnToolResult: func(result fantasy.ToolResultContent) error {
325 var resultContent string
326 isError := false
327 switch result.Result.GetType() {
328 case fantasy.ToolResultContentTypeText:
329 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
330 if ok {
331 resultContent = r.Text
332 }
333 case fantasy.ToolResultContentTypeError:
334 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
335 if ok {
336 isError = true
337 resultContent = r.Error.Error()
338 }
339 case fantasy.ToolResultContentTypeMedia:
340 // TODO: handle this message type
341 }
342 toolResult := message.ToolResult{
343 ToolCallID: result.ToolCallID,
344 Name: result.ToolName,
345 Content: resultContent,
346 IsError: isError,
347 Metadata: result.ClientMetadata,
348 }
349 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
350 Role: message.Tool,
351 Parts: []message.ContentPart{
352 toolResult,
353 },
354 })
355 if createMsgErr != nil {
356 return createMsgErr
357 }
358 return nil
359 },
360 OnStepFinish: func(stepResult fantasy.StepResult) error {
361 finishReason := message.FinishReasonUnknown
362 switch stepResult.FinishReason {
363 case fantasy.FinishReasonLength:
364 finishReason = message.FinishReasonMaxTokens
365 case fantasy.FinishReasonStop:
366 finishReason = message.FinishReasonEndTurn
367 case fantasy.FinishReasonToolCalls:
368 finishReason = message.FinishReasonToolUse
369 }
370 currentAssistant.AddFinish(finishReason, "", "")
371 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
372 sessionLock.Lock()
373 _, sessionErr := a.sessions.Save(genCtx, currentSession)
374 sessionLock.Unlock()
375 if sessionErr != nil {
376 return sessionErr
377 }
378 if finishReason == message.FinishReasonEndTurn {
379 a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
380 }
381 return a.messages.Update(genCtx, *currentAssistant)
382 },
383 StopWhen: []fantasy.StopCondition{
384 func(_ []fantasy.StepResult) bool {
385 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
386 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
387 remaining := cw - tokens
388 var threshold int64
389 if cw > 200_000 {
390 threshold = 20_000
391 } else {
392 threshold = int64(float64(cw) * 0.2)
393 }
394 if (remaining <= threshold) && !a.disableAutoSummarize {
395 shouldSummarize = true
396 return true
397 }
398 return false
399 },
400 },
401 })
402
403 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
404
405 if err != nil {
406 isCancelErr := errors.Is(err, context.Canceled)
407 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
408 if currentAssistant == nil {
409 return result, err
410 }
411 // Ensure we finish thinking on error to close the reasoning state
412 currentAssistant.FinishThinking()
413 toolCalls := currentAssistant.ToolCalls()
414 // INFO: we use the parent context here because the genCtx has been cancelled
415 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
416 if createErr != nil {
417 return nil, createErr
418 }
419 for _, tc := range toolCalls {
420 if !tc.Finished {
421 tc.Finished = true
422 tc.Input = "{}"
423 currentAssistant.AddToolCall(tc)
424 updateErr := a.messages.Update(ctx, *currentAssistant)
425 if updateErr != nil {
426 return nil, updateErr
427 }
428 }
429
430 found := false
431 for _, msg := range msgs {
432 if msg.Role == message.Tool {
433 for _, tr := range msg.ToolResults() {
434 if tr.ToolCallID == tc.ID {
435 found = true
436 break
437 }
438 }
439 }
440 if found {
441 break
442 }
443 }
444 if found {
445 continue
446 }
447 content := "There was an error while executing the tool"
448 if isCancelErr {
449 content = "Tool execution canceled by user"
450 } else if isPermissionErr {
451 content = "Permission denied"
452 }
453 toolResult := message.ToolResult{
454 ToolCallID: tc.ID,
455 Name: tc.Name,
456 Content: content,
457 IsError: true,
458 }
459 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
460 Role: message.Tool,
461 Parts: []message.ContentPart{
462 toolResult,
463 },
464 })
465 if createErr != nil {
466 return nil, createErr
467 }
468 }
469 if isCancelErr {
470 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
471 } else if isPermissionErr {
472 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
473 } else {
474 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
475 }
476 // INFO: we use the parent context here because the genCtx has been cancelled
477 updateErr := a.messages.Update(ctx, *currentAssistant)
478 if updateErr != nil {
479 return nil, updateErr
480 }
481 return nil, err
482 }
483 wg.Wait()
484
485 if shouldSummarize {
486 a.cancelCompletionNotification(call.SessionID)
487 a.activeRequests.Del(call.SessionID)
488 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
489 return nil, summarizeErr
490 }
491 // if the agent was not done...
492 if len(currentAssistant.ToolCalls()) > 0 {
493 existing, ok := a.messageQueue.Get(call.SessionID)
494 if !ok {
495 existing = []SessionAgentCall{}
496 }
497 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
498 existing = append(existing, call)
499 a.messageQueue.Set(call.SessionID, existing)
500 }
501 }
502
503 // release active request before processing queued messages
504 a.activeRequests.Del(call.SessionID)
505 cancel()
506
507 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
508 if !ok || len(queuedMessages) == 0 {
509 return result, err
510 }
511 // there are queued messages restart the loop
512 firstQueuedMessage := queuedMessages[0]
513 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
514 return a.Run(ctx, firstQueuedMessage)
515}
516
517func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
518 if a.notifier == nil {
519 return
520 }
521
522 if sessionTitle == "" {
523 sessionTitle = sessionID
524 }
525
526 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
527 cancel()
528 }
529
530 title := "💘 Crush is waiting"
531 message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
532 cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
533 if cancel == nil {
534 cancel = func() {}
535 }
536 a.completionCancels.Set(sessionID, cancel)
537}
538
539func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
540 if a.notifier == nil {
541 return
542 }
543
544 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
545 cancel()
546 }
547}
548
549func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
550 if a.IsSessionBusy(sessionID) {
551 return ErrSessionBusy
552 }
553
554 currentSession, err := a.sessions.Get(ctx, sessionID)
555 if err != nil {
556 return fmt.Errorf("failed to get session: %w", err)
557 }
558 msgs, err := a.getSessionMessages(ctx, currentSession)
559 if err != nil {
560 return err
561 }
562 if len(msgs) == 0 {
563 // nothing to summarize
564 return nil
565 }
566
567 aiMsgs, _ := a.preparePrompt(msgs)
568
569 genCtx, cancel := context.WithCancel(ctx)
570 a.activeRequests.Set(sessionID, cancel)
571 defer a.activeRequests.Del(sessionID)
572 defer cancel()
573
574 agent := fantasy.NewAgent(a.largeModel.Model,
575 fantasy.WithSystemPrompt(string(summaryPrompt)),
576 )
577 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
578 Role: message.Assistant,
579 Model: a.largeModel.Model.Model(),
580 Provider: a.largeModel.Model.Provider(),
581 IsSummaryMessage: true,
582 })
583 if err != nil {
584 return err
585 }
586
587 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
588 Prompt: "Provide a detailed summary of our conversation above.",
589 Messages: aiMsgs,
590 ProviderOptions: opts,
591 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
592 prepared.Messages = options.Messages
593 if a.systemPromptPrefix != "" {
594 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
595 }
596 return callContext, prepared, nil
597 },
598 OnReasoningDelta: func(id string, text string) error {
599 summaryMessage.AppendReasoningContent(text)
600 return a.messages.Update(genCtx, summaryMessage)
601 },
602 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
603 // handle anthropic signature
604 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
605 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
606 summaryMessage.AppendReasoningSignature(signature.Signature)
607 }
608 }
609 summaryMessage.FinishThinking()
610 return a.messages.Update(genCtx, summaryMessage)
611 },
612 OnTextDelta: func(id, text string) error {
613 summaryMessage.AppendContent(text)
614 return a.messages.Update(genCtx, summaryMessage)
615 },
616 })
617 if err != nil {
618 isCancelErr := errors.Is(err, context.Canceled)
619 if isCancelErr {
620 // User cancelled summarize we need to remove the summary message
621 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
622 return deleteErr
623 }
624 return err
625 }
626
627 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
628 err = a.messages.Update(genCtx, summaryMessage)
629 if err != nil {
630 return err
631 }
632
633 var openrouterCost *float64
634 for _, step := range resp.Steps {
635 stepCost := a.openrouterCost(step.ProviderMetadata)
636 if stepCost != nil {
637 newCost := *stepCost
638 if openrouterCost != nil {
639 newCost += *openrouterCost
640 }
641 openrouterCost = &newCost
642 }
643 }
644
645 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
646
647 // just in case get just the last usage
648 usage := resp.Response.Usage
649 currentSession.SummaryMessageID = summaryMessage.ID
650 currentSession.CompletionTokens = usage.OutputTokens
651 currentSession.PromptTokens = 0
652 _, err = a.sessions.Save(genCtx, currentSession)
653 return err
654}
655
656func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
657 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
658 return fantasy.ProviderOptions{}
659 }
660 return fantasy.ProviderOptions{
661 anthropic.Name: &anthropic.ProviderCacheControlOptions{
662 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
663 },
664 bedrock.Name: &anthropic.ProviderCacheControlOptions{
665 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
666 },
667 }
668}
669
670func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
671 var attachmentParts []message.ContentPart
672 for _, attachment := range call.Attachments {
673 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
674 }
675 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
676 parts = append(parts, attachmentParts...)
677 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
678 Role: message.User,
679 Parts: parts,
680 })
681 if err != nil {
682 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
683 }
684 return msg, nil
685}
686
687func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
688 var history []fantasy.Message
689 for _, m := range msgs {
690 if len(m.Parts) == 0 {
691 continue
692 }
693 // Assistant message without content or tool calls (cancelled before it returned anything)
694 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
695 continue
696 }
697 history = append(history, m.ToAIMessage()...)
698 }
699
700 var files []fantasy.FilePart
701 for _, attachment := range attachments {
702 files = append(files, fantasy.FilePart{
703 Filename: attachment.FileName,
704 Data: attachment.Content,
705 MediaType: attachment.MimeType,
706 })
707 }
708
709 return history, files
710}
711
712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
713 msgs, err := a.messages.List(ctx, session.ID)
714 if err != nil {
715 return nil, fmt.Errorf("failed to list messages: %w", err)
716 }
717
718 if session.SummaryMessageID != "" {
719 summaryMsgInex := -1
720 for i, msg := range msgs {
721 if msg.ID == session.SummaryMessageID {
722 summaryMsgInex = i
723 break
724 }
725 }
726 if summaryMsgInex != -1 {
727 msgs = msgs[summaryMsgInex:]
728 msgs[0].Role = message.User
729 }
730 }
731 return msgs, nil
732}
733
734func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
735 if prompt == "" {
736 return
737 }
738
739 var maxOutput int64 = 40
740 if a.smallModel.CatwalkCfg.CanReason {
741 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
742 }
743
744 agent := fantasy.NewAgent(a.smallModel.Model,
745 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
746 fantasy.WithMaxOutputTokens(maxOutput),
747 )
748
749 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
750 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
751 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
752 prepared.Messages = options.Messages
753 if a.systemPromptPrefix != "" {
754 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
755 }
756 return callContext, prepared, nil
757 },
758 })
759 if err != nil {
760 slog.Error("error generating title", "err", err)
761 return
762 }
763
764 title := resp.Response.Content.Text()
765
766 title = strings.ReplaceAll(title, "\n", " ")
767
768 // remove thinking tags if present
769 if idx := strings.Index(title, "</think>"); idx > 0 {
770 title = title[idx+len("</think>"):]
771 }
772
773 title = strings.TrimSpace(title)
774 if title == "" {
775 slog.Warn("failed to generate title", "warn", "empty title")
776 return
777 }
778
779 session.Title = title
780
781 var openrouterCost *float64
782 for _, step := range resp.Steps {
783 stepCost := a.openrouterCost(step.ProviderMetadata)
784 if stepCost != nil {
785 newCost := *stepCost
786 if openrouterCost != nil {
787 newCost += *openrouterCost
788 }
789 openrouterCost = &newCost
790 }
791 }
792
793 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
794 _, saveErr := a.sessions.Save(ctx, *session)
795 if saveErr != nil {
796 slog.Error("failed to save session title & usage", "error", saveErr)
797 return
798 }
799}
800
801func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
802 openrouterMetadata, ok := metadata[openrouter.Name]
803 if !ok {
804 return nil
805 }
806
807 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
808 if !ok {
809 return nil
810 }
811 return &opts.Usage.Cost
812}
813
814func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
815 modelConfig := model.CatwalkCfg
816 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
817 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
818 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
819 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
820
821 a.eventTokensUsed(session.ID, model, usage, cost)
822
823 if overrideCost != nil {
824 session.Cost += *overrideCost
825 } else {
826 session.Cost += cost
827 }
828
829 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
830 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
831}
832
833func (a *sessionAgent) Cancel(sessionID string) {
834 // Cancel regular requests
835 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
836 slog.Info("Request cancellation initiated", "session_id", sessionID)
837 cancel()
838 }
839
840 // Also check for summarize requests
841 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
842 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
843 cancel()
844 }
845
846 if a.QueuedPrompts(sessionID) > 0 {
847 slog.Info("Clearing queued prompts", "session_id", sessionID)
848 a.messageQueue.Del(sessionID)
849 }
850}
851
852func (a *sessionAgent) ClearQueue(sessionID string) {
853 if a.QueuedPrompts(sessionID) > 0 {
854 slog.Info("Clearing queued prompts", "session_id", sessionID)
855 a.messageQueue.Del(sessionID)
856 }
857}
858
859func (a *sessionAgent) CancelAll() {
860 if !a.IsBusy() {
861 // still ensure notifications are cancelled even when not busy
862 for cancel := range a.completionCancels.Seq() {
863 if cancel != nil {
864 cancel()
865 }
866 }
867 a.completionCancels.Reset(make(map[string]context.CancelFunc))
868 return
869 }
870 for key := range a.activeRequests.Seq2() {
871 a.Cancel(key) // key is sessionID
872 }
873
874 timeout := time.After(5 * time.Second)
875 for a.IsBusy() {
876 select {
877 case <-timeout:
878 return
879 default:
880 time.Sleep(200 * time.Millisecond)
881 }
882 }
883
884 for cancel := range a.completionCancels.Seq() {
885 if cancel != nil {
886 cancel()
887 }
888 }
889 a.completionCancels.Reset(make(map[string]context.CancelFunc))
890}
891
892func (a *sessionAgent) IsBusy() bool {
893 var busy bool
894 for cancelFunc := range a.activeRequests.Seq() {
895 if cancelFunc != nil {
896 busy = true
897 break
898 }
899 }
900 return busy
901}
902
903func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
904 _, busy := a.activeRequests.Get(sessionID)
905 return busy
906}
907
908func (a *sessionAgent) QueuedPrompts(sessionID string) int {
909 l, ok := a.messageQueue.Get(sessionID)
910 if !ok {
911 return 0
912 }
913 return len(l)
914}
915
916func (a *sessionAgent) SetModels(large Model, small Model) {
917 a.largeModel = large
918 a.smallModel = small
919}
920
921func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
922 a.tools = tools
923}
924
925func (a *sessionAgent) Model() Model {
926 return a.largeModel
927}