agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"git.secluded.site/crush/internal/agent/tools"
 30	"git.secluded.site/crush/internal/config"
 31	"git.secluded.site/crush/internal/csync"
 32	"git.secluded.site/crush/internal/message"
 33	"git.secluded.site/crush/internal/notification"
 34	"git.secluded.site/crush/internal/permission"
 35	"git.secluded.site/crush/internal/session"
 36	"git.secluded.site/crush/internal/stringext"
 37	"github.com/charmbracelet/catwalk/pkg/catwalk"
 38)
 39
 40//go:embed templates/title.md
 41var titlePrompt []byte
 42
 43//go:embed templates/summary.md
 44var summaryPrompt []byte
 45
 46type SessionAgentCall struct {
 47	SessionID        string
 48	Prompt           string
 49	ProviderOptions  fantasy.ProviderOptions
 50	Attachments      []message.Attachment
 51	MaxOutputTokens  int64
 52	Temperature      *float64
 53	TopP             *float64
 54	TopK             *int64
 55	FrequencyPenalty *float64
 56	PresencePenalty  *float64
 57}
 58
 59type SessionAgent interface {
 60	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 61	SetModels(large Model, small Model)
 62	SetTools(tools []fantasy.AgentTool)
 63	Cancel(sessionID string)
 64	CancelAll()
 65	IsSessionBusy(sessionID string) bool
 66	IsBusy() bool
 67	QueuedPrompts(sessionID string) int
 68	ClearQueue(sessionID string)
 69	Summarize(context.Context, string, fantasy.ProviderOptions) error
 70	Model() Model
 71	// CancelCompletionNotification cancels any scheduled "turn ended"
 72	// notification for the provided sessionID.
 73	CancelCompletionNotification(sessionID string)
 74	// HasPendingCompletionNotification returns true if a turn-end
 75	// notification has been scheduled and not yet cancelled/shown.
 76	HasPendingCompletionNotification(sessionID string) bool
 77}
 78
 79const completionNotificationDelay = 5 * time.Second
 80
 81type Model struct {
 82	Model      fantasy.LanguageModel
 83	CatwalkCfg catwalk.Model
 84	ModelCfg   config.SelectedModel
 85}
 86
 87type sessionAgent struct {
 88	largeModel           Model
 89	smallModel           Model
 90	systemPromptPrefix   string
 91	systemPrompt         string
 92	tools                []fantasy.AgentTool
 93	sessions             session.Service
 94	messages             message.Service
 95	disableAutoSummarize bool
 96	isYolo               bool
 97
 98	messageQueue      *csync.Map[string, []SessionAgentCall]
 99	activeRequests    *csync.Map[string, context.CancelFunc]
100	notifier          *notification.Notifier
101	notifyCtx         context.Context
102	completionCancels *csync.Map[string, context.CancelFunc]
103}
104
105type SessionAgentOptions struct {
106	LargeModel           Model
107	SmallModel           Model
108	SystemPromptPrefix   string
109	SystemPrompt         string
110	DisableAutoSummarize bool
111	IsYolo               bool
112	Sessions             session.Service
113	Messages             message.Service
114	Tools                []fantasy.AgentTool
115	Notifier             *notification.Notifier
116	NotificationCtx      context.Context
117}
118
119func NewSessionAgent(
120	opts SessionAgentOptions,
121) SessionAgent {
122	notifyCtx := opts.NotificationCtx
123	if notifyCtx == nil {
124		notifyCtx = context.Background()
125	}
126
127	return &sessionAgent{
128		largeModel:           opts.LargeModel,
129		smallModel:           opts.SmallModel,
130		systemPromptPrefix:   opts.SystemPromptPrefix,
131		systemPrompt:         opts.SystemPrompt,
132		sessions:             opts.Sessions,
133		messages:             opts.Messages,
134		disableAutoSummarize: opts.DisableAutoSummarize,
135		tools:                opts.Tools,
136		isYolo:               opts.IsYolo,
137		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
138		activeRequests:       csync.NewMap[string, context.CancelFunc](),
139		notifier:             opts.Notifier,
140		notifyCtx:            notifyCtx,
141		completionCancels:    csync.NewMap[string, context.CancelFunc](),
142	}
143}
144
145func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
146	if call.Prompt == "" {
147		return nil, ErrEmptyPrompt
148	}
149	if call.SessionID == "" {
150		return nil, ErrSessionMissing
151	}
152
153	a.cancelCompletionNotification(call.SessionID)
154
155	// Queue the message if busy
156	if a.IsSessionBusy(call.SessionID) {
157		existing, ok := a.messageQueue.Get(call.SessionID)
158		if !ok {
159			existing = []SessionAgentCall{}
160		}
161		existing = append(existing, call)
162		a.messageQueue.Set(call.SessionID, existing)
163		return nil, nil
164	}
165
166	if len(a.tools) > 0 {
167		// Add Anthropic caching to the last tool.
168		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
169	}
170
171	agent := fantasy.NewAgent(
172		a.largeModel.Model,
173		fantasy.WithSystemPrompt(a.systemPrompt),
174		fantasy.WithTools(a.tools...),
175	)
176
177	sessionLock := sync.Mutex{}
178	currentSession, err := a.sessions.Get(ctx, call.SessionID)
179	if err != nil {
180		return nil, fmt.Errorf("failed to get session: %w", err)
181	}
182
183	msgs, err := a.getSessionMessages(ctx, currentSession)
184	if err != nil {
185		return nil, fmt.Errorf("failed to get session messages: %w", err)
186	}
187
188	var wg sync.WaitGroup
189	// Generate title if first message.
190	if len(msgs) == 0 {
191		wg.Go(func() {
192			sessionLock.Lock()
193			a.generateTitle(ctx, &currentSession, call.Prompt)
194			sessionLock.Unlock()
195		})
196	}
197
198	// Add the user message to the session.
199	_, err = a.createUserMessage(ctx, call)
200	if err != nil {
201		return nil, err
202	}
203
204	// Add the session to the context.
205	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
206
207	genCtx, cancel := context.WithCancel(ctx)
208	a.activeRequests.Set(call.SessionID, cancel)
209
210	defer cancel()
211	defer a.activeRequests.Del(call.SessionID)
212
213	history, files := a.preparePrompt(msgs, call.Attachments...)
214
215	startTime := time.Now()
216	a.eventPromptSent(call.SessionID)
217
218	var currentAssistant *message.Message
219	var shouldSummarize bool
220	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
221		Prompt:           call.Prompt,
222		Files:            files,
223		Messages:         history,
224		ProviderOptions:  call.ProviderOptions,
225		MaxOutputTokens:  &call.MaxOutputTokens,
226		TopP:             call.TopP,
227		Temperature:      call.Temperature,
228		PresencePenalty:  call.PresencePenalty,
229		TopK:             call.TopK,
230		FrequencyPenalty: call.FrequencyPenalty,
231		// Before each step create a new assistant message.
232		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
233			prepared.Messages = options.Messages
234			// Reset all cached items.
235			for i := range prepared.Messages {
236				prepared.Messages[i].ProviderOptions = nil
237			}
238
239			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
240			a.messageQueue.Del(call.SessionID)
241			for _, queued := range queuedCalls {
242				userMessage, createErr := a.createUserMessage(callContext, queued)
243				if createErr != nil {
244					return callContext, prepared, createErr
245				}
246				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
247			}
248
249			lastSystemRoleInx := 0
250			systemMessageUpdated := false
251			for i, msg := range prepared.Messages {
252				// Only add cache control to the last message.
253				if msg.Role == fantasy.MessageRoleSystem {
254					lastSystemRoleInx = i
255				} else if !systemMessageUpdated {
256					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
257					systemMessageUpdated = true
258				}
259				// Than add cache control to the last 2 messages.
260				if i > len(prepared.Messages)-3 {
261					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
262				}
263			}
264
265			if a.systemPromptPrefix != "" {
266				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
267			}
268
269			var assistantMsg message.Message
270			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
271				Role:     message.Assistant,
272				Parts:    []message.ContentPart{},
273				Model:    a.largeModel.ModelCfg.Model,
274				Provider: a.largeModel.ModelCfg.Provider,
275			})
276			if err != nil {
277				return callContext, prepared, err
278			}
279			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
280			currentAssistant = &assistantMsg
281			return callContext, prepared, err
282		},
283		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
284			currentAssistant.AppendReasoningContent(reasoning.Text)
285			return a.messages.Update(genCtx, *currentAssistant)
286		},
287		OnReasoningDelta: func(id string, text string) error {
288			currentAssistant.AppendReasoningContent(text)
289			return a.messages.Update(genCtx, *currentAssistant)
290		},
291		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
292			// handle anthropic signature
293			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
294				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
295					currentAssistant.AppendReasoningSignature(reasoning.Signature)
296				}
297			}
298			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
299				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
300					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
301				}
302			}
303			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
304				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
305					currentAssistant.SetReasoningResponsesData(reasoning)
306				}
307			}
308			currentAssistant.FinishThinking()
309			return a.messages.Update(genCtx, *currentAssistant)
310		},
311		OnTextDelta: func(id string, text string) error {
312			// Strip leading newline from initial text content. This is is
313			// particularly important in non-interactive mode where leading
314			// newlines are very visible.
315			if len(currentAssistant.Parts) == 0 {
316				text = strings.TrimPrefix(text, "\n")
317			}
318
319			currentAssistant.AppendContent(text)
320			return a.messages.Update(genCtx, *currentAssistant)
321		},
322		OnToolInputStart: func(id string, toolName string) error {
323			toolCall := message.ToolCall{
324				ID:               id,
325				Name:             toolName,
326				ProviderExecuted: false,
327				Finished:         false,
328			}
329			currentAssistant.AddToolCall(toolCall)
330			return a.messages.Update(genCtx, *currentAssistant)
331		},
332		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
333			// TODO: implement
334		},
335		OnToolCall: func(tc fantasy.ToolCallContent) error {
336			toolCall := message.ToolCall{
337				ID:               tc.ToolCallID,
338				Name:             tc.ToolName,
339				Input:            tc.Input,
340				ProviderExecuted: false,
341				Finished:         true,
342			}
343			currentAssistant.AddToolCall(toolCall)
344			return a.messages.Update(genCtx, *currentAssistant)
345		},
346		OnToolResult: func(result fantasy.ToolResultContent) error {
347			var resultContent string
348			isError := false
349			switch result.Result.GetType() {
350			case fantasy.ToolResultContentTypeText:
351				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
352				if ok {
353					resultContent = r.Text
354				}
355			case fantasy.ToolResultContentTypeError:
356				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
357				if ok {
358					isError = true
359					resultContent = r.Error.Error()
360				}
361			case fantasy.ToolResultContentTypeMedia:
362				// TODO: handle this message type
363			}
364			toolResult := message.ToolResult{
365				ToolCallID: result.ToolCallID,
366				Name:       result.ToolName,
367				Content:    resultContent,
368				IsError:    isError,
369				Metadata:   result.ClientMetadata,
370			}
371			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
372				Role: message.Tool,
373				Parts: []message.ContentPart{
374					toolResult,
375				},
376			})
377			if createMsgErr != nil {
378				return createMsgErr
379			}
380			return nil
381		},
382		OnStepFinish: func(stepResult fantasy.StepResult) error {
383			finishReason := message.FinishReasonUnknown
384			switch stepResult.FinishReason {
385			case fantasy.FinishReasonLength:
386				finishReason = message.FinishReasonMaxTokens
387			case fantasy.FinishReasonStop:
388				finishReason = message.FinishReasonEndTurn
389			case fantasy.FinishReasonToolCalls:
390				finishReason = message.FinishReasonToolUse
391			}
392			currentAssistant.AddFinish(finishReason, "", "")
393			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
394			sessionLock.Lock()
395			_, sessionErr := a.sessions.Save(genCtx, currentSession)
396			sessionLock.Unlock()
397			if sessionErr != nil {
398				return sessionErr
399			}
400			if finishReason == message.FinishReasonEndTurn {
401				a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
402			}
403			return a.messages.Update(genCtx, *currentAssistant)
404		},
405		StopWhen: []fantasy.StopCondition{
406			func(_ []fantasy.StepResult) bool {
407				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
408				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
409				remaining := cw - tokens
410				var threshold int64
411				if cw > 200_000 {
412					threshold = 20_000
413				} else {
414					threshold = int64(float64(cw) * 0.2)
415				}
416				if (remaining <= threshold) && !a.disableAutoSummarize {
417					shouldSummarize = true
418					return true
419				}
420				return false
421			},
422		},
423	})
424
425	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
426
427	if err != nil {
428		isCancelErr := errors.Is(err, context.Canceled)
429		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
430		if currentAssistant == nil {
431			return result, err
432		}
433		// Ensure we finish thinking on error to close the reasoning state.
434		currentAssistant.FinishThinking()
435		toolCalls := currentAssistant.ToolCalls()
436		// INFO: we use the parent context here because the genCtx has been cancelled.
437		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
438		if createErr != nil {
439			return nil, createErr
440		}
441		for _, tc := range toolCalls {
442			if !tc.Finished {
443				tc.Finished = true
444				tc.Input = "{}"
445				currentAssistant.AddToolCall(tc)
446				updateErr := a.messages.Update(ctx, *currentAssistant)
447				if updateErr != nil {
448					return nil, updateErr
449				}
450			}
451
452			found := false
453			for _, msg := range msgs {
454				if msg.Role == message.Tool {
455					for _, tr := range msg.ToolResults() {
456						if tr.ToolCallID == tc.ID {
457							found = true
458							break
459						}
460					}
461				}
462				if found {
463					break
464				}
465			}
466			if found {
467				continue
468			}
469			content := "There was an error while executing the tool"
470			if isCancelErr {
471				content = "Tool execution canceled by user"
472			} else if isPermissionErr {
473				content = "User denied permission"
474			}
475			toolResult := message.ToolResult{
476				ToolCallID: tc.ID,
477				Name:       tc.Name,
478				Content:    content,
479				IsError:    true,
480			}
481			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
482				Role: message.Tool,
483				Parts: []message.ContentPart{
484					toolResult,
485				},
486			})
487			if createErr != nil {
488				return nil, createErr
489			}
490		}
491		var fantasyErr *fantasy.Error
492		var providerErr *fantasy.ProviderError
493		const defaultTitle = "Provider Error"
494		if isCancelErr {
495			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
496		} else if isPermissionErr {
497			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
498		} else if errors.As(err, &providerErr) {
499			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
500		} else if errors.As(err, &fantasyErr) {
501			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
502		} else {
503			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
504		}
505		// Note: we use the parent context here because the genCtx has been
506		// cancelled.
507		updateErr := a.messages.Update(ctx, *currentAssistant)
508		if updateErr != nil {
509			return nil, updateErr
510		}
511		return nil, err
512	}
513	wg.Wait()
514
515	if shouldSummarize {
516		a.cancelCompletionNotification(call.SessionID)
517		a.activeRequests.Del(call.SessionID)
518		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
519			return nil, summarizeErr
520		}
521		// If the agent wasn't done...
522		if len(currentAssistant.ToolCalls()) > 0 {
523			existing, ok := a.messageQueue.Get(call.SessionID)
524			if !ok {
525				existing = []SessionAgentCall{}
526			}
527			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
528			existing = append(existing, call)
529			a.messageQueue.Set(call.SessionID, existing)
530		}
531	}
532
533	// Release active request before processing queued messages.
534	a.activeRequests.Del(call.SessionID)
535	cancel()
536
537	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
538	if !ok || len(queuedMessages) == 0 {
539		return result, err
540	}
541	// There are queued messages restart the loop.
542	firstQueuedMessage := queuedMessages[0]
543	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
544	return a.Run(ctx, firstQueuedMessage)
545}
546
547func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
548	// Do not emit notifications for Agent-tool sub-sessions.
549	if a.sessions != nil && a.sessions.IsAgentToolSession(sessionID) {
550		return
551	}
552	if a.notifier == nil {
553		return
554	}
555
556	if sessionTitle == "" {
557		sessionTitle = sessionID
558	}
559
560	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
561		cancel()
562	}
563
564	title := "💘 Crush is waiting"
565	message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
566	cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
567	if cancel == nil {
568		cancel = func() {}
569	}
570	a.completionCancels.Set(sessionID, cancel)
571}
572
573func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
574	if a.notifier == nil {
575		return
576	}
577
578	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
579		cancel()
580	}
581}
582
583// CancelCompletionNotification implements SessionAgent.
584func (a *sessionAgent) CancelCompletionNotification(sessionID string) {
585	a.cancelCompletionNotification(sessionID)
586}
587
588// HasPendingCompletionNotification implements SessionAgent.
589func (a *sessionAgent) HasPendingCompletionNotification(sessionID string) bool {
590	if a.IsSessionBusy(sessionID) {
591		return false
592	}
593	_, ok := a.completionCancels.Get(sessionID)
594	return ok
595}
596
597func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
598	if a.IsSessionBusy(sessionID) {
599		return ErrSessionBusy
600	}
601
602	currentSession, err := a.sessions.Get(ctx, sessionID)
603	if err != nil {
604		return fmt.Errorf("failed to get session: %w", err)
605	}
606	msgs, err := a.getSessionMessages(ctx, currentSession)
607	if err != nil {
608		return err
609	}
610	if len(msgs) == 0 {
611		// Nothing to summarize.
612		return nil
613	}
614
615	aiMsgs, _ := a.preparePrompt(msgs)
616
617	genCtx, cancel := context.WithCancel(ctx)
618	a.activeRequests.Set(sessionID, cancel)
619	defer a.activeRequests.Del(sessionID)
620	defer cancel()
621
622	agent := fantasy.NewAgent(a.largeModel.Model,
623		fantasy.WithSystemPrompt(string(summaryPrompt)),
624	)
625	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
626		Role:             message.Assistant,
627		Model:            a.largeModel.Model.Model(),
628		Provider:         a.largeModel.Model.Provider(),
629		IsSummaryMessage: true,
630	})
631	if err != nil {
632		return err
633	}
634
635	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
636		Prompt:          "Provide a detailed summary of our conversation above.",
637		Messages:        aiMsgs,
638		ProviderOptions: opts,
639		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
640			prepared.Messages = options.Messages
641			if a.systemPromptPrefix != "" {
642				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
643			}
644			return callContext, prepared, nil
645		},
646		OnReasoningDelta: func(id string, text string) error {
647			summaryMessage.AppendReasoningContent(text)
648			return a.messages.Update(genCtx, summaryMessage)
649		},
650		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
651			// Handle anthropic signature.
652			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
653				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
654					summaryMessage.AppendReasoningSignature(signature.Signature)
655				}
656			}
657			summaryMessage.FinishThinking()
658			return a.messages.Update(genCtx, summaryMessage)
659		},
660		OnTextDelta: func(id, text string) error {
661			summaryMessage.AppendContent(text)
662			return a.messages.Update(genCtx, summaryMessage)
663		},
664	})
665	if err != nil {
666		isCancelErr := errors.Is(err, context.Canceled)
667		if isCancelErr {
668			// User cancelled summarize we need to remove the summary message.
669			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
670			return deleteErr
671		}
672		return err
673	}
674
675	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
676	err = a.messages.Update(genCtx, summaryMessage)
677	if err != nil {
678		return err
679	}
680
681	var openrouterCost *float64
682	for _, step := range resp.Steps {
683		stepCost := a.openrouterCost(step.ProviderMetadata)
684		if stepCost != nil {
685			newCost := *stepCost
686			if openrouterCost != nil {
687				newCost += *openrouterCost
688			}
689			openrouterCost = &newCost
690		}
691	}
692
693	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
694
695	// Just in case, get just the last usage info.
696	usage := resp.Response.Usage
697	currentSession.SummaryMessageID = summaryMessage.ID
698	currentSession.CompletionTokens = usage.OutputTokens
699	currentSession.PromptTokens = 0
700	_, err = a.sessions.Save(genCtx, currentSession)
701	return err
702}
703
704func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
705	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
706		return fantasy.ProviderOptions{}
707	}
708	return fantasy.ProviderOptions{
709		anthropic.Name: &anthropic.ProviderCacheControlOptions{
710			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
711		},
712		bedrock.Name: &anthropic.ProviderCacheControlOptions{
713			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
714		},
715	}
716}
717
718func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
719	var attachmentParts []message.ContentPart
720	for _, attachment := range call.Attachments {
721		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
722	}
723	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
724	parts = append(parts, attachmentParts...)
725	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
726		Role:  message.User,
727		Parts: parts,
728	})
729	if err != nil {
730		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
731	}
732	return msg, nil
733}
734
735func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
736	var history []fantasy.Message
737	for _, m := range msgs {
738		if len(m.Parts) == 0 {
739			continue
740		}
741		// Assistant message without content or tool calls (cancelled before it
742		// returned anything).
743		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
744			continue
745		}
746		history = append(history, m.ToAIMessage()...)
747	}
748
749	var files []fantasy.FilePart
750	for _, attachment := range attachments {
751		files = append(files, fantasy.FilePart{
752			Filename:  attachment.FileName,
753			Data:      attachment.Content,
754			MediaType: attachment.MimeType,
755		})
756	}
757
758	return history, files
759}
760
761func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
762	msgs, err := a.messages.List(ctx, session.ID)
763	if err != nil {
764		return nil, fmt.Errorf("failed to list messages: %w", err)
765	}
766
767	if session.SummaryMessageID != "" {
768		summaryMsgInex := -1
769		for i, msg := range msgs {
770			if msg.ID == session.SummaryMessageID {
771				summaryMsgInex = i
772				break
773			}
774		}
775		if summaryMsgInex != -1 {
776			msgs = msgs[summaryMsgInex:]
777			msgs[0].Role = message.User
778		}
779	}
780	return msgs, nil
781}
782
783func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
784	if prompt == "" {
785		return
786	}
787
788	var maxOutput int64 = 40
789	if a.smallModel.CatwalkCfg.CanReason {
790		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
791	}
792
793	agent := fantasy.NewAgent(a.smallModel.Model,
794		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
795		fantasy.WithMaxOutputTokens(maxOutput),
796	)
797
798	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
799		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
800		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
801			prepared.Messages = options.Messages
802			if a.systemPromptPrefix != "" {
803				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
804			}
805			return callContext, prepared, nil
806		},
807	})
808	if err != nil {
809		slog.Error("error generating title", "err", err)
810		return
811	}
812
813	title := resp.Response.Content.Text()
814
815	title = strings.ReplaceAll(title, "\n", " ")
816
817	// Remove thinking tags if present.
818	if idx := strings.Index(title, "</think>"); idx > 0 {
819		title = title[idx+len("</think>"):]
820	}
821
822	title = strings.TrimSpace(title)
823	if title == "" {
824		slog.Warn("failed to generate title", "warn", "empty title")
825		return
826	}
827
828	session.Title = title
829
830	var openrouterCost *float64
831	for _, step := range resp.Steps {
832		stepCost := a.openrouterCost(step.ProviderMetadata)
833		if stepCost != nil {
834			newCost := *stepCost
835			if openrouterCost != nil {
836				newCost += *openrouterCost
837			}
838			openrouterCost = &newCost
839		}
840	}
841
842	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
843	_, saveErr := a.sessions.Save(ctx, *session)
844	if saveErr != nil {
845		slog.Error("failed to save session title & usage", "error", saveErr)
846		return
847	}
848}
849
850func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
851	openrouterMetadata, ok := metadata[openrouter.Name]
852	if !ok {
853		return nil
854	}
855
856	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
857	if !ok {
858		return nil
859	}
860	return &opts.Usage.Cost
861}
862
863func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
864	modelConfig := model.CatwalkCfg
865	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
866		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
867		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
868		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
869
870	a.eventTokensUsed(session.ID, model, usage, cost)
871
872	if overrideCost != nil {
873		session.Cost += *overrideCost
874	} else {
875		session.Cost += cost
876	}
877
878	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
879	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
880}
881
882func (a *sessionAgent) Cancel(sessionID string) {
883	// Cancel regular requests.
884	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
885		slog.Info("Request cancellation initiated", "session_id", sessionID)
886		cancel()
887	}
888
889	// Also check for summarize requests.
890	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
891		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
892		cancel()
893	}
894
895	if a.QueuedPrompts(sessionID) > 0 {
896		slog.Info("Clearing queued prompts", "session_id", sessionID)
897		a.messageQueue.Del(sessionID)
898	}
899}
900
901func (a *sessionAgent) ClearQueue(sessionID string) {
902	if a.QueuedPrompts(sessionID) > 0 {
903		slog.Info("Clearing queued prompts", "session_id", sessionID)
904		a.messageQueue.Del(sessionID)
905	}
906}
907
908func (a *sessionAgent) CancelAll() {
909	if !a.IsBusy() {
910		// still ensure notifications are cancelled even when not busy
911		for cancel := range a.completionCancels.Seq() {
912			if cancel != nil {
913				cancel()
914			}
915		}
916		a.completionCancels.Reset(make(map[string]context.CancelFunc))
917		return
918	}
919	for key := range a.activeRequests.Seq2() {
920		a.Cancel(key) // key is sessionID
921	}
922
923	timeout := time.After(5 * time.Second)
924	for a.IsBusy() {
925		select {
926		case <-timeout:
927			return
928		default:
929			time.Sleep(200 * time.Millisecond)
930		}
931	}
932
933	for cancel := range a.completionCancels.Seq() {
934		if cancel != nil {
935			cancel()
936		}
937	}
938	a.completionCancels.Reset(make(map[string]context.CancelFunc))
939}
940
941func (a *sessionAgent) IsBusy() bool {
942	var busy bool
943	for cancelFunc := range a.activeRequests.Seq() {
944		if cancelFunc != nil {
945			busy = true
946			break
947		}
948	}
949	return busy
950}
951
952func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
953	_, busy := a.activeRequests.Get(sessionID)
954	return busy
955}
956
957func (a *sessionAgent) QueuedPrompts(sessionID string) int {
958	l, ok := a.messageQueue.Get(sessionID)
959	if !ok {
960		return 0
961	}
962	return len(l)
963}
964
965func (a *sessionAgent) SetModels(large Model, small Model) {
966	a.largeModel = large
967	a.smallModel = small
968}
969
970func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
971	a.tools = tools
972}
973
974func (a *sessionAgent) Model() Model {
975	return a.largeModel
976}