agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"context"
 12	_ "embed"
 13	"errors"
 14	"fmt"
 15	"log/slog"
 16	"os"
 17	"strconv"
 18	"strings"
 19	"sync"
 20	"time"
 21
 22	"charm.land/fantasy"
 23	"charm.land/fantasy/providers/anthropic"
 24	"charm.land/fantasy/providers/bedrock"
 25	"charm.land/fantasy/providers/google"
 26	"charm.land/fantasy/providers/openai"
 27	"charm.land/fantasy/providers/openrouter"
 28	"git.secluded.site/crush/internal/agent/tools"
 29	"git.secluded.site/crush/internal/config"
 30	"git.secluded.site/crush/internal/csync"
 31	"git.secluded.site/crush/internal/message"
 32	"git.secluded.site/crush/internal/notification"
 33	"git.secluded.site/crush/internal/permission"
 34	"git.secluded.site/crush/internal/session"
 35	"github.com/charmbracelet/catwalk/pkg/catwalk"
 36)
 37
 38//go:embed templates/title.md
 39var titlePrompt []byte
 40
 41//go:embed templates/summary.md
 42var summaryPrompt []byte
 43
 44type SessionAgentCall struct {
 45	SessionID        string
 46	Prompt           string
 47	ProviderOptions  fantasy.ProviderOptions
 48	Attachments      []message.Attachment
 49	MaxOutputTokens  int64
 50	Temperature      *float64
 51	TopP             *float64
 52	TopK             *int64
 53	FrequencyPenalty *float64
 54	PresencePenalty  *float64
 55}
 56
 57type SessionAgent interface {
 58	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 59	SetModels(large Model, small Model)
 60	SetTools(tools []fantasy.AgentTool)
 61	Cancel(sessionID string)
 62	CancelAll()
 63	IsSessionBusy(sessionID string) bool
 64	IsBusy() bool
 65	QueuedPrompts(sessionID string) int
 66	ClearQueue(sessionID string)
 67	Summarize(context.Context, string, fantasy.ProviderOptions) error
 68	Model() Model
 69	// CancelCompletionNotification cancels any scheduled "turn ended"
 70	// notification for the provided sessionID.
 71	CancelCompletionNotification(sessionID string)
 72	// HasPendingCompletionNotification returns true if a turn-end
 73	// notification has been scheduled and not yet cancelled/shown.
 74	HasPendingCompletionNotification(sessionID string) bool
 75}
 76
 77const completionNotificationDelay = 5 * time.Second
 78
 79type Model struct {
 80	Model      fantasy.LanguageModel
 81	CatwalkCfg catwalk.Model
 82	ModelCfg   config.SelectedModel
 83}
 84
 85type sessionAgent struct {
 86	largeModel           Model
 87	smallModel           Model
 88	systemPromptPrefix   string
 89	systemPrompt         string
 90	tools                []fantasy.AgentTool
 91	sessions             session.Service
 92	messages             message.Service
 93	disableAutoSummarize bool
 94	isYolo               bool
 95
 96	messageQueue      *csync.Map[string, []SessionAgentCall]
 97	activeRequests    *csync.Map[string, context.CancelFunc]
 98	notifier          *notification.Notifier
 99	notifyCtx         context.Context
100	completionCancels *csync.Map[string, context.CancelFunc]
101}
102
103type SessionAgentOptions struct {
104	LargeModel           Model
105	SmallModel           Model
106	SystemPromptPrefix   string
107	SystemPrompt         string
108	DisableAutoSummarize bool
109	IsYolo               bool
110	Sessions             session.Service
111	Messages             message.Service
112	Tools                []fantasy.AgentTool
113	Notifier             *notification.Notifier
114	NotificationCtx      context.Context
115}
116
117func NewSessionAgent(
118	opts SessionAgentOptions,
119) SessionAgent {
120	notifyCtx := opts.NotificationCtx
121	if notifyCtx == nil {
122		notifyCtx = context.Background()
123	}
124
125	return &sessionAgent{
126		largeModel:           opts.LargeModel,
127		smallModel:           opts.SmallModel,
128		systemPromptPrefix:   opts.SystemPromptPrefix,
129		systemPrompt:         opts.SystemPrompt,
130		sessions:             opts.Sessions,
131		messages:             opts.Messages,
132		disableAutoSummarize: opts.DisableAutoSummarize,
133		tools:                opts.Tools,
134		isYolo:               opts.IsYolo,
135		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
136		activeRequests:       csync.NewMap[string, context.CancelFunc](),
137		notifier:             opts.Notifier,
138		notifyCtx:            notifyCtx,
139		completionCancels:    csync.NewMap[string, context.CancelFunc](),
140	}
141}
142
143func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
144	if call.Prompt == "" {
145		return nil, ErrEmptyPrompt
146	}
147	if call.SessionID == "" {
148		return nil, ErrSessionMissing
149	}
150
151	a.cancelCompletionNotification(call.SessionID)
152
153	// Queue the message if busy
154	if a.IsSessionBusy(call.SessionID) {
155		existing, ok := a.messageQueue.Get(call.SessionID)
156		if !ok {
157			existing = []SessionAgentCall{}
158		}
159		existing = append(existing, call)
160		a.messageQueue.Set(call.SessionID, existing)
161		return nil, nil
162	}
163
164	if len(a.tools) > 0 {
165		// Add Anthropic caching to the last tool.
166		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
167	}
168
169	agent := fantasy.NewAgent(
170		a.largeModel.Model,
171		fantasy.WithSystemPrompt(a.systemPrompt),
172		fantasy.WithTools(a.tools...),
173	)
174
175	sessionLock := sync.Mutex{}
176	currentSession, err := a.sessions.Get(ctx, call.SessionID)
177	if err != nil {
178		return nil, fmt.Errorf("failed to get session: %w", err)
179	}
180
181	msgs, err := a.getSessionMessages(ctx, currentSession)
182	if err != nil {
183		return nil, fmt.Errorf("failed to get session messages: %w", err)
184	}
185
186	var wg sync.WaitGroup
187	// Generate title if first message.
188	if len(msgs) == 0 {
189		wg.Go(func() {
190			sessionLock.Lock()
191			a.generateTitle(ctx, &currentSession, call.Prompt)
192			sessionLock.Unlock()
193		})
194	}
195
196	// Add the user message to the session.
197	_, err = a.createUserMessage(ctx, call)
198	if err != nil {
199		return nil, err
200	}
201
202	// Add the session to the context.
203	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
204
205	genCtx, cancel := context.WithCancel(ctx)
206	a.activeRequests.Set(call.SessionID, cancel)
207
208	defer cancel()
209	defer a.activeRequests.Del(call.SessionID)
210
211	history, files := a.preparePrompt(msgs, call.Attachments...)
212
213	startTime := time.Now()
214	a.eventPromptSent(call.SessionID)
215
216	var currentAssistant *message.Message
217	var shouldSummarize bool
218	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
219		Prompt:           call.Prompt,
220		Files:            files,
221		Messages:         history,
222		ProviderOptions:  call.ProviderOptions,
223		MaxOutputTokens:  &call.MaxOutputTokens,
224		TopP:             call.TopP,
225		Temperature:      call.Temperature,
226		PresencePenalty:  call.PresencePenalty,
227		TopK:             call.TopK,
228		FrequencyPenalty: call.FrequencyPenalty,
229		// Before each step create a new assistant message.
230		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
231			prepared.Messages = options.Messages
232			// Reset all cached items.
233			for i := range prepared.Messages {
234				prepared.Messages[i].ProviderOptions = nil
235			}
236
237			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
238			a.messageQueue.Del(call.SessionID)
239			for _, queued := range queuedCalls {
240				userMessage, createErr := a.createUserMessage(callContext, queued)
241				if createErr != nil {
242					return callContext, prepared, createErr
243				}
244				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
245			}
246
247			lastSystemRoleInx := 0
248			systemMessageUpdated := false
249			for i, msg := range prepared.Messages {
250				// Only add cache control to the last message.
251				if msg.Role == fantasy.MessageRoleSystem {
252					lastSystemRoleInx = i
253				} else if !systemMessageUpdated {
254					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
255					systemMessageUpdated = true
256				}
257				// Than add cache control to the last 2 messages.
258				if i > len(prepared.Messages)-3 {
259					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
260				}
261			}
262
263			if a.systemPromptPrefix != "" {
264				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
265			}
266
267			var assistantMsg message.Message
268			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
269				Role:     message.Assistant,
270				Parts:    []message.ContentPart{},
271				Model:    a.largeModel.ModelCfg.Model,
272				Provider: a.largeModel.ModelCfg.Provider,
273			})
274			if err != nil {
275				return callContext, prepared, err
276			}
277			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
278			currentAssistant = &assistantMsg
279			return callContext, prepared, err
280		},
281		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
282			currentAssistant.AppendReasoningContent(reasoning.Text)
283			return a.messages.Update(genCtx, *currentAssistant)
284		},
285		OnReasoningDelta: func(id string, text string) error {
286			currentAssistant.AppendReasoningContent(text)
287			return a.messages.Update(genCtx, *currentAssistant)
288		},
289		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
290			// handle anthropic signature
291			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
292				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
293					currentAssistant.AppendReasoningSignature(reasoning.Signature)
294				}
295			}
296			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
297				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
298					currentAssistant.AppendReasoningSignature(reasoning.Signature)
299				}
300			}
301			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
302				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
303					currentAssistant.SetReasoningResponsesData(reasoning)
304				}
305			}
306			currentAssistant.FinishThinking()
307			return a.messages.Update(genCtx, *currentAssistant)
308		},
309		OnTextDelta: func(id string, text string) error {
310			// Strip leading newline from initial text content. This is is
311			// particularly important in non-interactive mode where leading
312			// newlines are very visible.
313			if len(currentAssistant.Parts) == 0 {
314				text = strings.TrimPrefix(text, "\n")
315			}
316
317			currentAssistant.AppendContent(text)
318			return a.messages.Update(genCtx, *currentAssistant)
319		},
320		OnToolInputStart: func(id string, toolName string) error {
321			toolCall := message.ToolCall{
322				ID:               id,
323				Name:             toolName,
324				ProviderExecuted: false,
325				Finished:         false,
326			}
327			currentAssistant.AddToolCall(toolCall)
328			return a.messages.Update(genCtx, *currentAssistant)
329		},
330		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
331			// TODO: implement
332		},
333		OnToolCall: func(tc fantasy.ToolCallContent) error {
334			toolCall := message.ToolCall{
335				ID:               tc.ToolCallID,
336				Name:             tc.ToolName,
337				Input:            tc.Input,
338				ProviderExecuted: false,
339				Finished:         true,
340			}
341			currentAssistant.AddToolCall(toolCall)
342			return a.messages.Update(genCtx, *currentAssistant)
343		},
344		OnToolResult: func(result fantasy.ToolResultContent) error {
345			var resultContent string
346			isError := false
347			switch result.Result.GetType() {
348			case fantasy.ToolResultContentTypeText:
349				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
350				if ok {
351					resultContent = r.Text
352				}
353			case fantasy.ToolResultContentTypeError:
354				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
355				if ok {
356					isError = true
357					resultContent = r.Error.Error()
358				}
359			case fantasy.ToolResultContentTypeMedia:
360				// TODO: handle this message type
361			}
362			toolResult := message.ToolResult{
363				ToolCallID: result.ToolCallID,
364				Name:       result.ToolName,
365				Content:    resultContent,
366				IsError:    isError,
367				Metadata:   result.ClientMetadata,
368			}
369			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
370				Role: message.Tool,
371				Parts: []message.ContentPart{
372					toolResult,
373				},
374			})
375			if createMsgErr != nil {
376				return createMsgErr
377			}
378			return nil
379		},
380		OnStepFinish: func(stepResult fantasy.StepResult) error {
381			finishReason := message.FinishReasonUnknown
382			switch stepResult.FinishReason {
383			case fantasy.FinishReasonLength:
384				finishReason = message.FinishReasonMaxTokens
385			case fantasy.FinishReasonStop:
386				finishReason = message.FinishReasonEndTurn
387			case fantasy.FinishReasonToolCalls:
388				finishReason = message.FinishReasonToolUse
389			}
390			currentAssistant.AddFinish(finishReason, "", "")
391			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
392			sessionLock.Lock()
393			_, sessionErr := a.sessions.Save(genCtx, currentSession)
394			sessionLock.Unlock()
395			if sessionErr != nil {
396				return sessionErr
397			}
398			if finishReason == message.FinishReasonEndTurn {
399				a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
400			}
401			return a.messages.Update(genCtx, *currentAssistant)
402		},
403		StopWhen: []fantasy.StopCondition{
404			func(_ []fantasy.StepResult) bool {
405				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
406				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
407				remaining := cw - tokens
408				var threshold int64
409				if cw > 200_000 {
410					threshold = 20_000
411				} else {
412					threshold = int64(float64(cw) * 0.2)
413				}
414				if (remaining <= threshold) && !a.disableAutoSummarize {
415					shouldSummarize = true
416					return true
417				}
418				return false
419			},
420		},
421	})
422
423	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
424
425	if err != nil {
426		isCancelErr := errors.Is(err, context.Canceled)
427		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
428		if currentAssistant == nil {
429			return result, err
430		}
431		// Ensure we finish thinking on error to close the reasoning state.
432		currentAssistant.FinishThinking()
433		toolCalls := currentAssistant.ToolCalls()
434		// INFO: we use the parent context here because the genCtx has been cancelled.
435		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
436		if createErr != nil {
437			return nil, createErr
438		}
439		for _, tc := range toolCalls {
440			if !tc.Finished {
441				tc.Finished = true
442				tc.Input = "{}"
443				currentAssistant.AddToolCall(tc)
444				updateErr := a.messages.Update(ctx, *currentAssistant)
445				if updateErr != nil {
446					return nil, updateErr
447				}
448			}
449
450			found := false
451			for _, msg := range msgs {
452				if msg.Role == message.Tool {
453					for _, tr := range msg.ToolResults() {
454						if tr.ToolCallID == tc.ID {
455							found = true
456							break
457						}
458					}
459				}
460				if found {
461					break
462				}
463			}
464			if found {
465				continue
466			}
467			content := "There was an error while executing the tool"
468			if isCancelErr {
469				content = "Tool execution canceled by user"
470			} else if isPermissionErr {
471				content = "User denied permission"
472			}
473			toolResult := message.ToolResult{
474				ToolCallID: tc.ID,
475				Name:       tc.Name,
476				Content:    content,
477				IsError:    true,
478			}
479			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
480				Role: message.Tool,
481				Parts: []message.ContentPart{
482					toolResult,
483				},
484			})
485			if createErr != nil {
486				return nil, createErr
487			}
488		}
489		if isCancelErr {
490			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
491		} else if isPermissionErr {
492			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
493		} else {
494			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
495		}
496		// Note: we use the parent context here because the genCtx has been
497		// cancelled.
498		updateErr := a.messages.Update(ctx, *currentAssistant)
499		if updateErr != nil {
500			return nil, updateErr
501		}
502		return nil, err
503	}
504	wg.Wait()
505
506	if shouldSummarize {
507		a.cancelCompletionNotification(call.SessionID)
508		a.activeRequests.Del(call.SessionID)
509		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
510			return nil, summarizeErr
511		}
512		// If the agent wasn't done...
513		if len(currentAssistant.ToolCalls()) > 0 {
514			existing, ok := a.messageQueue.Get(call.SessionID)
515			if !ok {
516				existing = []SessionAgentCall{}
517			}
518			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
519			existing = append(existing, call)
520			a.messageQueue.Set(call.SessionID, existing)
521		}
522	}
523
524	// Release active request before processing queued messages.
525	a.activeRequests.Del(call.SessionID)
526	cancel()
527
528	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
529	if !ok || len(queuedMessages) == 0 {
530		return result, err
531	}
532	// There are queued messages restart the loop.
533	firstQueuedMessage := queuedMessages[0]
534	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
535	return a.Run(ctx, firstQueuedMessage)
536}
537
538func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
539	// Do not emit notifications for Agent-tool sub-sessions.
540	if a.sessions != nil && a.sessions.IsAgentToolSession(sessionID) {
541		return
542	}
543	if a.notifier == nil {
544		return
545	}
546
547	if sessionTitle == "" {
548		sessionTitle = sessionID
549	}
550
551	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
552		cancel()
553	}
554
555	title := "💘 Crush is waiting"
556	message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
557	cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
558	if cancel == nil {
559		cancel = func() {}
560	}
561	a.completionCancels.Set(sessionID, cancel)
562}
563
564func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
565	if a.notifier == nil {
566		return
567	}
568
569	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
570		cancel()
571	}
572}
573
574// CancelCompletionNotification implements SessionAgent.
575func (a *sessionAgent) CancelCompletionNotification(sessionID string) {
576	a.cancelCompletionNotification(sessionID)
577}
578
579// HasPendingCompletionNotification implements SessionAgent.
580func (a *sessionAgent) HasPendingCompletionNotification(sessionID string) bool {
581	if a.IsSessionBusy(sessionID) {
582		return false
583	}
584	_, ok := a.completionCancels.Get(sessionID)
585	return ok
586}
587
588func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
589	if a.IsSessionBusy(sessionID) {
590		return ErrSessionBusy
591	}
592
593	currentSession, err := a.sessions.Get(ctx, sessionID)
594	if err != nil {
595		return fmt.Errorf("failed to get session: %w", err)
596	}
597	msgs, err := a.getSessionMessages(ctx, currentSession)
598	if err != nil {
599		return err
600	}
601	if len(msgs) == 0 {
602		// Nothing to summarize.
603		return nil
604	}
605
606	aiMsgs, _ := a.preparePrompt(msgs)
607
608	genCtx, cancel := context.WithCancel(ctx)
609	a.activeRequests.Set(sessionID, cancel)
610	defer a.activeRequests.Del(sessionID)
611	defer cancel()
612
613	agent := fantasy.NewAgent(a.largeModel.Model,
614		fantasy.WithSystemPrompt(string(summaryPrompt)),
615	)
616	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
617		Role:             message.Assistant,
618		Model:            a.largeModel.Model.Model(),
619		Provider:         a.largeModel.Model.Provider(),
620		IsSummaryMessage: true,
621	})
622	if err != nil {
623		return err
624	}
625
626	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
627		Prompt:          "Provide a detailed summary of our conversation above.",
628		Messages:        aiMsgs,
629		ProviderOptions: opts,
630		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
631			prepared.Messages = options.Messages
632			if a.systemPromptPrefix != "" {
633				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
634			}
635			return callContext, prepared, nil
636		},
637		OnReasoningDelta: func(id string, text string) error {
638			summaryMessage.AppendReasoningContent(text)
639			return a.messages.Update(genCtx, summaryMessage)
640		},
641		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
642			// Handle anthropic signature.
643			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
644				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
645					summaryMessage.AppendReasoningSignature(signature.Signature)
646				}
647			}
648			summaryMessage.FinishThinking()
649			return a.messages.Update(genCtx, summaryMessage)
650		},
651		OnTextDelta: func(id, text string) error {
652			summaryMessage.AppendContent(text)
653			return a.messages.Update(genCtx, summaryMessage)
654		},
655	})
656	if err != nil {
657		isCancelErr := errors.Is(err, context.Canceled)
658		if isCancelErr {
659			// User cancelled summarize we need to remove the summary message.
660			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
661			return deleteErr
662		}
663		return err
664	}
665
666	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
667	err = a.messages.Update(genCtx, summaryMessage)
668	if err != nil {
669		return err
670	}
671
672	var openrouterCost *float64
673	for _, step := range resp.Steps {
674		stepCost := a.openrouterCost(step.ProviderMetadata)
675		if stepCost != nil {
676			newCost := *stepCost
677			if openrouterCost != nil {
678				newCost += *openrouterCost
679			}
680			openrouterCost = &newCost
681		}
682	}
683
684	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
685
686	// Just in case, get just the last usage info.
687	usage := resp.Response.Usage
688	currentSession.SummaryMessageID = summaryMessage.ID
689	currentSession.CompletionTokens = usage.OutputTokens
690	currentSession.PromptTokens = 0
691	_, err = a.sessions.Save(genCtx, currentSession)
692	return err
693}
694
695func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
696	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
697		return fantasy.ProviderOptions{}
698	}
699	return fantasy.ProviderOptions{
700		anthropic.Name: &anthropic.ProviderCacheControlOptions{
701			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
702		},
703		bedrock.Name: &anthropic.ProviderCacheControlOptions{
704			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
705		},
706	}
707}
708
709func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
710	var attachmentParts []message.ContentPart
711	for _, attachment := range call.Attachments {
712		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
713	}
714	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
715	parts = append(parts, attachmentParts...)
716	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
717		Role:  message.User,
718		Parts: parts,
719	})
720	if err != nil {
721		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
722	}
723	return msg, nil
724}
725
726func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
727	var history []fantasy.Message
728	for _, m := range msgs {
729		if len(m.Parts) == 0 {
730			continue
731		}
732		// Assistant message without content or tool calls (cancelled before it
733		// returned anything).
734		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
735			continue
736		}
737		history = append(history, m.ToAIMessage()...)
738	}
739
740	var files []fantasy.FilePart
741	for _, attachment := range attachments {
742		files = append(files, fantasy.FilePart{
743			Filename:  attachment.FileName,
744			Data:      attachment.Content,
745			MediaType: attachment.MimeType,
746		})
747	}
748
749	return history, files
750}
751
752func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
753	msgs, err := a.messages.List(ctx, session.ID)
754	if err != nil {
755		return nil, fmt.Errorf("failed to list messages: %w", err)
756	}
757
758	if session.SummaryMessageID != "" {
759		summaryMsgInex := -1
760		for i, msg := range msgs {
761			if msg.ID == session.SummaryMessageID {
762				summaryMsgInex = i
763				break
764			}
765		}
766		if summaryMsgInex != -1 {
767			msgs = msgs[summaryMsgInex:]
768			msgs[0].Role = message.User
769		}
770	}
771	return msgs, nil
772}
773
774func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
775	if prompt == "" {
776		return
777	}
778
779	var maxOutput int64 = 40
780	if a.smallModel.CatwalkCfg.CanReason {
781		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
782	}
783
784	agent := fantasy.NewAgent(a.smallModel.Model,
785		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
786		fantasy.WithMaxOutputTokens(maxOutput),
787	)
788
789	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
790		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
791		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
792			prepared.Messages = options.Messages
793			if a.systemPromptPrefix != "" {
794				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
795			}
796			return callContext, prepared, nil
797		},
798	})
799	if err != nil {
800		slog.Error("error generating title", "err", err)
801		return
802	}
803
804	title := resp.Response.Content.Text()
805
806	title = strings.ReplaceAll(title, "\n", " ")
807
808	// Remove thinking tags if present.
809	if idx := strings.Index(title, "</think>"); idx > 0 {
810		title = title[idx+len("</think>"):]
811	}
812
813	title = strings.TrimSpace(title)
814	if title == "" {
815		slog.Warn("failed to generate title", "warn", "empty title")
816		return
817	}
818
819	session.Title = title
820
821	var openrouterCost *float64
822	for _, step := range resp.Steps {
823		stepCost := a.openrouterCost(step.ProviderMetadata)
824		if stepCost != nil {
825			newCost := *stepCost
826			if openrouterCost != nil {
827				newCost += *openrouterCost
828			}
829			openrouterCost = &newCost
830		}
831	}
832
833	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
834	_, saveErr := a.sessions.Save(ctx, *session)
835	if saveErr != nil {
836		slog.Error("failed to save session title & usage", "error", saveErr)
837		return
838	}
839}
840
841func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
842	openrouterMetadata, ok := metadata[openrouter.Name]
843	if !ok {
844		return nil
845	}
846
847	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
848	if !ok {
849		return nil
850	}
851	return &opts.Usage.Cost
852}
853
854func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
855	modelConfig := model.CatwalkCfg
856	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
857		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
858		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
859		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
860
861	a.eventTokensUsed(session.ID, model, usage, cost)
862
863	if overrideCost != nil {
864		session.Cost += *overrideCost
865	} else {
866		session.Cost += cost
867	}
868
869	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
870	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
871}
872
873func (a *sessionAgent) Cancel(sessionID string) {
874	// Cancel regular requests.
875	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
876		slog.Info("Request cancellation initiated", "session_id", sessionID)
877		cancel()
878	}
879
880	// Also check for summarize requests.
881	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
882		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
883		cancel()
884	}
885
886	if a.QueuedPrompts(sessionID) > 0 {
887		slog.Info("Clearing queued prompts", "session_id", sessionID)
888		a.messageQueue.Del(sessionID)
889	}
890}
891
892func (a *sessionAgent) ClearQueue(sessionID string) {
893	if a.QueuedPrompts(sessionID) > 0 {
894		slog.Info("Clearing queued prompts", "session_id", sessionID)
895		a.messageQueue.Del(sessionID)
896	}
897}
898
899func (a *sessionAgent) CancelAll() {
900	if !a.IsBusy() {
901		// still ensure notifications are cancelled even when not busy
902		for cancel := range a.completionCancels.Seq() {
903			if cancel != nil {
904				cancel()
905			}
906		}
907		a.completionCancels.Reset(make(map[string]context.CancelFunc))
908		return
909	}
910	for key := range a.activeRequests.Seq2() {
911		a.Cancel(key) // key is sessionID
912	}
913
914	timeout := time.After(5 * time.Second)
915	for a.IsBusy() {
916		select {
917		case <-timeout:
918			return
919		default:
920			time.Sleep(200 * time.Millisecond)
921		}
922	}
923
924	for cancel := range a.completionCancels.Seq() {
925		if cancel != nil {
926			cancel()
927		}
928	}
929	a.completionCancels.Reset(make(map[string]context.CancelFunc))
930}
931
932func (a *sessionAgent) IsBusy() bool {
933	var busy bool
934	for cancelFunc := range a.activeRequests.Seq() {
935		if cancelFunc != nil {
936			busy = true
937			break
938		}
939	}
940	return busy
941}
942
943func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
944	_, busy := a.activeRequests.Get(sessionID)
945	return busy
946}
947
948func (a *sessionAgent) QueuedPrompts(sessionID string) int {
949	l, ok := a.messageQueue.Get(sessionID)
950	if !ok {
951		return 0
952	}
953	return len(l)
954}
955
956func (a *sessionAgent) SetModels(large Model, small Model) {
957	a.largeModel = large
958	a.smallModel = small
959}
960
961func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
962	a.tools = tools
963}
964
965func (a *sessionAgent) Model() Model {
966	return a.largeModel
967}