agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"git.secluded.site/crush/internal/agent/tools"
 30	"git.secluded.site/crush/internal/config"
 31	"git.secluded.site/crush/internal/csync"
 32	"git.secluded.site/crush/internal/message"
 33	"git.secluded.site/crush/internal/notification"
 34	"git.secluded.site/crush/internal/permission"
 35	"git.secluded.site/crush/internal/session"
 36	"git.secluded.site/crush/internal/stringext"
 37	"github.com/charmbracelet/catwalk/pkg/catwalk"
 38)
 39
 40//go:embed templates/title.md
 41var titlePrompt []byte
 42
 43//go:embed templates/summary.md
 44var summaryPrompt []byte
 45
 46type SessionAgentCall struct {
 47	SessionID        string
 48	Prompt           string
 49	ProviderOptions  fantasy.ProviderOptions
 50	Attachments      []message.Attachment
 51	MaxOutputTokens  int64
 52	Temperature      *float64
 53	TopP             *float64
 54	TopK             *int64
 55	FrequencyPenalty *float64
 56	PresencePenalty  *float64
 57	NonInteractive   bool
 58}
 59
 60type SessionAgent interface {
 61	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 62	SetModels(large Model, small Model)
 63	SetTools(tools []fantasy.AgentTool)
 64	Cancel(sessionID string)
 65	CancelAll()
 66	IsSessionBusy(sessionID string) bool
 67	IsBusy() bool
 68	QueuedPrompts(sessionID string) int
 69	ClearQueue(sessionID string)
 70	Summarize(context.Context, string, fantasy.ProviderOptions) error
 71	Model() Model
 72}
 73
 74type Model struct {
 75	Model      fantasy.LanguageModel
 76	CatwalkCfg catwalk.Model
 77	ModelCfg   config.SelectedModel
 78}
 79
 80type sessionAgent struct {
 81	largeModel           Model
 82	smallModel           Model
 83	systemPromptPrefix   string
 84	systemPrompt         string
 85	tools                []fantasy.AgentTool
 86	sessions             session.Service
 87	messages             message.Service
 88	disableAutoSummarize bool
 89	isYolo               bool
 90
 91	messageQueue   *csync.Map[string, []SessionAgentCall]
 92	activeRequests *csync.Map[string, context.CancelFunc]
 93}
 94
 95type SessionAgentOptions struct {
 96	LargeModel           Model
 97	SmallModel           Model
 98	SystemPromptPrefix   string
 99	SystemPrompt         string
100	DisableAutoSummarize bool
101	IsYolo               bool
102	Sessions             session.Service
103	Messages             message.Service
104	Tools                []fantasy.AgentTool
105}
106
107func NewSessionAgent(
108	opts SessionAgentOptions,
109) SessionAgent {
110	return &sessionAgent{
111		largeModel:           opts.LargeModel,
112		smallModel:           opts.SmallModel,
113		systemPromptPrefix:   opts.SystemPromptPrefix,
114		systemPrompt:         opts.SystemPrompt,
115		sessions:             opts.Sessions,
116		messages:             opts.Messages,
117		disableAutoSummarize: opts.DisableAutoSummarize,
118		tools:                opts.Tools,
119		isYolo:               opts.IsYolo,
120		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
121		activeRequests:       csync.NewMap[string, context.CancelFunc](),
122	}
123}
124
125func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
126	if call.Prompt == "" {
127		return nil, ErrEmptyPrompt
128	}
129	if call.SessionID == "" {
130		return nil, ErrSessionMissing
131	}
132
133	// Queue the message if busy
134	if a.IsSessionBusy(call.SessionID) {
135		existing, ok := a.messageQueue.Get(call.SessionID)
136		if !ok {
137			existing = []SessionAgentCall{}
138		}
139		existing = append(existing, call)
140		a.messageQueue.Set(call.SessionID, existing)
141		return nil, nil
142	}
143
144	if len(a.tools) > 0 {
145		// Add Anthropic caching to the last tool.
146		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
147	}
148
149	agent := fantasy.NewAgent(
150		a.largeModel.Model,
151		fantasy.WithSystemPrompt(a.systemPrompt),
152		fantasy.WithTools(a.tools...),
153	)
154
155	sessionLock := sync.Mutex{}
156	currentSession, err := a.sessions.Get(ctx, call.SessionID)
157	if err != nil {
158		return nil, fmt.Errorf("failed to get session: %w", err)
159	}
160
161	msgs, err := a.getSessionMessages(ctx, currentSession)
162	if err != nil {
163		return nil, fmt.Errorf("failed to get session messages: %w", err)
164	}
165
166	var wg sync.WaitGroup
167	// Generate title if first message.
168	if len(msgs) == 0 {
169		wg.Go(func() {
170			sessionLock.Lock()
171			a.generateTitle(ctx, &currentSession, call.Prompt)
172			sessionLock.Unlock()
173		})
174	}
175
176	// Add the user message to the session.
177	_, err = a.createUserMessage(ctx, call)
178	if err != nil {
179		return nil, err
180	}
181
182	// Add the session to the context.
183	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
184
185	genCtx, cancel := context.WithCancel(ctx)
186	a.activeRequests.Set(call.SessionID, cancel)
187
188	defer cancel()
189	defer a.activeRequests.Del(call.SessionID)
190
191	history, files := a.preparePrompt(msgs, call.Attachments...)
192
193	startTime := time.Now()
194	a.eventPromptSent(call.SessionID)
195
196	var currentAssistant *message.Message
197	var shouldSummarize bool
198	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
199		Prompt:           call.Prompt,
200		Files:            files,
201		Messages:         history,
202		ProviderOptions:  call.ProviderOptions,
203		MaxOutputTokens:  &call.MaxOutputTokens,
204		TopP:             call.TopP,
205		Temperature:      call.Temperature,
206		PresencePenalty:  call.PresencePenalty,
207		TopK:             call.TopK,
208		FrequencyPenalty: call.FrequencyPenalty,
209		// Before each step create a new assistant message.
210		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
211			prepared.Messages = options.Messages
212			// Reset all cached items.
213			for i := range prepared.Messages {
214				prepared.Messages[i].ProviderOptions = nil
215			}
216
217			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
218			a.messageQueue.Del(call.SessionID)
219			for _, queued := range queuedCalls {
220				userMessage, createErr := a.createUserMessage(callContext, queued)
221				if createErr != nil {
222					return callContext, prepared, createErr
223				}
224				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
225			}
226
227			lastSystemRoleInx := 0
228			systemMessageUpdated := false
229			for i, msg := range prepared.Messages {
230				// Only add cache control to the last message.
231				if msg.Role == fantasy.MessageRoleSystem {
232					lastSystemRoleInx = i
233				} else if !systemMessageUpdated {
234					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
235					systemMessageUpdated = true
236				}
237				// Than add cache control to the last 2 messages.
238				if i > len(prepared.Messages)-3 {
239					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
240				}
241			}
242
243			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
244				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
245			}
246
247			var assistantMsg message.Message
248			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
249				Role:     message.Assistant,
250				Parts:    []message.ContentPart{},
251				Model:    a.largeModel.ModelCfg.Model,
252				Provider: a.largeModel.ModelCfg.Provider,
253			})
254			if err != nil {
255				return callContext, prepared, err
256			}
257			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
258			currentAssistant = &assistantMsg
259			return callContext, prepared, err
260		},
261		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
262			currentAssistant.AppendReasoningContent(reasoning.Text)
263			return a.messages.Update(genCtx, *currentAssistant)
264		},
265		OnReasoningDelta: func(id string, text string) error {
266			currentAssistant.AppendReasoningContent(text)
267			return a.messages.Update(genCtx, *currentAssistant)
268		},
269		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
270			// handle anthropic signature
271			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
272				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
273					currentAssistant.AppendReasoningSignature(reasoning.Signature)
274				}
275			}
276			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
277				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
278					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
279				}
280			}
281			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
282				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
283					currentAssistant.SetReasoningResponsesData(reasoning)
284				}
285			}
286			currentAssistant.FinishThinking()
287			return a.messages.Update(genCtx, *currentAssistant)
288		},
289		OnTextDelta: func(id string, text string) error {
290			// Strip leading newline from initial text content. This is is
291			// particularly important in non-interactive mode where leading
292			// newlines are very visible.
293			if len(currentAssistant.Parts) == 0 {
294				text = strings.TrimPrefix(text, "\n")
295			}
296
297			currentAssistant.AppendContent(text)
298			return a.messages.Update(genCtx, *currentAssistant)
299		},
300		OnToolInputStart: func(id string, toolName string) error {
301			toolCall := message.ToolCall{
302				ID:               id,
303				Name:             toolName,
304				ProviderExecuted: false,
305				Finished:         false,
306			}
307			currentAssistant.AddToolCall(toolCall)
308			return a.messages.Update(genCtx, *currentAssistant)
309		},
310		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
311			// TODO: implement
312		},
313		OnToolCall: func(tc fantasy.ToolCallContent) error {
314			toolCall := message.ToolCall{
315				ID:               tc.ToolCallID,
316				Name:             tc.ToolName,
317				Input:            tc.Input,
318				ProviderExecuted: false,
319				Finished:         true,
320			}
321			currentAssistant.AddToolCall(toolCall)
322			return a.messages.Update(genCtx, *currentAssistant)
323		},
324		OnToolResult: func(result fantasy.ToolResultContent) error {
325			var resultContent string
326			isError := false
327			switch result.Result.GetType() {
328			case fantasy.ToolResultContentTypeText:
329				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
330				if ok {
331					resultContent = r.Text
332				}
333			case fantasy.ToolResultContentTypeError:
334				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
335				if ok {
336					isError = true
337					resultContent = r.Error.Error()
338				}
339			case fantasy.ToolResultContentTypeMedia:
340				// TODO: handle this message type
341			}
342			toolResult := message.ToolResult{
343				ToolCallID: result.ToolCallID,
344				Name:       result.ToolName,
345				Content:    resultContent,
346				IsError:    isError,
347				Metadata:   result.ClientMetadata,
348			}
349			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
350				Role: message.Tool,
351				Parts: []message.ContentPart{
352					toolResult,
353				},
354			})
355			if createMsgErr != nil {
356				return createMsgErr
357			}
358			return nil
359		},
360		OnStepFinish: func(stepResult fantasy.StepResult) error {
361			finishReason := message.FinishReasonUnknown
362			switch stepResult.FinishReason {
363			case fantasy.FinishReasonLength:
364				finishReason = message.FinishReasonMaxTokens
365			case fantasy.FinishReasonStop:
366				finishReason = message.FinishReasonEndTurn
367			case fantasy.FinishReasonToolCalls:
368				finishReason = message.FinishReasonToolUse
369			}
370			currentAssistant.AddFinish(finishReason, "", "")
371			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
372			sessionLock.Lock()
373			_, sessionErr := a.sessions.Save(genCtx, currentSession)
374			sessionLock.Unlock()
375			if sessionErr != nil {
376				return sessionErr
377			}
378			return a.messages.Update(genCtx, *currentAssistant)
379		},
380		StopWhen: []fantasy.StopCondition{
381			func(_ []fantasy.StepResult) bool {
382				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
383				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
384				remaining := cw - tokens
385				var threshold int64
386				if cw > 200_000 {
387					threshold = 20_000
388				} else {
389					threshold = int64(float64(cw) * 0.2)
390				}
391				if (remaining <= threshold) && !a.disableAutoSummarize {
392					shouldSummarize = true
393					return true
394				}
395				return false
396			},
397		},
398	})
399
400	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
401
402	if err != nil {
403		isCancelErr := errors.Is(err, context.Canceled)
404		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
405		if currentAssistant == nil {
406			return result, err
407		}
408		// Ensure we finish thinking on error to close the reasoning state.
409		currentAssistant.FinishThinking()
410		toolCalls := currentAssistant.ToolCalls()
411		// INFO: we use the parent context here because the genCtx has been cancelled.
412		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
413		if createErr != nil {
414			return nil, createErr
415		}
416		for _, tc := range toolCalls {
417			if !tc.Finished {
418				tc.Finished = true
419				tc.Input = "{}"
420				currentAssistant.AddToolCall(tc)
421				updateErr := a.messages.Update(ctx, *currentAssistant)
422				if updateErr != nil {
423					return nil, updateErr
424				}
425			}
426
427			found := false
428			for _, msg := range msgs {
429				if msg.Role == message.Tool {
430					for _, tr := range msg.ToolResults() {
431						if tr.ToolCallID == tc.ID {
432							found = true
433							break
434						}
435					}
436				}
437				if found {
438					break
439				}
440			}
441			if found {
442				continue
443			}
444			content := "There was an error while executing the tool"
445			if isCancelErr {
446				content = "Tool execution canceled by user"
447			} else if isPermissionErr {
448				content = "User denied permission"
449			}
450			toolResult := message.ToolResult{
451				ToolCallID: tc.ID,
452				Name:       tc.Name,
453				Content:    content,
454				IsError:    true,
455			}
456			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
457				Role: message.Tool,
458				Parts: []message.ContentPart{
459					toolResult,
460				},
461			})
462			if createErr != nil {
463				return nil, createErr
464			}
465		}
466		var fantasyErr *fantasy.Error
467		var providerErr *fantasy.ProviderError
468		const defaultTitle = "Provider Error"
469		if isCancelErr {
470			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
471		} else if isPermissionErr {
472			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
473		} else if errors.As(err, &providerErr) {
474			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
475		} else if errors.As(err, &fantasyErr) {
476			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
477		} else {
478			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
479		}
480		// Note: we use the parent context here because the genCtx has been
481		// cancelled.
482		updateErr := a.messages.Update(ctx, *currentAssistant)
483		if updateErr != nil {
484			return nil, updateErr
485		}
486		return nil, err
487	}
488	wg.Wait()
489
490	// Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
491	if !call.NonInteractive {
492		notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
493		_ = notification.Send("Crush is waiting...", notifBody)
494	}
495
496	if shouldSummarize {
497		a.activeRequests.Del(call.SessionID)
498		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
499			return nil, summarizeErr
500		}
501		// If the agent wasn't done...
502		if len(currentAssistant.ToolCalls()) > 0 {
503			existing, ok := a.messageQueue.Get(call.SessionID)
504			if !ok {
505				existing = []SessionAgentCall{}
506			}
507			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
508			existing = append(existing, call)
509			a.messageQueue.Set(call.SessionID, existing)
510		}
511	}
512
513	// Release active request before processing queued messages.
514	a.activeRequests.Del(call.SessionID)
515	cancel()
516
517	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
518	if !ok || len(queuedMessages) == 0 {
519		return result, err
520	}
521	// There are queued messages restart the loop.
522	firstQueuedMessage := queuedMessages[0]
523	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
524	return a.Run(ctx, firstQueuedMessage)
525}
526
527func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
528	if a.IsSessionBusy(sessionID) {
529		return ErrSessionBusy
530	}
531
532	currentSession, err := a.sessions.Get(ctx, sessionID)
533	if err != nil {
534		return fmt.Errorf("failed to get session: %w", err)
535	}
536	msgs, err := a.getSessionMessages(ctx, currentSession)
537	if err != nil {
538		return err
539	}
540	if len(msgs) == 0 {
541		// Nothing to summarize.
542		return nil
543	}
544
545	aiMsgs, _ := a.preparePrompt(msgs)
546
547	genCtx, cancel := context.WithCancel(ctx)
548	a.activeRequests.Set(sessionID, cancel)
549	defer a.activeRequests.Del(sessionID)
550	defer cancel()
551
552	agent := fantasy.NewAgent(a.largeModel.Model,
553		fantasy.WithSystemPrompt(string(summaryPrompt)),
554	)
555	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
556		Role:             message.Assistant,
557		Model:            a.largeModel.Model.Model(),
558		Provider:         a.largeModel.Model.Provider(),
559		IsSummaryMessage: true,
560	})
561	if err != nil {
562		return err
563	}
564
565	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
566		Prompt:          "Provide a detailed summary of our conversation above.",
567		Messages:        aiMsgs,
568		ProviderOptions: opts,
569		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
570			prepared.Messages = options.Messages
571			if a.systemPromptPrefix != "" {
572				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
573			}
574			return callContext, prepared, nil
575		},
576		OnReasoningDelta: func(id string, text string) error {
577			summaryMessage.AppendReasoningContent(text)
578			return a.messages.Update(genCtx, summaryMessage)
579		},
580		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
581			// Handle anthropic signature.
582			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
583				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
584					summaryMessage.AppendReasoningSignature(signature.Signature)
585				}
586			}
587			summaryMessage.FinishThinking()
588			return a.messages.Update(genCtx, summaryMessage)
589		},
590		OnTextDelta: func(id, text string) error {
591			summaryMessage.AppendContent(text)
592			return a.messages.Update(genCtx, summaryMessage)
593		},
594	})
595	if err != nil {
596		isCancelErr := errors.Is(err, context.Canceled)
597		if isCancelErr {
598			// User cancelled summarize we need to remove the summary message.
599			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
600			return deleteErr
601		}
602		return err
603	}
604
605	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
606	err = a.messages.Update(genCtx, summaryMessage)
607	if err != nil {
608		return err
609	}
610
611	var openrouterCost *float64
612	for _, step := range resp.Steps {
613		stepCost := a.openrouterCost(step.ProviderMetadata)
614		if stepCost != nil {
615			newCost := *stepCost
616			if openrouterCost != nil {
617				newCost += *openrouterCost
618			}
619			openrouterCost = &newCost
620		}
621	}
622
623	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
624
625	// Just in case, get just the last usage info.
626	usage := resp.Response.Usage
627	currentSession.SummaryMessageID = summaryMessage.ID
628	currentSession.CompletionTokens = usage.OutputTokens
629	currentSession.PromptTokens = 0
630	_, err = a.sessions.Save(genCtx, currentSession)
631	return err
632}
633
634func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
635	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
636		return fantasy.ProviderOptions{}
637	}
638	return fantasy.ProviderOptions{
639		anthropic.Name: &anthropic.ProviderCacheControlOptions{
640			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
641		},
642		bedrock.Name: &anthropic.ProviderCacheControlOptions{
643			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
644		},
645	}
646}
647
648func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
649	var attachmentParts []message.ContentPart
650	for _, attachment := range call.Attachments {
651		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
652	}
653	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
654	parts = append(parts, attachmentParts...)
655	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
656		Role:  message.User,
657		Parts: parts,
658	})
659	if err != nil {
660		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
661	}
662	return msg, nil
663}
664
665func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
666	var history []fantasy.Message
667	for _, m := range msgs {
668		if len(m.Parts) == 0 {
669			continue
670		}
671		// Assistant message without content or tool calls (cancelled before it
672		// returned anything).
673		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
674			continue
675		}
676		history = append(history, m.ToAIMessage()...)
677	}
678
679	var files []fantasy.FilePart
680	for _, attachment := range attachments {
681		files = append(files, fantasy.FilePart{
682			Filename:  attachment.FileName,
683			Data:      attachment.Content,
684			MediaType: attachment.MimeType,
685		})
686	}
687
688	return history, files
689}
690
691func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
692	msgs, err := a.messages.List(ctx, session.ID)
693	if err != nil {
694		return nil, fmt.Errorf("failed to list messages: %w", err)
695	}
696
697	if session.SummaryMessageID != "" {
698		summaryMsgInex := -1
699		for i, msg := range msgs {
700			if msg.ID == session.SummaryMessageID {
701				summaryMsgInex = i
702				break
703			}
704		}
705		if summaryMsgInex != -1 {
706			msgs = msgs[summaryMsgInex:]
707			msgs[0].Role = message.User
708		}
709	}
710	return msgs, nil
711}
712
713func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
714	if prompt == "" {
715		return
716	}
717
718	var maxOutput int64 = 40
719	if a.smallModel.CatwalkCfg.CanReason {
720		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
721	}
722
723	agent := fantasy.NewAgent(a.smallModel.Model,
724		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
725		fantasy.WithMaxOutputTokens(maxOutput),
726	)
727
728	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
729		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
730		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
731			prepared.Messages = options.Messages
732			if a.systemPromptPrefix != "" {
733				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
734			}
735			return callContext, prepared, nil
736		},
737	})
738	if err != nil {
739		slog.Error("error generating title", "err", err)
740		return
741	}
742
743	title := resp.Response.Content.Text()
744
745	title = strings.ReplaceAll(title, "\n", " ")
746
747	// Remove thinking tags if present.
748	if idx := strings.Index(title, "</think>"); idx > 0 {
749		title = title[idx+len("</think>"):]
750	}
751
752	title = strings.TrimSpace(title)
753	if title == "" {
754		slog.Warn("failed to generate title", "warn", "empty title")
755		return
756	}
757
758	session.Title = title
759
760	var openrouterCost *float64
761	for _, step := range resp.Steps {
762		stepCost := a.openrouterCost(step.ProviderMetadata)
763		if stepCost != nil {
764			newCost := *stepCost
765			if openrouterCost != nil {
766				newCost += *openrouterCost
767			}
768			openrouterCost = &newCost
769		}
770	}
771
772	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
773	_, saveErr := a.sessions.Save(ctx, *session)
774	if saveErr != nil {
775		slog.Error("failed to save session title & usage", "error", saveErr)
776		return
777	}
778}
779
780func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
781	openrouterMetadata, ok := metadata[openrouter.Name]
782	if !ok {
783		return nil
784	}
785
786	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
787	if !ok {
788		return nil
789	}
790	return &opts.Usage.Cost
791}
792
793func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
794	modelConfig := model.CatwalkCfg
795	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
796		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
797		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
798		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
799
800	if a.isClaudeCode() {
801		cost = 0
802	}
803
804	a.eventTokensUsed(session.ID, model, usage, cost)
805
806	if overrideCost != nil {
807		session.Cost += *overrideCost
808	} else {
809		session.Cost += cost
810	}
811
812	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
813	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
814}
815
816func (a *sessionAgent) Cancel(sessionID string) {
817	// Cancel regular requests.
818	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
819		slog.Info("Request cancellation initiated", "session_id", sessionID)
820		cancel()
821	}
822
823	// Also check for summarize requests.
824	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
825		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
826		cancel()
827	}
828
829	if a.QueuedPrompts(sessionID) > 0 {
830		slog.Info("Clearing queued prompts", "session_id", sessionID)
831		a.messageQueue.Del(sessionID)
832	}
833}
834
835func (a *sessionAgent) ClearQueue(sessionID string) {
836	if a.QueuedPrompts(sessionID) > 0 {
837		slog.Info("Clearing queued prompts", "session_id", sessionID)
838		a.messageQueue.Del(sessionID)
839	}
840}
841
842func (a *sessionAgent) CancelAll() {
843	if !a.IsBusy() {
844		return
845	}
846	for key := range a.activeRequests.Seq2() {
847		a.Cancel(key) // key is sessionID
848	}
849
850	timeout := time.After(5 * time.Second)
851	for a.IsBusy() {
852		select {
853		case <-timeout:
854			return
855		default:
856			time.Sleep(200 * time.Millisecond)
857		}
858	}
859}
860
861func (a *sessionAgent) IsBusy() bool {
862	var busy bool
863	for cancelFunc := range a.activeRequests.Seq() {
864		if cancelFunc != nil {
865			busy = true
866			break
867		}
868	}
869	return busy
870}
871
872func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
873	_, busy := a.activeRequests.Get(sessionID)
874	return busy
875}
876
877func (a *sessionAgent) QueuedPrompts(sessionID string) int {
878	l, ok := a.messageQueue.Get(sessionID)
879	if !ok {
880		return 0
881	}
882	return len(l)
883}
884
885func (a *sessionAgent) SetModels(large Model, small Model) {
886	a.largeModel = large
887	a.smallModel = small
888}
889
890func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
891	a.tools = tools
892}
893
894func (a *sessionAgent) Model() Model {
895	return a.largeModel
896}
897
898func (a *sessionAgent) promptPrefix() string {
899	if a.isClaudeCode() {
900		return "You are Claude Code, Anthropic's official CLI for Claude."
901	}
902	return a.systemPromptPrefix
903}
904
905func (a *sessionAgent) isClaudeCode() bool {
906	cfg := config.Get()
907	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
908	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
909}