agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"charm.land/fantasy/providers/openrouter"
 21	"github.com/charmbracelet/catwalk/pkg/catwalk"
 22	"git.secluded.site/crush/internal/agent/tools"
 23	"git.secluded.site/crush/internal/config"
 24	"git.secluded.site/crush/internal/csync"
 25	"git.secluded.site/crush/internal/message"
 26	"git.secluded.site/crush/internal/notification"
 27	"git.secluded.site/crush/internal/permission"
 28	"git.secluded.site/crush/internal/session"
 29)
 30
 31//go:embed templates/title.md
 32var titlePrompt []byte
 33
 34//go:embed templates/summary.md
 35var summaryPrompt []byte
 36
 37type SessionAgentCall struct {
 38	SessionID        string
 39	Prompt           string
 40	ProviderOptions  fantasy.ProviderOptions
 41	Attachments      []message.Attachment
 42	MaxOutputTokens  int64
 43	Temperature      *float64
 44	TopP             *float64
 45	TopK             *int64
 46	FrequencyPenalty *float64
 47	PresencePenalty  *float64
 48}
 49
 50type SessionAgent interface {
 51	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 52	SetModels(large Model, small Model)
 53	SetTools(tools []fantasy.AgentTool)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	QueuedPrompts(sessionID string) int
 59	ClearQueue(sessionID string)
 60	Summarize(context.Context, string, fantasy.ProviderOptions) error
 61	Model() Model
 62	// CancelCompletionNotification cancels any scheduled "turn ended"
 63	// notification for the provided sessionID.
 64	CancelCompletionNotification(sessionID string)
 65	// HasPendingCompletionNotification returns true if a turn-end
 66	// notification has been scheduled and not yet cancelled/shown.
 67	HasPendingCompletionNotification(sessionID string) bool
 68}
 69
 70const completionNotificationDelay = 5 * time.Second
 71
 72type Model struct {
 73	Model      fantasy.LanguageModel
 74	CatwalkCfg catwalk.Model
 75	ModelCfg   config.SelectedModel
 76}
 77
 78type sessionAgent struct {
 79	largeModel           Model
 80	smallModel           Model
 81	systemPromptPrefix   string
 82	systemPrompt         string
 83	tools                []fantasy.AgentTool
 84	sessions             session.Service
 85	messages             message.Service
 86	disableAutoSummarize bool
 87	isYolo               bool
 88
 89	messageQueue      *csync.Map[string, []SessionAgentCall]
 90	activeRequests    *csync.Map[string, context.CancelFunc]
 91	notifier          *notification.Notifier
 92	notifyCtx         context.Context
 93	completionCancels *csync.Map[string, context.CancelFunc]
 94}
 95
 96type SessionAgentOptions struct {
 97	LargeModel           Model
 98	SmallModel           Model
 99	SystemPromptPrefix   string
100	SystemPrompt         string
101	DisableAutoSummarize bool
102	IsYolo               bool
103	Sessions             session.Service
104	Messages             message.Service
105	Tools                []fantasy.AgentTool
106	Notifier             *notification.Notifier
107	NotificationCtx      context.Context
108}
109
110func NewSessionAgent(
111	opts SessionAgentOptions,
112) SessionAgent {
113	notifyCtx := opts.NotificationCtx
114	if notifyCtx == nil {
115		notifyCtx = context.Background()
116	}
117
118	return &sessionAgent{
119		largeModel:           opts.LargeModel,
120		smallModel:           opts.SmallModel,
121		systemPromptPrefix:   opts.SystemPromptPrefix,
122		systemPrompt:         opts.SystemPrompt,
123		sessions:             opts.Sessions,
124		messages:             opts.Messages,
125		disableAutoSummarize: opts.DisableAutoSummarize,
126		tools:                opts.Tools,
127		isYolo:               opts.IsYolo,
128		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
129		activeRequests:       csync.NewMap[string, context.CancelFunc](),
130		notifier:             opts.Notifier,
131		notifyCtx:            notifyCtx,
132		completionCancels:    csync.NewMap[string, context.CancelFunc](),
133	}
134}
135
136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
137	if call.Prompt == "" {
138		return nil, ErrEmptyPrompt
139	}
140	if call.SessionID == "" {
141		return nil, ErrSessionMissing
142	}
143
144	a.cancelCompletionNotification(call.SessionID)
145
146	// Queue the message if busy
147	if a.IsSessionBusy(call.SessionID) {
148		existing, ok := a.messageQueue.Get(call.SessionID)
149		if !ok {
150			existing = []SessionAgentCall{}
151		}
152		existing = append(existing, call)
153		a.messageQueue.Set(call.SessionID, existing)
154		return nil, nil
155	}
156
157	if len(a.tools) > 0 {
158		// add anthropic caching to the last tool
159		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
160	}
161
162	agent := fantasy.NewAgent(
163		a.largeModel.Model,
164		fantasy.WithSystemPrompt(a.systemPrompt),
165		fantasy.WithTools(a.tools...),
166	)
167
168	sessionLock := sync.Mutex{}
169	currentSession, err := a.sessions.Get(ctx, call.SessionID)
170	if err != nil {
171		return nil, fmt.Errorf("failed to get session: %w", err)
172	}
173
174	msgs, err := a.getSessionMessages(ctx, currentSession)
175	if err != nil {
176		return nil, fmt.Errorf("failed to get session messages: %w", err)
177	}
178
179	var wg sync.WaitGroup
180	// Generate title if first message
181	if len(msgs) == 0 {
182		wg.Go(func() {
183			sessionLock.Lock()
184			a.generateTitle(ctx, &currentSession, call.Prompt)
185			sessionLock.Unlock()
186		})
187	}
188
189	// Add the user message to the session
190	_, err = a.createUserMessage(ctx, call)
191	if err != nil {
192		return nil, err
193	}
194
195	// add the session to the context
196	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
197
198	genCtx, cancel := context.WithCancel(ctx)
199	a.activeRequests.Set(call.SessionID, cancel)
200
201	defer cancel()
202	defer a.activeRequests.Del(call.SessionID)
203
204	history, files := a.preparePrompt(msgs, call.Attachments...)
205
206	startTime := time.Now()
207	a.eventPromptSent(call.SessionID)
208
209	var currentAssistant *message.Message
210	var shouldSummarize bool
211	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
212		Prompt:           call.Prompt,
213		Files:            files,
214		Messages:         history,
215		ProviderOptions:  call.ProviderOptions,
216		MaxOutputTokens:  &call.MaxOutputTokens,
217		TopP:             call.TopP,
218		Temperature:      call.Temperature,
219		PresencePenalty:  call.PresencePenalty,
220		TopK:             call.TopK,
221		FrequencyPenalty: call.FrequencyPenalty,
222		// Before each step create the new assistant message
223		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
224			prepared.Messages = options.Messages
225			// reset all cached items
226			for i := range prepared.Messages {
227				prepared.Messages[i].ProviderOptions = nil
228			}
229
230			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
231			a.messageQueue.Del(call.SessionID)
232			for _, queued := range queuedCalls {
233				userMessage, createErr := a.createUserMessage(callContext, queued)
234				if createErr != nil {
235					return callContext, prepared, createErr
236				}
237				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
238			}
239
240			lastSystemRoleInx := 0
241			systemMessageUpdated := false
242			for i, msg := range prepared.Messages {
243				// only add cache control to the last message
244				if msg.Role == fantasy.MessageRoleSystem {
245					lastSystemRoleInx = i
246				} else if !systemMessageUpdated {
247					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
248					systemMessageUpdated = true
249				}
250				// than add cache control to the last 2 messages
251				if i > len(prepared.Messages)-3 {
252					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
253				}
254			}
255
256			if a.systemPromptPrefix != "" {
257				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
258			}
259
260			var assistantMsg message.Message
261			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
262				Role:     message.Assistant,
263				Parts:    []message.ContentPart{},
264				Model:    a.largeModel.ModelCfg.Model,
265				Provider: a.largeModel.ModelCfg.Provider,
266			})
267			if err != nil {
268				return callContext, prepared, err
269			}
270			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
271			currentAssistant = &assistantMsg
272			return callContext, prepared, err
273		},
274		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
275			currentAssistant.AppendReasoningContent(reasoning.Text)
276			return a.messages.Update(genCtx, *currentAssistant)
277		},
278		OnReasoningDelta: func(id string, text string) error {
279			currentAssistant.AppendReasoningContent(text)
280			return a.messages.Update(genCtx, *currentAssistant)
281		},
282		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
283			// handle anthropic signature
284			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
285				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
286					currentAssistant.AppendReasoningSignature(reasoning.Signature)
287				}
288			}
289			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
290				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
291					currentAssistant.AppendReasoningSignature(reasoning.Signature)
292				}
293			}
294			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
295				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
296					currentAssistant.SetReasoningResponsesData(reasoning)
297				}
298			}
299			currentAssistant.FinishThinking()
300			return a.messages.Update(genCtx, *currentAssistant)
301		},
302		OnTextDelta: func(id string, text string) error {
303			currentAssistant.AppendContent(text)
304			return a.messages.Update(genCtx, *currentAssistant)
305		},
306		OnToolInputStart: func(id string, toolName string) error {
307			toolCall := message.ToolCall{
308				ID:               id,
309				Name:             toolName,
310				ProviderExecuted: false,
311				Finished:         false,
312			}
313			currentAssistant.AddToolCall(toolCall)
314			return a.messages.Update(genCtx, *currentAssistant)
315		},
316		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
317			// TODO: implement
318		},
319		OnToolCall: func(tc fantasy.ToolCallContent) error {
320			toolCall := message.ToolCall{
321				ID:               tc.ToolCallID,
322				Name:             tc.ToolName,
323				Input:            tc.Input,
324				ProviderExecuted: false,
325				Finished:         true,
326			}
327			currentAssistant.AddToolCall(toolCall)
328			return a.messages.Update(genCtx, *currentAssistant)
329		},
330		OnToolResult: func(result fantasy.ToolResultContent) error {
331			var resultContent string
332			isError := false
333			switch result.Result.GetType() {
334			case fantasy.ToolResultContentTypeText:
335				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
336				if ok {
337					resultContent = r.Text
338				}
339			case fantasy.ToolResultContentTypeError:
340				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
341				if ok {
342					isError = true
343					resultContent = r.Error.Error()
344				}
345			case fantasy.ToolResultContentTypeMedia:
346				// TODO: handle this message type
347			}
348			toolResult := message.ToolResult{
349				ToolCallID: result.ToolCallID,
350				Name:       result.ToolName,
351				Content:    resultContent,
352				IsError:    isError,
353				Metadata:   result.ClientMetadata,
354			}
355			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
356				Role: message.Tool,
357				Parts: []message.ContentPart{
358					toolResult,
359				},
360			})
361			if createMsgErr != nil {
362				return createMsgErr
363			}
364			return nil
365		},
366		OnStepFinish: func(stepResult fantasy.StepResult) error {
367			finishReason := message.FinishReasonUnknown
368			switch stepResult.FinishReason {
369			case fantasy.FinishReasonLength:
370				finishReason = message.FinishReasonMaxTokens
371			case fantasy.FinishReasonStop:
372				finishReason = message.FinishReasonEndTurn
373			case fantasy.FinishReasonToolCalls:
374				finishReason = message.FinishReasonToolUse
375			}
376			currentAssistant.AddFinish(finishReason, "", "")
377			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
378			sessionLock.Lock()
379			_, sessionErr := a.sessions.Save(genCtx, currentSession)
380			sessionLock.Unlock()
381			if sessionErr != nil {
382				return sessionErr
383			}
384			if finishReason == message.FinishReasonEndTurn {
385				a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
386			}
387			return a.messages.Update(genCtx, *currentAssistant)
388		},
389		StopWhen: []fantasy.StopCondition{
390			func(_ []fantasy.StepResult) bool {
391				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
392				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
393				remaining := cw - tokens
394				var threshold int64
395				if cw > 200_000 {
396					threshold = 20_000
397				} else {
398					threshold = int64(float64(cw) * 0.2)
399				}
400				if (remaining <= threshold) && !a.disableAutoSummarize {
401					shouldSummarize = true
402					return true
403				}
404				return false
405			},
406		},
407	})
408
409	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
410
411	if err != nil {
412		isCancelErr := errors.Is(err, context.Canceled)
413		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
414		if currentAssistant == nil {
415			return result, err
416		}
417		// Ensure we finish thinking on error to close the reasoning state
418		currentAssistant.FinishThinking()
419		toolCalls := currentAssistant.ToolCalls()
420		// INFO: we use the parent context here because the genCtx has been cancelled
421		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
422		if createErr != nil {
423			return nil, createErr
424		}
425		for _, tc := range toolCalls {
426			if !tc.Finished {
427				tc.Finished = true
428				tc.Input = "{}"
429				currentAssistant.AddToolCall(tc)
430				updateErr := a.messages.Update(ctx, *currentAssistant)
431				if updateErr != nil {
432					return nil, updateErr
433				}
434			}
435
436			found := false
437			for _, msg := range msgs {
438				if msg.Role == message.Tool {
439					for _, tr := range msg.ToolResults() {
440						if tr.ToolCallID == tc.ID {
441							found = true
442							break
443						}
444					}
445				}
446				if found {
447					break
448				}
449			}
450			if found {
451				continue
452			}
453			content := "There was an error while executing the tool"
454			if isCancelErr {
455				content = "Tool execution canceled by user"
456			} else if isPermissionErr {
457				content = "Permission denied"
458			}
459			toolResult := message.ToolResult{
460				ToolCallID: tc.ID,
461				Name:       tc.Name,
462				Content:    content,
463				IsError:    true,
464			}
465			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
466				Role: message.Tool,
467				Parts: []message.ContentPart{
468					toolResult,
469				},
470			})
471			if createErr != nil {
472				return nil, createErr
473			}
474		}
475		if isCancelErr {
476			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
477		} else if isPermissionErr {
478			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
479		} else {
480			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
481		}
482		// INFO: we use the parent context here because the genCtx has been cancelled
483		updateErr := a.messages.Update(ctx, *currentAssistant)
484		if updateErr != nil {
485			return nil, updateErr
486		}
487		return nil, err
488	}
489	wg.Wait()
490
491	if shouldSummarize {
492		a.cancelCompletionNotification(call.SessionID)
493		a.activeRequests.Del(call.SessionID)
494		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
495			return nil, summarizeErr
496		}
497		// if the agent was not done...
498		if len(currentAssistant.ToolCalls()) > 0 {
499			existing, ok := a.messageQueue.Get(call.SessionID)
500			if !ok {
501				existing = []SessionAgentCall{}
502			}
503			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
504			existing = append(existing, call)
505			a.messageQueue.Set(call.SessionID, existing)
506		}
507	}
508
509	// release active request before processing queued messages
510	a.activeRequests.Del(call.SessionID)
511	cancel()
512
513	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
514	if !ok || len(queuedMessages) == 0 {
515		return result, err
516	}
517	// there are queued messages restart the loop
518	firstQueuedMessage := queuedMessages[0]
519	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
520	return a.Run(ctx, firstQueuedMessage)
521}
522
523func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
524	if a.notifier == nil {
525		return
526	}
527
528	if sessionTitle == "" {
529		sessionTitle = sessionID
530	}
531
532	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
533		cancel()
534	}
535
536	title := "💘 Crush is waiting"
537	message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
538	cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
539	if cancel == nil {
540		cancel = func() {}
541	}
542	a.completionCancels.Set(sessionID, cancel)
543}
544
545func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
546	if a.notifier == nil {
547		return
548	}
549
550	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
551		cancel()
552	}
553}
554
555// CancelCompletionNotification implements SessionAgent.
556func (a *sessionAgent) CancelCompletionNotification(sessionID string) {
557	a.cancelCompletionNotification(sessionID)
558}
559
560// HasPendingCompletionNotification implements SessionAgent.
561func (a *sessionAgent) HasPendingCompletionNotification(sessionID string) bool {
562	if a.IsSessionBusy(sessionID) {
563		return false
564	}
565	_, ok := a.completionCancels.Get(sessionID)
566	return ok
567}
568
569func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
570	if a.IsSessionBusy(sessionID) {
571		return ErrSessionBusy
572	}
573
574	currentSession, err := a.sessions.Get(ctx, sessionID)
575	if err != nil {
576		return fmt.Errorf("failed to get session: %w", err)
577	}
578	msgs, err := a.getSessionMessages(ctx, currentSession)
579	if err != nil {
580		return err
581	}
582	if len(msgs) == 0 {
583		// nothing to summarize
584		return nil
585	}
586
587	aiMsgs, _ := a.preparePrompt(msgs)
588
589	genCtx, cancel := context.WithCancel(ctx)
590	a.activeRequests.Set(sessionID, cancel)
591	defer a.activeRequests.Del(sessionID)
592	defer cancel()
593
594	agent := fantasy.NewAgent(a.largeModel.Model,
595		fantasy.WithSystemPrompt(string(summaryPrompt)),
596	)
597	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
598		Role:             message.Assistant,
599		Model:            a.largeModel.Model.Model(),
600		Provider:         a.largeModel.Model.Provider(),
601		IsSummaryMessage: true,
602	})
603	if err != nil {
604		return err
605	}
606
607	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
608		Prompt:          "Provide a detailed summary of our conversation above.",
609		Messages:        aiMsgs,
610		ProviderOptions: opts,
611		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
612			prepared.Messages = options.Messages
613			if a.systemPromptPrefix != "" {
614				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
615			}
616			return callContext, prepared, nil
617		},
618		OnReasoningDelta: func(id string, text string) error {
619			summaryMessage.AppendReasoningContent(text)
620			return a.messages.Update(genCtx, summaryMessage)
621		},
622		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
623			// handle anthropic signature
624			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
625				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
626					summaryMessage.AppendReasoningSignature(signature.Signature)
627				}
628			}
629			summaryMessage.FinishThinking()
630			return a.messages.Update(genCtx, summaryMessage)
631		},
632		OnTextDelta: func(id, text string) error {
633			summaryMessage.AppendContent(text)
634			return a.messages.Update(genCtx, summaryMessage)
635		},
636	})
637	if err != nil {
638		isCancelErr := errors.Is(err, context.Canceled)
639		if isCancelErr {
640			// User cancelled summarize we need to remove the summary message
641			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
642			return deleteErr
643		}
644		return err
645	}
646
647	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
648	err = a.messages.Update(genCtx, summaryMessage)
649	if err != nil {
650		return err
651	}
652
653	var openrouterCost *float64
654	for _, step := range resp.Steps {
655		stepCost := a.openrouterCost(step.ProviderMetadata)
656		if stepCost != nil {
657			newCost := *stepCost
658			if openrouterCost != nil {
659				newCost += *openrouterCost
660			}
661			openrouterCost = &newCost
662		}
663	}
664
665	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
666
667	// just in case get just the last usage
668	usage := resp.Response.Usage
669	currentSession.SummaryMessageID = summaryMessage.ID
670	currentSession.CompletionTokens = usage.OutputTokens
671	currentSession.PromptTokens = 0
672	_, err = a.sessions.Save(genCtx, currentSession)
673	return err
674}
675
676func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
677	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
678		return fantasy.ProviderOptions{}
679	}
680	return fantasy.ProviderOptions{
681		anthropic.Name: &anthropic.ProviderCacheControlOptions{
682			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
683		},
684		bedrock.Name: &anthropic.ProviderCacheControlOptions{
685			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
686		},
687	}
688}
689
690func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
691	var attachmentParts []message.ContentPart
692	for _, attachment := range call.Attachments {
693		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
694	}
695	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
696	parts = append(parts, attachmentParts...)
697	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
698		Role:  message.User,
699		Parts: parts,
700	})
701	if err != nil {
702		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
703	}
704	return msg, nil
705}
706
707func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
708	var history []fantasy.Message
709	for _, m := range msgs {
710		if len(m.Parts) == 0 {
711			continue
712		}
713		// Assistant message without content or tool calls (cancelled before it returned anything)
714		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
715			continue
716		}
717		history = append(history, m.ToAIMessage()...)
718	}
719
720	var files []fantasy.FilePart
721	for _, attachment := range attachments {
722		files = append(files, fantasy.FilePart{
723			Filename:  attachment.FileName,
724			Data:      attachment.Content,
725			MediaType: attachment.MimeType,
726		})
727	}
728
729	return history, files
730}
731
732func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
733	msgs, err := a.messages.List(ctx, session.ID)
734	if err != nil {
735		return nil, fmt.Errorf("failed to list messages: %w", err)
736	}
737
738	if session.SummaryMessageID != "" {
739		summaryMsgInex := -1
740		for i, msg := range msgs {
741			if msg.ID == session.SummaryMessageID {
742				summaryMsgInex = i
743				break
744			}
745		}
746		if summaryMsgInex != -1 {
747			msgs = msgs[summaryMsgInex:]
748			msgs[0].Role = message.User
749		}
750	}
751	return msgs, nil
752}
753
754func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
755	if prompt == "" {
756		return
757	}
758
759	var maxOutput int64 = 40
760	if a.smallModel.CatwalkCfg.CanReason {
761		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
762	}
763
764	agent := fantasy.NewAgent(a.smallModel.Model,
765		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
766		fantasy.WithMaxOutputTokens(maxOutput),
767	)
768
769	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
770		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
771		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
772			prepared.Messages = options.Messages
773			if a.systemPromptPrefix != "" {
774				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
775			}
776			return callContext, prepared, nil
777		},
778	})
779	if err != nil {
780		slog.Error("error generating title", "err", err)
781		return
782	}
783
784	title := resp.Response.Content.Text()
785
786	title = strings.ReplaceAll(title, "\n", " ")
787
788	// remove thinking tags if present
789	if idx := strings.Index(title, "</think>"); idx > 0 {
790		title = title[idx+len("</think>"):]
791	}
792
793	title = strings.TrimSpace(title)
794	if title == "" {
795		slog.Warn("failed to generate title", "warn", "empty title")
796		return
797	}
798
799	session.Title = title
800
801	var openrouterCost *float64
802	for _, step := range resp.Steps {
803		stepCost := a.openrouterCost(step.ProviderMetadata)
804		if stepCost != nil {
805			newCost := *stepCost
806			if openrouterCost != nil {
807				newCost += *openrouterCost
808			}
809			openrouterCost = &newCost
810		}
811	}
812
813	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
814	_, saveErr := a.sessions.Save(ctx, *session)
815	if saveErr != nil {
816		slog.Error("failed to save session title & usage", "error", saveErr)
817		return
818	}
819}
820
821func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
822	openrouterMetadata, ok := metadata[openrouter.Name]
823	if !ok {
824		return nil
825	}
826
827	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
828	if !ok {
829		return nil
830	}
831	return &opts.Usage.Cost
832}
833
834func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
835	modelConfig := model.CatwalkCfg
836	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
837		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
838		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
839		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
840
841	a.eventTokensUsed(session.ID, model, usage, cost)
842
843	if overrideCost != nil {
844		session.Cost += *overrideCost
845	} else {
846		session.Cost += cost
847	}
848
849	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
850	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
851}
852
853func (a *sessionAgent) Cancel(sessionID string) {
854	// Cancel regular requests
855	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
856		slog.Info("Request cancellation initiated", "session_id", sessionID)
857		cancel()
858	}
859
860	// Also check for summarize requests
861	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
862		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
863		cancel()
864	}
865
866	if a.QueuedPrompts(sessionID) > 0 {
867		slog.Info("Clearing queued prompts", "session_id", sessionID)
868		a.messageQueue.Del(sessionID)
869	}
870}
871
872func (a *sessionAgent) ClearQueue(sessionID string) {
873	if a.QueuedPrompts(sessionID) > 0 {
874		slog.Info("Clearing queued prompts", "session_id", sessionID)
875		a.messageQueue.Del(sessionID)
876	}
877}
878
879func (a *sessionAgent) CancelAll() {
880	if !a.IsBusy() {
881		// still ensure notifications are cancelled even when not busy
882		for cancel := range a.completionCancels.Seq() {
883			if cancel != nil {
884				cancel()
885			}
886		}
887		a.completionCancels.Reset(make(map[string]context.CancelFunc))
888		return
889	}
890	for key := range a.activeRequests.Seq2() {
891		a.Cancel(key) // key is sessionID
892	}
893
894	timeout := time.After(5 * time.Second)
895	for a.IsBusy() {
896		select {
897		case <-timeout:
898			return
899		default:
900			time.Sleep(200 * time.Millisecond)
901		}
902	}
903
904	for cancel := range a.completionCancels.Seq() {
905		if cancel != nil {
906			cancel()
907		}
908	}
909	a.completionCancels.Reset(make(map[string]context.CancelFunc))
910}
911
912func (a *sessionAgent) IsBusy() bool {
913	var busy bool
914	for cancelFunc := range a.activeRequests.Seq() {
915		if cancelFunc != nil {
916			busy = true
917			break
918		}
919	}
920	return busy
921}
922
923func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
924	_, busy := a.activeRequests.Get(sessionID)
925	return busy
926}
927
928func (a *sessionAgent) QueuedPrompts(sessionID string) int {
929	l, ok := a.messageQueue.Get(sessionID)
930	if !ok {
931		return 0
932	}
933	return len(l)
934}
935
936func (a *sessionAgent) SetModels(large Model, small Model) {
937	a.largeModel = large
938	a.smallModel = small
939}
940
941func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
942	a.tools = tools
943}
944
945func (a *sessionAgent) Model() Model {
946	return a.largeModel
947}