1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/google"
19 "charm.land/fantasy/providers/openai"
20 "charm.land/fantasy/providers/openrouter"
21 "github.com/charmbracelet/catwalk/pkg/catwalk"
22 "git.secluded.site/crush/internal/agent/tools"
23 "git.secluded.site/crush/internal/config"
24 "git.secluded.site/crush/internal/csync"
25 "git.secluded.site/crush/internal/message"
26 "git.secluded.site/crush/internal/notification"
27 "git.secluded.site/crush/internal/permission"
28 "git.secluded.site/crush/internal/session"
29)
30
31//go:embed templates/title.md
32var titlePrompt []byte
33
34//go:embed templates/summary.md
35var summaryPrompt []byte
36
37type SessionAgentCall struct {
38 SessionID string
39 Prompt string
40 ProviderOptions fantasy.ProviderOptions
41 Attachments []message.Attachment
42 MaxOutputTokens int64
43 Temperature *float64
44 TopP *float64
45 TopK *int64
46 FrequencyPenalty *float64
47 PresencePenalty *float64
48}
49
50type SessionAgent interface {
51 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
52 SetModels(large Model, small Model)
53 SetTools(tools []fantasy.AgentTool)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 QueuedPrompts(sessionID string) int
59 ClearQueue(sessionID string)
60 Summarize(context.Context, string, fantasy.ProviderOptions) error
61 Model() Model
62 // CancelCompletionNotification cancels any scheduled "turn ended"
63 // notification for the provided sessionID.
64 CancelCompletionNotification(sessionID string)
65 // HasPendingCompletionNotification returns true if a turn-end
66 // notification has been scheduled and not yet cancelled/shown.
67 HasPendingCompletionNotification(sessionID string) bool
68}
69
70const completionNotificationDelay = 5 * time.Second
71
72type Model struct {
73 Model fantasy.LanguageModel
74 CatwalkCfg catwalk.Model
75 ModelCfg config.SelectedModel
76}
77
78type sessionAgent struct {
79 largeModel Model
80 smallModel Model
81 systemPromptPrefix string
82 systemPrompt string
83 tools []fantasy.AgentTool
84 sessions session.Service
85 messages message.Service
86 disableAutoSummarize bool
87 isYolo bool
88
89 messageQueue *csync.Map[string, []SessionAgentCall]
90 activeRequests *csync.Map[string, context.CancelFunc]
91 notifier *notification.Notifier
92 notifyCtx context.Context
93 completionCancels *csync.Map[string, context.CancelFunc]
94}
95
96type SessionAgentOptions struct {
97 LargeModel Model
98 SmallModel Model
99 SystemPromptPrefix string
100 SystemPrompt string
101 DisableAutoSummarize bool
102 IsYolo bool
103 Sessions session.Service
104 Messages message.Service
105 Tools []fantasy.AgentTool
106 Notifier *notification.Notifier
107 NotificationCtx context.Context
108}
109
110func NewSessionAgent(
111 opts SessionAgentOptions,
112) SessionAgent {
113 notifyCtx := opts.NotificationCtx
114 if notifyCtx == nil {
115 notifyCtx = context.Background()
116 }
117
118 return &sessionAgent{
119 largeModel: opts.LargeModel,
120 smallModel: opts.SmallModel,
121 systemPromptPrefix: opts.SystemPromptPrefix,
122 systemPrompt: opts.SystemPrompt,
123 sessions: opts.Sessions,
124 messages: opts.Messages,
125 disableAutoSummarize: opts.DisableAutoSummarize,
126 tools: opts.Tools,
127 isYolo: opts.IsYolo,
128 messageQueue: csync.NewMap[string, []SessionAgentCall](),
129 activeRequests: csync.NewMap[string, context.CancelFunc](),
130 notifier: opts.Notifier,
131 notifyCtx: notifyCtx,
132 completionCancels: csync.NewMap[string, context.CancelFunc](),
133 }
134}
135
136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
137 if call.Prompt == "" {
138 return nil, ErrEmptyPrompt
139 }
140 if call.SessionID == "" {
141 return nil, ErrSessionMissing
142 }
143
144 a.cancelCompletionNotification(call.SessionID)
145
146 // Queue the message if busy
147 if a.IsSessionBusy(call.SessionID) {
148 existing, ok := a.messageQueue.Get(call.SessionID)
149 if !ok {
150 existing = []SessionAgentCall{}
151 }
152 existing = append(existing, call)
153 a.messageQueue.Set(call.SessionID, existing)
154 return nil, nil
155 }
156
157 if len(a.tools) > 0 {
158 // add anthropic caching to the last tool
159 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
160 }
161
162 agent := fantasy.NewAgent(
163 a.largeModel.Model,
164 fantasy.WithSystemPrompt(a.systemPrompt),
165 fantasy.WithTools(a.tools...),
166 )
167
168 sessionLock := sync.Mutex{}
169 currentSession, err := a.sessions.Get(ctx, call.SessionID)
170 if err != nil {
171 return nil, fmt.Errorf("failed to get session: %w", err)
172 }
173
174 msgs, err := a.getSessionMessages(ctx, currentSession)
175 if err != nil {
176 return nil, fmt.Errorf("failed to get session messages: %w", err)
177 }
178
179 var wg sync.WaitGroup
180 // Generate title if first message
181 if len(msgs) == 0 {
182 wg.Go(func() {
183 sessionLock.Lock()
184 a.generateTitle(ctx, ¤tSession, call.Prompt)
185 sessionLock.Unlock()
186 })
187 }
188
189 // Add the user message to the session
190 _, err = a.createUserMessage(ctx, call)
191 if err != nil {
192 return nil, err
193 }
194
195 // add the session to the context
196 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
197
198 genCtx, cancel := context.WithCancel(ctx)
199 a.activeRequests.Set(call.SessionID, cancel)
200
201 defer cancel()
202 defer a.activeRequests.Del(call.SessionID)
203
204 history, files := a.preparePrompt(msgs, call.Attachments...)
205
206 startTime := time.Now()
207 a.eventPromptSent(call.SessionID)
208
209 var currentAssistant *message.Message
210 var shouldSummarize bool
211 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
212 Prompt: call.Prompt,
213 Files: files,
214 Messages: history,
215 ProviderOptions: call.ProviderOptions,
216 MaxOutputTokens: &call.MaxOutputTokens,
217 TopP: call.TopP,
218 Temperature: call.Temperature,
219 PresencePenalty: call.PresencePenalty,
220 TopK: call.TopK,
221 FrequencyPenalty: call.FrequencyPenalty,
222 // Before each step create the new assistant message
223 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
224 prepared.Messages = options.Messages
225 // reset all cached items
226 for i := range prepared.Messages {
227 prepared.Messages[i].ProviderOptions = nil
228 }
229
230 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
231 a.messageQueue.Del(call.SessionID)
232 for _, queued := range queuedCalls {
233 userMessage, createErr := a.createUserMessage(callContext, queued)
234 if createErr != nil {
235 return callContext, prepared, createErr
236 }
237 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
238 }
239
240 lastSystemRoleInx := 0
241 systemMessageUpdated := false
242 for i, msg := range prepared.Messages {
243 // only add cache control to the last message
244 if msg.Role == fantasy.MessageRoleSystem {
245 lastSystemRoleInx = i
246 } else if !systemMessageUpdated {
247 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
248 systemMessageUpdated = true
249 }
250 // than add cache control to the last 2 messages
251 if i > len(prepared.Messages)-3 {
252 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
253 }
254 }
255
256 if a.systemPromptPrefix != "" {
257 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
258 }
259
260 var assistantMsg message.Message
261 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
262 Role: message.Assistant,
263 Parts: []message.ContentPart{},
264 Model: a.largeModel.ModelCfg.Model,
265 Provider: a.largeModel.ModelCfg.Provider,
266 })
267 if err != nil {
268 return callContext, prepared, err
269 }
270 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
271 currentAssistant = &assistantMsg
272 return callContext, prepared, err
273 },
274 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
275 currentAssistant.AppendReasoningContent(reasoning.Text)
276 return a.messages.Update(genCtx, *currentAssistant)
277 },
278 OnReasoningDelta: func(id string, text string) error {
279 currentAssistant.AppendReasoningContent(text)
280 return a.messages.Update(genCtx, *currentAssistant)
281 },
282 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
283 // handle anthropic signature
284 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
285 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
286 currentAssistant.AppendReasoningSignature(reasoning.Signature)
287 }
288 }
289 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
290 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
291 currentAssistant.AppendReasoningSignature(reasoning.Signature)
292 }
293 }
294 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
295 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
296 currentAssistant.SetReasoningResponsesData(reasoning)
297 }
298 }
299 currentAssistant.FinishThinking()
300 return a.messages.Update(genCtx, *currentAssistant)
301 },
302 OnTextDelta: func(id string, text string) error {
303 currentAssistant.AppendContent(text)
304 return a.messages.Update(genCtx, *currentAssistant)
305 },
306 OnToolInputStart: func(id string, toolName string) error {
307 toolCall := message.ToolCall{
308 ID: id,
309 Name: toolName,
310 ProviderExecuted: false,
311 Finished: false,
312 }
313 currentAssistant.AddToolCall(toolCall)
314 return a.messages.Update(genCtx, *currentAssistant)
315 },
316 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
317 // TODO: implement
318 },
319 OnToolCall: func(tc fantasy.ToolCallContent) error {
320 toolCall := message.ToolCall{
321 ID: tc.ToolCallID,
322 Name: tc.ToolName,
323 Input: tc.Input,
324 ProviderExecuted: false,
325 Finished: true,
326 }
327 currentAssistant.AddToolCall(toolCall)
328 return a.messages.Update(genCtx, *currentAssistant)
329 },
330 OnToolResult: func(result fantasy.ToolResultContent) error {
331 var resultContent string
332 isError := false
333 switch result.Result.GetType() {
334 case fantasy.ToolResultContentTypeText:
335 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
336 if ok {
337 resultContent = r.Text
338 }
339 case fantasy.ToolResultContentTypeError:
340 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
341 if ok {
342 isError = true
343 resultContent = r.Error.Error()
344 }
345 case fantasy.ToolResultContentTypeMedia:
346 // TODO: handle this message type
347 }
348 toolResult := message.ToolResult{
349 ToolCallID: result.ToolCallID,
350 Name: result.ToolName,
351 Content: resultContent,
352 IsError: isError,
353 Metadata: result.ClientMetadata,
354 }
355 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
356 Role: message.Tool,
357 Parts: []message.ContentPart{
358 toolResult,
359 },
360 })
361 if createMsgErr != nil {
362 return createMsgErr
363 }
364 return nil
365 },
366 OnStepFinish: func(stepResult fantasy.StepResult) error {
367 finishReason := message.FinishReasonUnknown
368 switch stepResult.FinishReason {
369 case fantasy.FinishReasonLength:
370 finishReason = message.FinishReasonMaxTokens
371 case fantasy.FinishReasonStop:
372 finishReason = message.FinishReasonEndTurn
373 case fantasy.FinishReasonToolCalls:
374 finishReason = message.FinishReasonToolUse
375 }
376 currentAssistant.AddFinish(finishReason, "", "")
377 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
378 sessionLock.Lock()
379 _, sessionErr := a.sessions.Save(genCtx, currentSession)
380 sessionLock.Unlock()
381 if sessionErr != nil {
382 return sessionErr
383 }
384 if finishReason == message.FinishReasonEndTurn {
385 a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
386 }
387 return a.messages.Update(genCtx, *currentAssistant)
388 },
389 StopWhen: []fantasy.StopCondition{
390 func(_ []fantasy.StepResult) bool {
391 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
392 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
393 remaining := cw - tokens
394 var threshold int64
395 if cw > 200_000 {
396 threshold = 20_000
397 } else {
398 threshold = int64(float64(cw) * 0.2)
399 }
400 if (remaining <= threshold) && !a.disableAutoSummarize {
401 shouldSummarize = true
402 return true
403 }
404 return false
405 },
406 },
407 })
408
409 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
410
411 if err != nil {
412 isCancelErr := errors.Is(err, context.Canceled)
413 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
414 if currentAssistant == nil {
415 return result, err
416 }
417 // Ensure we finish thinking on error to close the reasoning state
418 currentAssistant.FinishThinking()
419 toolCalls := currentAssistant.ToolCalls()
420 // INFO: we use the parent context here because the genCtx has been cancelled
421 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
422 if createErr != nil {
423 return nil, createErr
424 }
425 for _, tc := range toolCalls {
426 if !tc.Finished {
427 tc.Finished = true
428 tc.Input = "{}"
429 currentAssistant.AddToolCall(tc)
430 updateErr := a.messages.Update(ctx, *currentAssistant)
431 if updateErr != nil {
432 return nil, updateErr
433 }
434 }
435
436 found := false
437 for _, msg := range msgs {
438 if msg.Role == message.Tool {
439 for _, tr := range msg.ToolResults() {
440 if tr.ToolCallID == tc.ID {
441 found = true
442 break
443 }
444 }
445 }
446 if found {
447 break
448 }
449 }
450 if found {
451 continue
452 }
453 content := "There was an error while executing the tool"
454 if isCancelErr {
455 content = "Tool execution canceled by user"
456 } else if isPermissionErr {
457 content = "User denied permission"
458 }
459 toolResult := message.ToolResult{
460 ToolCallID: tc.ID,
461 Name: tc.Name,
462 Content: content,
463 IsError: true,
464 }
465 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
466 Role: message.Tool,
467 Parts: []message.ContentPart{
468 toolResult,
469 },
470 })
471 if createErr != nil {
472 return nil, createErr
473 }
474 }
475 if isCancelErr {
476 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
477 } else if isPermissionErr {
478 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
479 } else {
480 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
481 }
482 // INFO: we use the parent context here because the genCtx has been cancelled
483 updateErr := a.messages.Update(ctx, *currentAssistant)
484 if updateErr != nil {
485 return nil, updateErr
486 }
487 return nil, err
488 }
489 wg.Wait()
490
491 if shouldSummarize {
492 a.cancelCompletionNotification(call.SessionID)
493 a.activeRequests.Del(call.SessionID)
494 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
495 return nil, summarizeErr
496 }
497 // if the agent was not done...
498 if len(currentAssistant.ToolCalls()) > 0 {
499 existing, ok := a.messageQueue.Get(call.SessionID)
500 if !ok {
501 existing = []SessionAgentCall{}
502 }
503 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
504 existing = append(existing, call)
505 a.messageQueue.Set(call.SessionID, existing)
506 }
507 }
508
509 // release active request before processing queued messages
510 a.activeRequests.Del(call.SessionID)
511 cancel()
512
513 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
514 if !ok || len(queuedMessages) == 0 {
515 return result, err
516 }
517 // there are queued messages restart the loop
518 firstQueuedMessage := queuedMessages[0]
519 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
520 return a.Run(ctx, firstQueuedMessage)
521}
522
523func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
524 // Do not emit notifications for Agent-tool sub-sessions.
525 if a.sessions != nil && a.sessions.IsAgentToolSession(sessionID) {
526 return
527 }
528 if a.notifier == nil {
529 return
530 }
531
532 if sessionTitle == "" {
533 sessionTitle = sessionID
534 }
535
536 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
537 cancel()
538 }
539
540 title := "💘 Crush is waiting"
541 message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
542 cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
543 if cancel == nil {
544 cancel = func() {}
545 }
546 a.completionCancels.Set(sessionID, cancel)
547}
548
549func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
550 if a.notifier == nil {
551 return
552 }
553
554 if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
555 cancel()
556 }
557}
558
559// CancelCompletionNotification implements SessionAgent.
560func (a *sessionAgent) CancelCompletionNotification(sessionID string) {
561 a.cancelCompletionNotification(sessionID)
562}
563
564// HasPendingCompletionNotification implements SessionAgent.
565func (a *sessionAgent) HasPendingCompletionNotification(sessionID string) bool {
566 if a.IsSessionBusy(sessionID) {
567 return false
568 }
569 _, ok := a.completionCancels.Get(sessionID)
570 return ok
571}
572
573func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
574 if a.IsSessionBusy(sessionID) {
575 return ErrSessionBusy
576 }
577
578 currentSession, err := a.sessions.Get(ctx, sessionID)
579 if err != nil {
580 return fmt.Errorf("failed to get session: %w", err)
581 }
582 msgs, err := a.getSessionMessages(ctx, currentSession)
583 if err != nil {
584 return err
585 }
586 if len(msgs) == 0 {
587 // nothing to summarize
588 return nil
589 }
590
591 aiMsgs, _ := a.preparePrompt(msgs)
592
593 genCtx, cancel := context.WithCancel(ctx)
594 a.activeRequests.Set(sessionID, cancel)
595 defer a.activeRequests.Del(sessionID)
596 defer cancel()
597
598 agent := fantasy.NewAgent(a.largeModel.Model,
599 fantasy.WithSystemPrompt(string(summaryPrompt)),
600 )
601 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
602 Role: message.Assistant,
603 Model: a.largeModel.Model.Model(),
604 Provider: a.largeModel.Model.Provider(),
605 IsSummaryMessage: true,
606 })
607 if err != nil {
608 return err
609 }
610
611 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
612 Prompt: "Provide a detailed summary of our conversation above.",
613 Messages: aiMsgs,
614 ProviderOptions: opts,
615 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
616 prepared.Messages = options.Messages
617 if a.systemPromptPrefix != "" {
618 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
619 }
620 return callContext, prepared, nil
621 },
622 OnReasoningDelta: func(id string, text string) error {
623 summaryMessage.AppendReasoningContent(text)
624 return a.messages.Update(genCtx, summaryMessage)
625 },
626 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
627 // handle anthropic signature
628 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
629 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
630 summaryMessage.AppendReasoningSignature(signature.Signature)
631 }
632 }
633 summaryMessage.FinishThinking()
634 return a.messages.Update(genCtx, summaryMessage)
635 },
636 OnTextDelta: func(id, text string) error {
637 summaryMessage.AppendContent(text)
638 return a.messages.Update(genCtx, summaryMessage)
639 },
640 })
641 if err != nil {
642 isCancelErr := errors.Is(err, context.Canceled)
643 if isCancelErr {
644 // User cancelled summarize we need to remove the summary message
645 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
646 return deleteErr
647 }
648 return err
649 }
650
651 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
652 err = a.messages.Update(genCtx, summaryMessage)
653 if err != nil {
654 return err
655 }
656
657 var openrouterCost *float64
658 for _, step := range resp.Steps {
659 stepCost := a.openrouterCost(step.ProviderMetadata)
660 if stepCost != nil {
661 newCost := *stepCost
662 if openrouterCost != nil {
663 newCost += *openrouterCost
664 }
665 openrouterCost = &newCost
666 }
667 }
668
669 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
670
671 // just in case get just the last usage
672 usage := resp.Response.Usage
673 currentSession.SummaryMessageID = summaryMessage.ID
674 currentSession.CompletionTokens = usage.OutputTokens
675 currentSession.PromptTokens = 0
676 _, err = a.sessions.Save(genCtx, currentSession)
677 return err
678}
679
680func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
681 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
682 return fantasy.ProviderOptions{}
683 }
684 return fantasy.ProviderOptions{
685 anthropic.Name: &anthropic.ProviderCacheControlOptions{
686 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
687 },
688 bedrock.Name: &anthropic.ProviderCacheControlOptions{
689 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
690 },
691 }
692}
693
694func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
695 var attachmentParts []message.ContentPart
696 for _, attachment := range call.Attachments {
697 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
698 }
699 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
700 parts = append(parts, attachmentParts...)
701 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
702 Role: message.User,
703 Parts: parts,
704 })
705 if err != nil {
706 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
707 }
708 return msg, nil
709}
710
711func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
712 var history []fantasy.Message
713 for _, m := range msgs {
714 if len(m.Parts) == 0 {
715 continue
716 }
717 // Assistant message without content or tool calls (cancelled before it returned anything)
718 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
719 continue
720 }
721 history = append(history, m.ToAIMessage()...)
722 }
723
724 var files []fantasy.FilePart
725 for _, attachment := range attachments {
726 files = append(files, fantasy.FilePart{
727 Filename: attachment.FileName,
728 Data: attachment.Content,
729 MediaType: attachment.MimeType,
730 })
731 }
732
733 return history, files
734}
735
736func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
737 msgs, err := a.messages.List(ctx, session.ID)
738 if err != nil {
739 return nil, fmt.Errorf("failed to list messages: %w", err)
740 }
741
742 if session.SummaryMessageID != "" {
743 summaryMsgInex := -1
744 for i, msg := range msgs {
745 if msg.ID == session.SummaryMessageID {
746 summaryMsgInex = i
747 break
748 }
749 }
750 if summaryMsgInex != -1 {
751 msgs = msgs[summaryMsgInex:]
752 msgs[0].Role = message.User
753 }
754 }
755 return msgs, nil
756}
757
758func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
759 if prompt == "" {
760 return
761 }
762
763 var maxOutput int64 = 40
764 if a.smallModel.CatwalkCfg.CanReason {
765 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
766 }
767
768 agent := fantasy.NewAgent(a.smallModel.Model,
769 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
770 fantasy.WithMaxOutputTokens(maxOutput),
771 )
772
773 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
774 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
775 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
776 prepared.Messages = options.Messages
777 if a.systemPromptPrefix != "" {
778 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
779 }
780 return callContext, prepared, nil
781 },
782 })
783 if err != nil {
784 slog.Error("error generating title", "err", err)
785 return
786 }
787
788 title := resp.Response.Content.Text()
789
790 title = strings.ReplaceAll(title, "\n", " ")
791
792 // remove thinking tags if present
793 if idx := strings.Index(title, "</think>"); idx > 0 {
794 title = title[idx+len("</think>"):]
795 }
796
797 title = strings.TrimSpace(title)
798 if title == "" {
799 slog.Warn("failed to generate title", "warn", "empty title")
800 return
801 }
802
803 session.Title = title
804
805 var openrouterCost *float64
806 for _, step := range resp.Steps {
807 stepCost := a.openrouterCost(step.ProviderMetadata)
808 if stepCost != nil {
809 newCost := *stepCost
810 if openrouterCost != nil {
811 newCost += *openrouterCost
812 }
813 openrouterCost = &newCost
814 }
815 }
816
817 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
818 _, saveErr := a.sessions.Save(ctx, *session)
819 if saveErr != nil {
820 slog.Error("failed to save session title & usage", "error", saveErr)
821 return
822 }
823}
824
825func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
826 openrouterMetadata, ok := metadata[openrouter.Name]
827 if !ok {
828 return nil
829 }
830
831 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
832 if !ok {
833 return nil
834 }
835 return &opts.Usage.Cost
836}
837
838func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
839 modelConfig := model.CatwalkCfg
840 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
841 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
842 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
843 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
844
845 a.eventTokensUsed(session.ID, model, usage, cost)
846
847 if overrideCost != nil {
848 session.Cost += *overrideCost
849 } else {
850 session.Cost += cost
851 }
852
853 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
854 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
855}
856
857func (a *sessionAgent) Cancel(sessionID string) {
858 // Cancel regular requests
859 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
860 slog.Info("Request cancellation initiated", "session_id", sessionID)
861 cancel()
862 }
863
864 // Also check for summarize requests
865 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
866 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
867 cancel()
868 }
869
870 if a.QueuedPrompts(sessionID) > 0 {
871 slog.Info("Clearing queued prompts", "session_id", sessionID)
872 a.messageQueue.Del(sessionID)
873 }
874}
875
876func (a *sessionAgent) ClearQueue(sessionID string) {
877 if a.QueuedPrompts(sessionID) > 0 {
878 slog.Info("Clearing queued prompts", "session_id", sessionID)
879 a.messageQueue.Del(sessionID)
880 }
881}
882
883func (a *sessionAgent) CancelAll() {
884 if !a.IsBusy() {
885 // still ensure notifications are cancelled even when not busy
886 for cancel := range a.completionCancels.Seq() {
887 if cancel != nil {
888 cancel()
889 }
890 }
891 a.completionCancels.Reset(make(map[string]context.CancelFunc))
892 return
893 }
894 for key := range a.activeRequests.Seq2() {
895 a.Cancel(key) // key is sessionID
896 }
897
898 timeout := time.After(5 * time.Second)
899 for a.IsBusy() {
900 select {
901 case <-timeout:
902 return
903 default:
904 time.Sleep(200 * time.Millisecond)
905 }
906 }
907
908 for cancel := range a.completionCancels.Seq() {
909 if cancel != nil {
910 cancel()
911 }
912 }
913 a.completionCancels.Reset(make(map[string]context.CancelFunc))
914}
915
916func (a *sessionAgent) IsBusy() bool {
917 var busy bool
918 for cancelFunc := range a.activeRequests.Seq() {
919 if cancelFunc != nil {
920 busy = true
921 break
922 }
923 }
924 return busy
925}
926
927func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
928 _, busy := a.activeRequests.Get(sessionID)
929 return busy
930}
931
932func (a *sessionAgent) QueuedPrompts(sessionID string) int {
933 l, ok := a.messageQueue.Get(sessionID)
934 if !ok {
935 return 0
936 }
937 return len(l)
938}
939
940func (a *sessionAgent) SetModels(large Model, small Model) {
941 a.largeModel = large
942 a.smallModel = small
943}
944
945func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
946 a.tools = tools
947}
948
949func (a *sessionAgent) Model() Model {
950 return a.largeModel
951}