agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"charm.land/fantasy/providers/openrouter"
 21	"github.com/charmbracelet/catwalk/pkg/catwalk"
 22	"git.secluded.site/crush/internal/agent/tools"
 23	"git.secluded.site/crush/internal/config"
 24	"git.secluded.site/crush/internal/csync"
 25	"git.secluded.site/crush/internal/message"
 26	"git.secluded.site/crush/internal/notification"
 27	"git.secluded.site/crush/internal/permission"
 28	"git.secluded.site/crush/internal/session"
 29)
 30
 31//go:embed templates/title.md
 32var titlePrompt []byte
 33
 34//go:embed templates/summary.md
 35var summaryPrompt []byte
 36
 37type SessionAgentCall struct {
 38	SessionID        string
 39	Prompt           string
 40	ProviderOptions  fantasy.ProviderOptions
 41	Attachments      []message.Attachment
 42	MaxOutputTokens  int64
 43	Temperature      *float64
 44	TopP             *float64
 45	TopK             *int64
 46	FrequencyPenalty *float64
 47	PresencePenalty  *float64
 48}
 49
 50type SessionAgent interface {
 51	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 52	SetModels(large Model, small Model)
 53	SetTools(tools []fantasy.AgentTool)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	QueuedPrompts(sessionID string) int
 59	ClearQueue(sessionID string)
 60	Summarize(context.Context, string, fantasy.ProviderOptions) error
 61	Model() Model
 62	// CancelCompletionNotification cancels any scheduled "turn ended"
 63	// notification for the provided sessionID.
 64	CancelCompletionNotification(sessionID string)
 65	// HasPendingCompletionNotification returns true if a turn-end
 66	// notification has been scheduled and not yet cancelled/shown.
 67	HasPendingCompletionNotification(sessionID string) bool
 68}
 69
 70const completionNotificationDelay = 5 * time.Second
 71
 72type Model struct {
 73	Model      fantasy.LanguageModel
 74	CatwalkCfg catwalk.Model
 75	ModelCfg   config.SelectedModel
 76}
 77
 78type sessionAgent struct {
 79	largeModel           Model
 80	smallModel           Model
 81	systemPromptPrefix   string
 82	systemPrompt         string
 83	tools                []fantasy.AgentTool
 84	sessions             session.Service
 85	messages             message.Service
 86	disableAutoSummarize bool
 87	isYolo               bool
 88
 89	messageQueue      *csync.Map[string, []SessionAgentCall]
 90	activeRequests    *csync.Map[string, context.CancelFunc]
 91	notifier          *notification.Notifier
 92	notifyCtx         context.Context
 93	completionCancels *csync.Map[string, context.CancelFunc]
 94}
 95
 96type SessionAgentOptions struct {
 97	LargeModel           Model
 98	SmallModel           Model
 99	SystemPromptPrefix   string
100	SystemPrompt         string
101	DisableAutoSummarize bool
102	IsYolo               bool
103	Sessions             session.Service
104	Messages             message.Service
105	Tools                []fantasy.AgentTool
106	Notifier             *notification.Notifier
107	NotificationCtx      context.Context
108}
109
110func NewSessionAgent(
111	opts SessionAgentOptions,
112) SessionAgent {
113	notifyCtx := opts.NotificationCtx
114	if notifyCtx == nil {
115		notifyCtx = context.Background()
116	}
117
118	return &sessionAgent{
119		largeModel:           opts.LargeModel,
120		smallModel:           opts.SmallModel,
121		systemPromptPrefix:   opts.SystemPromptPrefix,
122		systemPrompt:         opts.SystemPrompt,
123		sessions:             opts.Sessions,
124		messages:             opts.Messages,
125		disableAutoSummarize: opts.DisableAutoSummarize,
126		tools:                opts.Tools,
127		isYolo:               opts.IsYolo,
128		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
129		activeRequests:       csync.NewMap[string, context.CancelFunc](),
130		notifier:             opts.Notifier,
131		notifyCtx:            notifyCtx,
132		completionCancels:    csync.NewMap[string, context.CancelFunc](),
133	}
134}
135
136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
137	if call.Prompt == "" {
138		return nil, ErrEmptyPrompt
139	}
140	if call.SessionID == "" {
141		return nil, ErrSessionMissing
142	}
143
144	a.cancelCompletionNotification(call.SessionID)
145
146	// Queue the message if busy
147	if a.IsSessionBusy(call.SessionID) {
148		existing, ok := a.messageQueue.Get(call.SessionID)
149		if !ok {
150			existing = []SessionAgentCall{}
151		}
152		existing = append(existing, call)
153		a.messageQueue.Set(call.SessionID, existing)
154		return nil, nil
155	}
156
157	if len(a.tools) > 0 {
158		// add anthropic caching to the last tool
159		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
160	}
161
162	agent := fantasy.NewAgent(
163		a.largeModel.Model,
164		fantasy.WithSystemPrompt(a.systemPrompt),
165		fantasy.WithTools(a.tools...),
166	)
167
168	sessionLock := sync.Mutex{}
169	currentSession, err := a.sessions.Get(ctx, call.SessionID)
170	if err != nil {
171		return nil, fmt.Errorf("failed to get session: %w", err)
172	}
173
174	msgs, err := a.getSessionMessages(ctx, currentSession)
175	if err != nil {
176		return nil, fmt.Errorf("failed to get session messages: %w", err)
177	}
178
179	var wg sync.WaitGroup
180	// Generate title if first message
181	if len(msgs) == 0 {
182		wg.Go(func() {
183			sessionLock.Lock()
184			a.generateTitle(ctx, &currentSession, call.Prompt)
185			sessionLock.Unlock()
186		})
187	}
188
189	// Add the user message to the session
190	_, err = a.createUserMessage(ctx, call)
191	if err != nil {
192		return nil, err
193	}
194
195	// add the session to the context
196	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
197
198	genCtx, cancel := context.WithCancel(ctx)
199	a.activeRequests.Set(call.SessionID, cancel)
200
201	defer cancel()
202	defer a.activeRequests.Del(call.SessionID)
203
204	history, files := a.preparePrompt(msgs, call.Attachments...)
205
206	startTime := time.Now()
207	a.eventPromptSent(call.SessionID)
208
209	var currentAssistant *message.Message
210	var shouldSummarize bool
211	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
212		Prompt:           call.Prompt,
213		Files:            files,
214		Messages:         history,
215		ProviderOptions:  call.ProviderOptions,
216		MaxOutputTokens:  &call.MaxOutputTokens,
217		TopP:             call.TopP,
218		Temperature:      call.Temperature,
219		PresencePenalty:  call.PresencePenalty,
220		TopK:             call.TopK,
221		FrequencyPenalty: call.FrequencyPenalty,
222		// Before each step create the new assistant message
223		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
224			prepared.Messages = options.Messages
225			// reset all cached items
226			for i := range prepared.Messages {
227				prepared.Messages[i].ProviderOptions = nil
228			}
229
230			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
231			a.messageQueue.Del(call.SessionID)
232			for _, queued := range queuedCalls {
233				userMessage, createErr := a.createUserMessage(callContext, queued)
234				if createErr != nil {
235					return callContext, prepared, createErr
236				}
237				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
238			}
239
240			lastSystemRoleInx := 0
241			systemMessageUpdated := false
242			for i, msg := range prepared.Messages {
243				// only add cache control to the last message
244				if msg.Role == fantasy.MessageRoleSystem {
245					lastSystemRoleInx = i
246				} else if !systemMessageUpdated {
247					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
248					systemMessageUpdated = true
249				}
250				// than add cache control to the last 2 messages
251				if i > len(prepared.Messages)-3 {
252					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
253				}
254			}
255
256			if a.systemPromptPrefix != "" {
257				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
258			}
259
260			var assistantMsg message.Message
261			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
262				Role:     message.Assistant,
263				Parts:    []message.ContentPart{},
264				Model:    a.largeModel.ModelCfg.Model,
265				Provider: a.largeModel.ModelCfg.Provider,
266			})
267			if err != nil {
268				return callContext, prepared, err
269			}
270			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
271			currentAssistant = &assistantMsg
272			return callContext, prepared, err
273		},
274		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
275			currentAssistant.AppendReasoningContent(reasoning.Text)
276			return a.messages.Update(genCtx, *currentAssistant)
277		},
278		OnReasoningDelta: func(id string, text string) error {
279			currentAssistant.AppendReasoningContent(text)
280			return a.messages.Update(genCtx, *currentAssistant)
281		},
282		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
283			// handle anthropic signature
284			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
285				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
286					currentAssistant.AppendReasoningSignature(reasoning.Signature)
287				}
288			}
289			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
290				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
291					currentAssistant.AppendReasoningSignature(reasoning.Signature)
292				}
293			}
294			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
295				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
296					currentAssistant.SetReasoningResponsesData(reasoning)
297				}
298			}
299			currentAssistant.FinishThinking()
300			return a.messages.Update(genCtx, *currentAssistant)
301		},
302		OnTextDelta: func(id string, text string) error {
303			currentAssistant.AppendContent(text)
304			return a.messages.Update(genCtx, *currentAssistant)
305		},
306		OnToolInputStart: func(id string, toolName string) error {
307			toolCall := message.ToolCall{
308				ID:               id,
309				Name:             toolName,
310				ProviderExecuted: false,
311				Finished:         false,
312			}
313			currentAssistant.AddToolCall(toolCall)
314			return a.messages.Update(genCtx, *currentAssistant)
315		},
316		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
317			// TODO: implement
318		},
319		OnToolCall: func(tc fantasy.ToolCallContent) error {
320			toolCall := message.ToolCall{
321				ID:               tc.ToolCallID,
322				Name:             tc.ToolName,
323				Input:            tc.Input,
324				ProviderExecuted: false,
325				Finished:         true,
326			}
327			currentAssistant.AddToolCall(toolCall)
328			return a.messages.Update(genCtx, *currentAssistant)
329		},
330		OnToolResult: func(result fantasy.ToolResultContent) error {
331			var resultContent string
332			isError := false
333			switch result.Result.GetType() {
334			case fantasy.ToolResultContentTypeText:
335				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
336				if ok {
337					resultContent = r.Text
338				}
339			case fantasy.ToolResultContentTypeError:
340				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
341				if ok {
342					isError = true
343					resultContent = r.Error.Error()
344				}
345			case fantasy.ToolResultContentTypeMedia:
346				// TODO: handle this message type
347			}
348			toolResult := message.ToolResult{
349				ToolCallID: result.ToolCallID,
350				Name:       result.ToolName,
351				Content:    resultContent,
352				IsError:    isError,
353				Metadata:   result.ClientMetadata,
354			}
355			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
356				Role: message.Tool,
357				Parts: []message.ContentPart{
358					toolResult,
359				},
360			})
361			if createMsgErr != nil {
362				return createMsgErr
363			}
364			return nil
365		},
366		OnStepFinish: func(stepResult fantasy.StepResult) error {
367			finishReason := message.FinishReasonUnknown
368			switch stepResult.FinishReason {
369			case fantasy.FinishReasonLength:
370				finishReason = message.FinishReasonMaxTokens
371			case fantasy.FinishReasonStop:
372				finishReason = message.FinishReasonEndTurn
373			case fantasy.FinishReasonToolCalls:
374				finishReason = message.FinishReasonToolUse
375			}
376			currentAssistant.AddFinish(finishReason, "", "")
377			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
378			sessionLock.Lock()
379			_, sessionErr := a.sessions.Save(genCtx, currentSession)
380			sessionLock.Unlock()
381			if sessionErr != nil {
382				return sessionErr
383			}
384			if finishReason == message.FinishReasonEndTurn {
385				a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
386			}
387			return a.messages.Update(genCtx, *currentAssistant)
388		},
389		StopWhen: []fantasy.StopCondition{
390			func(_ []fantasy.StepResult) bool {
391				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
392				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
393				remaining := cw - tokens
394				var threshold int64
395				if cw > 200_000 {
396					threshold = 20_000
397				} else {
398					threshold = int64(float64(cw) * 0.2)
399				}
400				if (remaining <= threshold) && !a.disableAutoSummarize {
401					shouldSummarize = true
402					return true
403				}
404				return false
405			},
406		},
407	})
408
409	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
410
411	if err != nil {
412		isCancelErr := errors.Is(err, context.Canceled)
413		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
414		if currentAssistant == nil {
415			return result, err
416		}
417		// Ensure we finish thinking on error to close the reasoning state
418		currentAssistant.FinishThinking()
419		toolCalls := currentAssistant.ToolCalls()
420		// INFO: we use the parent context here because the genCtx has been cancelled
421		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
422		if createErr != nil {
423			return nil, createErr
424		}
425		for _, tc := range toolCalls {
426			if !tc.Finished {
427				tc.Finished = true
428				tc.Input = "{}"
429				currentAssistant.AddToolCall(tc)
430				updateErr := a.messages.Update(ctx, *currentAssistant)
431				if updateErr != nil {
432					return nil, updateErr
433				}
434			}
435
436			found := false
437			for _, msg := range msgs {
438				if msg.Role == message.Tool {
439					for _, tr := range msg.ToolResults() {
440						if tr.ToolCallID == tc.ID {
441							found = true
442							break
443						}
444					}
445				}
446				if found {
447					break
448				}
449			}
450			if found {
451				continue
452			}
453			content := "There was an error while executing the tool"
454			if isCancelErr {
455				content = "Tool execution canceled by user"
456			} else if isPermissionErr {
457				content = "User denied permission"
458			}
459			toolResult := message.ToolResult{
460				ToolCallID: tc.ID,
461				Name:       tc.Name,
462				Content:    content,
463				IsError:    true,
464			}
465			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
466				Role: message.Tool,
467				Parts: []message.ContentPart{
468					toolResult,
469				},
470			})
471			if createErr != nil {
472				return nil, createErr
473			}
474		}
475		if isCancelErr {
476			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
477		} else if isPermissionErr {
478			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
479		} else {
480			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
481		}
482		// INFO: we use the parent context here because the genCtx has been cancelled
483		updateErr := a.messages.Update(ctx, *currentAssistant)
484		if updateErr != nil {
485			return nil, updateErr
486		}
487		return nil, err
488	}
489	wg.Wait()
490
491	if shouldSummarize {
492		a.cancelCompletionNotification(call.SessionID)
493		a.activeRequests.Del(call.SessionID)
494		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
495			return nil, summarizeErr
496		}
497		// if the agent was not done...
498		if len(currentAssistant.ToolCalls()) > 0 {
499			existing, ok := a.messageQueue.Get(call.SessionID)
500			if !ok {
501				existing = []SessionAgentCall{}
502			}
503			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
504			existing = append(existing, call)
505			a.messageQueue.Set(call.SessionID, existing)
506		}
507	}
508
509	// release active request before processing queued messages
510	a.activeRequests.Del(call.SessionID)
511	cancel()
512
513	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
514	if !ok || len(queuedMessages) == 0 {
515		return result, err
516	}
517	// there are queued messages restart the loop
518	firstQueuedMessage := queuedMessages[0]
519	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
520	return a.Run(ctx, firstQueuedMessage)
521}
522
523func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
524	// Do not emit notifications for Agent-tool sub-sessions.
525	if a.sessions != nil && a.sessions.IsAgentToolSession(sessionID) {
526		return
527	}
528	if a.notifier == nil {
529		return
530	}
531
532	if sessionTitle == "" {
533		sessionTitle = sessionID
534	}
535
536	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
537		cancel()
538	}
539
540	title := "💘 Crush is waiting"
541	message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
542	cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
543	if cancel == nil {
544		cancel = func() {}
545	}
546	a.completionCancels.Set(sessionID, cancel)
547}
548
549func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
550	if a.notifier == nil {
551		return
552	}
553
554	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
555		cancel()
556	}
557}
558
559// CancelCompletionNotification implements SessionAgent.
560func (a *sessionAgent) CancelCompletionNotification(sessionID string) {
561	a.cancelCompletionNotification(sessionID)
562}
563
564// HasPendingCompletionNotification implements SessionAgent.
565func (a *sessionAgent) HasPendingCompletionNotification(sessionID string) bool {
566	if a.IsSessionBusy(sessionID) {
567		return false
568	}
569	_, ok := a.completionCancels.Get(sessionID)
570	return ok
571}
572
573func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
574	if a.IsSessionBusy(sessionID) {
575		return ErrSessionBusy
576	}
577
578	currentSession, err := a.sessions.Get(ctx, sessionID)
579	if err != nil {
580		return fmt.Errorf("failed to get session: %w", err)
581	}
582	msgs, err := a.getSessionMessages(ctx, currentSession)
583	if err != nil {
584		return err
585	}
586	if len(msgs) == 0 {
587		// nothing to summarize
588		return nil
589	}
590
591	aiMsgs, _ := a.preparePrompt(msgs)
592
593	genCtx, cancel := context.WithCancel(ctx)
594	a.activeRequests.Set(sessionID, cancel)
595	defer a.activeRequests.Del(sessionID)
596	defer cancel()
597
598	agent := fantasy.NewAgent(a.largeModel.Model,
599		fantasy.WithSystemPrompt(string(summaryPrompt)),
600	)
601	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
602		Role:             message.Assistant,
603		Model:            a.largeModel.Model.Model(),
604		Provider:         a.largeModel.Model.Provider(),
605		IsSummaryMessage: true,
606	})
607	if err != nil {
608		return err
609	}
610
611	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
612		Prompt:          "Provide a detailed summary of our conversation above.",
613		Messages:        aiMsgs,
614		ProviderOptions: opts,
615		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
616			prepared.Messages = options.Messages
617			if a.systemPromptPrefix != "" {
618				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
619			}
620			return callContext, prepared, nil
621		},
622		OnReasoningDelta: func(id string, text string) error {
623			summaryMessage.AppendReasoningContent(text)
624			return a.messages.Update(genCtx, summaryMessage)
625		},
626		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
627			// handle anthropic signature
628			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
629				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
630					summaryMessage.AppendReasoningSignature(signature.Signature)
631				}
632			}
633			summaryMessage.FinishThinking()
634			return a.messages.Update(genCtx, summaryMessage)
635		},
636		OnTextDelta: func(id, text string) error {
637			summaryMessage.AppendContent(text)
638			return a.messages.Update(genCtx, summaryMessage)
639		},
640	})
641	if err != nil {
642		isCancelErr := errors.Is(err, context.Canceled)
643		if isCancelErr {
644			// User cancelled summarize we need to remove the summary message
645			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
646			return deleteErr
647		}
648		return err
649	}
650
651	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
652	err = a.messages.Update(genCtx, summaryMessage)
653	if err != nil {
654		return err
655	}
656
657	var openrouterCost *float64
658	for _, step := range resp.Steps {
659		stepCost := a.openrouterCost(step.ProviderMetadata)
660		if stepCost != nil {
661			newCost := *stepCost
662			if openrouterCost != nil {
663				newCost += *openrouterCost
664			}
665			openrouterCost = &newCost
666		}
667	}
668
669	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
670
671	// just in case get just the last usage
672	usage := resp.Response.Usage
673	currentSession.SummaryMessageID = summaryMessage.ID
674	currentSession.CompletionTokens = usage.OutputTokens
675	currentSession.PromptTokens = 0
676	_, err = a.sessions.Save(genCtx, currentSession)
677	return err
678}
679
680func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
681	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
682		return fantasy.ProviderOptions{}
683	}
684	return fantasy.ProviderOptions{
685		anthropic.Name: &anthropic.ProviderCacheControlOptions{
686			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
687		},
688		bedrock.Name: &anthropic.ProviderCacheControlOptions{
689			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
690		},
691	}
692}
693
694func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
695	var attachmentParts []message.ContentPart
696	for _, attachment := range call.Attachments {
697		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
698	}
699	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
700	parts = append(parts, attachmentParts...)
701	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
702		Role:  message.User,
703		Parts: parts,
704	})
705	if err != nil {
706		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
707	}
708	return msg, nil
709}
710
711func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
712	var history []fantasy.Message
713	for _, m := range msgs {
714		if len(m.Parts) == 0 {
715			continue
716		}
717		// Assistant message without content or tool calls (cancelled before it returned anything)
718		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
719			continue
720		}
721		history = append(history, m.ToAIMessage()...)
722	}
723
724	var files []fantasy.FilePart
725	for _, attachment := range attachments {
726		files = append(files, fantasy.FilePart{
727			Filename:  attachment.FileName,
728			Data:      attachment.Content,
729			MediaType: attachment.MimeType,
730		})
731	}
732
733	return history, files
734}
735
736func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
737	msgs, err := a.messages.List(ctx, session.ID)
738	if err != nil {
739		return nil, fmt.Errorf("failed to list messages: %w", err)
740	}
741
742	if session.SummaryMessageID != "" {
743		summaryMsgInex := -1
744		for i, msg := range msgs {
745			if msg.ID == session.SummaryMessageID {
746				summaryMsgInex = i
747				break
748			}
749		}
750		if summaryMsgInex != -1 {
751			msgs = msgs[summaryMsgInex:]
752			msgs[0].Role = message.User
753		}
754	}
755	return msgs, nil
756}
757
758func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
759	if prompt == "" {
760		return
761	}
762
763	var maxOutput int64 = 40
764	if a.smallModel.CatwalkCfg.CanReason {
765		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
766	}
767
768	agent := fantasy.NewAgent(a.smallModel.Model,
769		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
770		fantasy.WithMaxOutputTokens(maxOutput),
771	)
772
773	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
774		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
775		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
776			prepared.Messages = options.Messages
777			if a.systemPromptPrefix != "" {
778				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
779			}
780			return callContext, prepared, nil
781		},
782	})
783	if err != nil {
784		slog.Error("error generating title", "err", err)
785		return
786	}
787
788	title := resp.Response.Content.Text()
789
790	title = strings.ReplaceAll(title, "\n", " ")
791
792	// remove thinking tags if present
793	if idx := strings.Index(title, "</think>"); idx > 0 {
794		title = title[idx+len("</think>"):]
795	}
796
797	title = strings.TrimSpace(title)
798	if title == "" {
799		slog.Warn("failed to generate title", "warn", "empty title")
800		return
801	}
802
803	session.Title = title
804
805	var openrouterCost *float64
806	for _, step := range resp.Steps {
807		stepCost := a.openrouterCost(step.ProviderMetadata)
808		if stepCost != nil {
809			newCost := *stepCost
810			if openrouterCost != nil {
811				newCost += *openrouterCost
812			}
813			openrouterCost = &newCost
814		}
815	}
816
817	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
818	_, saveErr := a.sessions.Save(ctx, *session)
819	if saveErr != nil {
820		slog.Error("failed to save session title & usage", "error", saveErr)
821		return
822	}
823}
824
825func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
826	openrouterMetadata, ok := metadata[openrouter.Name]
827	if !ok {
828		return nil
829	}
830
831	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
832	if !ok {
833		return nil
834	}
835	return &opts.Usage.Cost
836}
837
838func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
839	modelConfig := model.CatwalkCfg
840	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
841		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
842		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
843		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
844
845	a.eventTokensUsed(session.ID, model, usage, cost)
846
847	if overrideCost != nil {
848		session.Cost += *overrideCost
849	} else {
850		session.Cost += cost
851	}
852
853	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
854	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
855}
856
857func (a *sessionAgent) Cancel(sessionID string) {
858	// Cancel regular requests
859	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
860		slog.Info("Request cancellation initiated", "session_id", sessionID)
861		cancel()
862	}
863
864	// Also check for summarize requests
865	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
866		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
867		cancel()
868	}
869
870	if a.QueuedPrompts(sessionID) > 0 {
871		slog.Info("Clearing queued prompts", "session_id", sessionID)
872		a.messageQueue.Del(sessionID)
873	}
874}
875
876func (a *sessionAgent) ClearQueue(sessionID string) {
877	if a.QueuedPrompts(sessionID) > 0 {
878		slog.Info("Clearing queued prompts", "session_id", sessionID)
879		a.messageQueue.Del(sessionID)
880	}
881}
882
883func (a *sessionAgent) CancelAll() {
884	if !a.IsBusy() {
885		// still ensure notifications are cancelled even when not busy
886		for cancel := range a.completionCancels.Seq() {
887			if cancel != nil {
888				cancel()
889			}
890		}
891		a.completionCancels.Reset(make(map[string]context.CancelFunc))
892		return
893	}
894	for key := range a.activeRequests.Seq2() {
895		a.Cancel(key) // key is sessionID
896	}
897
898	timeout := time.After(5 * time.Second)
899	for a.IsBusy() {
900		select {
901		case <-timeout:
902			return
903		default:
904			time.Sleep(200 * time.Millisecond)
905		}
906	}
907
908	for cancel := range a.completionCancels.Seq() {
909		if cancel != nil {
910			cancel()
911		}
912	}
913	a.completionCancels.Reset(make(map[string]context.CancelFunc))
914}
915
916func (a *sessionAgent) IsBusy() bool {
917	var busy bool
918	for cancelFunc := range a.activeRequests.Seq() {
919		if cancelFunc != nil {
920			busy = true
921			break
922		}
923	}
924	return busy
925}
926
927func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
928	_, busy := a.activeRequests.Get(sessionID)
929	return busy
930}
931
932func (a *sessionAgent) QueuedPrompts(sessionID string) int {
933	l, ok := a.messageQueue.Get(sessionID)
934	if !ok {
935		return 0
936	}
937	return len(l)
938}
939
940func (a *sessionAgent) SetModels(large Model, small Model) {
941	a.largeModel = large
942	a.smallModel = small
943}
944
945func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
946	a.tools = tools
947}
948
949func (a *sessionAgent) Model() Model {
950	return a.largeModel
951}