agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"github.com/charmbracelet/catwalk/pkg/catwalk"
 30	"github.com/charmbracelet/crush/internal/agent/tools"
 31	"github.com/charmbracelet/crush/internal/config"
 32	"github.com/charmbracelet/crush/internal/csync"
 33	"github.com/charmbracelet/crush/internal/message"
 34	"github.com/charmbracelet/crush/internal/permission"
 35	"github.com/charmbracelet/crush/internal/session"
 36	"github.com/charmbracelet/crush/internal/stringext"
 37)
 38
 39//go:embed templates/title.md
 40var titlePrompt []byte
 41
 42//go:embed templates/summary.md
 43var summaryPrompt []byte
 44
 45type SessionAgentCall struct {
 46	SessionID        string
 47	Prompt           string
 48	ProviderOptions  fantasy.ProviderOptions
 49	Attachments      []message.Attachment
 50	MaxOutputTokens  int64
 51	Temperature      *float64
 52	TopP             *float64
 53	TopK             *int64
 54	FrequencyPenalty *float64
 55	PresencePenalty  *float64
 56}
 57
 58type SessionAgent interface {
 59	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 60	SetModels(large Model, small Model)
 61	SetTools(tools []fantasy.AgentTool)
 62	Cancel(sessionID string)
 63	CancelAll()
 64	IsSessionBusy(sessionID string) bool
 65	IsBusy() bool
 66	QueuedPrompts(sessionID string) int
 67	ClearQueue(sessionID string)
 68	Summarize(context.Context, string, fantasy.ProviderOptions) error
 69	Model() Model
 70}
 71
 72type Model struct {
 73	Model      fantasy.LanguageModel
 74	CatwalkCfg catwalk.Model
 75	ModelCfg   config.SelectedModel
 76}
 77
 78type sessionAgent struct {
 79	largeModel           Model
 80	smallModel           Model
 81	systemPromptPrefix   string
 82	systemPrompt         string
 83	tools                []fantasy.AgentTool
 84	sessions             session.Service
 85	messages             message.Service
 86	disableAutoSummarize bool
 87	isYolo               bool
 88
 89	messageQueue   *csync.Map[string, []SessionAgentCall]
 90	activeRequests *csync.Map[string, context.CancelFunc]
 91}
 92
 93type SessionAgentOptions struct {
 94	LargeModel           Model
 95	SmallModel           Model
 96	SystemPromptPrefix   string
 97	SystemPrompt         string
 98	DisableAutoSummarize bool
 99	IsYolo               bool
100	Sessions             session.Service
101	Messages             message.Service
102	Tools                []fantasy.AgentTool
103}
104
105func NewSessionAgent(
106	opts SessionAgentOptions,
107) SessionAgent {
108	return &sessionAgent{
109		largeModel:           opts.LargeModel,
110		smallModel:           opts.SmallModel,
111		systemPromptPrefix:   opts.SystemPromptPrefix,
112		systemPrompt:         opts.SystemPrompt,
113		sessions:             opts.Sessions,
114		messages:             opts.Messages,
115		disableAutoSummarize: opts.DisableAutoSummarize,
116		tools:                opts.Tools,
117		isYolo:               opts.IsYolo,
118		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
119		activeRequests:       csync.NewMap[string, context.CancelFunc](),
120	}
121}
122
123func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
124	if call.Prompt == "" {
125		return nil, ErrEmptyPrompt
126	}
127	if call.SessionID == "" {
128		return nil, ErrSessionMissing
129	}
130
131	// Queue the message if busy
132	if a.IsSessionBusy(call.SessionID) {
133		existing, ok := a.messageQueue.Get(call.SessionID)
134		if !ok {
135			existing = []SessionAgentCall{}
136		}
137		existing = append(existing, call)
138		a.messageQueue.Set(call.SessionID, existing)
139		return nil, nil
140	}
141
142	if len(a.tools) > 0 {
143		// Add Anthropic caching to the last tool.
144		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
145	}
146
147	agent := fantasy.NewAgent(
148		a.largeModel.Model,
149		fantasy.WithSystemPrompt(a.systemPrompt),
150		fantasy.WithTools(a.tools...),
151	)
152
153	sessionLock := sync.Mutex{}
154	currentSession, err := a.sessions.Get(ctx, call.SessionID)
155	if err != nil {
156		return nil, fmt.Errorf("failed to get session: %w", err)
157	}
158
159	msgs, err := a.getSessionMessages(ctx, currentSession)
160	if err != nil {
161		return nil, fmt.Errorf("failed to get session messages: %w", err)
162	}
163
164	var wg sync.WaitGroup
165	// Generate title if first message.
166	if len(msgs) == 0 {
167		wg.Go(func() {
168			sessionLock.Lock()
169			a.generateTitle(ctx, &currentSession, call.Prompt)
170			sessionLock.Unlock()
171		})
172	}
173
174	// Add the user message to the session.
175	_, err = a.createUserMessage(ctx, call)
176	if err != nil {
177		return nil, err
178	}
179
180	// Add the session to the context.
181	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
182
183	genCtx, cancel := context.WithCancel(ctx)
184	a.activeRequests.Set(call.SessionID, cancel)
185
186	defer cancel()
187	defer a.activeRequests.Del(call.SessionID)
188
189	history, files := a.preparePrompt(msgs, call.Attachments...)
190
191	startTime := time.Now()
192	a.eventPromptSent(call.SessionID)
193
194	var currentAssistant *message.Message
195	var shouldSummarize bool
196	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
197		Prompt:           call.Prompt,
198		Files:            files,
199		Messages:         history,
200		ProviderOptions:  call.ProviderOptions,
201		MaxOutputTokens:  &call.MaxOutputTokens,
202		TopP:             call.TopP,
203		Temperature:      call.Temperature,
204		PresencePenalty:  call.PresencePenalty,
205		TopK:             call.TopK,
206		FrequencyPenalty: call.FrequencyPenalty,
207		// Before each step create a new assistant message.
208		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
209			prepared.Messages = options.Messages
210			// Reset all cached items.
211			for i := range prepared.Messages {
212				prepared.Messages[i].ProviderOptions = nil
213			}
214
215			// Use latest tools (updated by SetTools when MCP tools change).
216			prepared.Tools = a.tools
217
218			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
219			a.messageQueue.Del(call.SessionID)
220			for _, queued := range queuedCalls {
221				userMessage, createErr := a.createUserMessage(callContext, queued)
222				if createErr != nil {
223					return callContext, prepared, createErr
224				}
225				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
226			}
227
228			lastSystemRoleInx := 0
229			systemMessageUpdated := false
230			for i, msg := range prepared.Messages {
231				// Only add cache control to the last message.
232				if msg.Role == fantasy.MessageRoleSystem {
233					lastSystemRoleInx = i
234				} else if !systemMessageUpdated {
235					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
236					systemMessageUpdated = true
237				}
238				// Than add cache control to the last 2 messages.
239				if i > len(prepared.Messages)-3 {
240					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
241				}
242			}
243
244			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
245				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
246			}
247
248			var assistantMsg message.Message
249			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
250				Role:     message.Assistant,
251				Parts:    []message.ContentPart{},
252				Model:    a.largeModel.ModelCfg.Model,
253				Provider: a.largeModel.ModelCfg.Provider,
254			})
255			if err != nil {
256				return callContext, prepared, err
257			}
258			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
259			currentAssistant = &assistantMsg
260			return callContext, prepared, err
261		},
262		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
263			currentAssistant.AppendReasoningContent(reasoning.Text)
264			return a.messages.Update(genCtx, *currentAssistant)
265		},
266		OnReasoningDelta: func(id string, text string) error {
267			currentAssistant.AppendReasoningContent(text)
268			return a.messages.Update(genCtx, *currentAssistant)
269		},
270		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
271			// handle anthropic signature
272			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
273				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
274					currentAssistant.AppendReasoningSignature(reasoning.Signature)
275				}
276			}
277			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
278				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
279					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
280				}
281			}
282			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
283				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
284					currentAssistant.SetReasoningResponsesData(reasoning)
285				}
286			}
287			currentAssistant.FinishThinking()
288			return a.messages.Update(genCtx, *currentAssistant)
289		},
290		OnTextDelta: func(id string, text string) error {
291			// Strip leading newline from initial text content. This is is
292			// particularly important in non-interactive mode where leading
293			// newlines are very visible.
294			if len(currentAssistant.Parts) == 0 {
295				text = strings.TrimPrefix(text, "\n")
296			}
297
298			currentAssistant.AppendContent(text)
299			return a.messages.Update(genCtx, *currentAssistant)
300		},
301		OnToolInputStart: func(id string, toolName string) error {
302			toolCall := message.ToolCall{
303				ID:               id,
304				Name:             toolName,
305				ProviderExecuted: false,
306				Finished:         false,
307			}
308			currentAssistant.AddToolCall(toolCall)
309			return a.messages.Update(genCtx, *currentAssistant)
310		},
311		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
312			// TODO: implement
313		},
314		OnToolCall: func(tc fantasy.ToolCallContent) error {
315			toolCall := message.ToolCall{
316				ID:               tc.ToolCallID,
317				Name:             tc.ToolName,
318				Input:            tc.Input,
319				ProviderExecuted: false,
320				Finished:         true,
321			}
322			currentAssistant.AddToolCall(toolCall)
323			return a.messages.Update(genCtx, *currentAssistant)
324		},
325		OnToolResult: func(result fantasy.ToolResultContent) error {
326			var resultContent string
327			isError := false
328			switch result.Result.GetType() {
329			case fantasy.ToolResultContentTypeText:
330				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
331				if ok {
332					resultContent = r.Text
333				}
334			case fantasy.ToolResultContentTypeError:
335				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
336				if ok {
337					isError = true
338					resultContent = r.Error.Error()
339				}
340			case fantasy.ToolResultContentTypeMedia:
341				// TODO: handle this message type
342			}
343			toolResult := message.ToolResult{
344				ToolCallID: result.ToolCallID,
345				Name:       result.ToolName,
346				Content:    resultContent,
347				IsError:    isError,
348				Metadata:   result.ClientMetadata,
349			}
350			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
351				Role: message.Tool,
352				Parts: []message.ContentPart{
353					toolResult,
354				},
355			})
356			if createMsgErr != nil {
357				return createMsgErr
358			}
359			return nil
360		},
361		OnStepFinish: func(stepResult fantasy.StepResult) error {
362			finishReason := message.FinishReasonUnknown
363			switch stepResult.FinishReason {
364			case fantasy.FinishReasonLength:
365				finishReason = message.FinishReasonMaxTokens
366			case fantasy.FinishReasonStop:
367				finishReason = message.FinishReasonEndTurn
368			case fantasy.FinishReasonToolCalls:
369				finishReason = message.FinishReasonToolUse
370			}
371			currentAssistant.AddFinish(finishReason, "", "")
372			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
373			sessionLock.Lock()
374			_, sessionErr := a.sessions.Save(genCtx, currentSession)
375			sessionLock.Unlock()
376			if sessionErr != nil {
377				return sessionErr
378			}
379			return a.messages.Update(genCtx, *currentAssistant)
380		},
381		StopWhen: []fantasy.StopCondition{
382			func(_ []fantasy.StepResult) bool {
383				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
384				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
385				remaining := cw - tokens
386				var threshold int64
387				if cw > 200_000 {
388					threshold = 20_000
389				} else {
390					threshold = int64(float64(cw) * 0.2)
391				}
392				if (remaining <= threshold) && !a.disableAutoSummarize {
393					shouldSummarize = true
394					return true
395				}
396				return false
397			},
398		},
399	})
400
401	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
402
403	if err != nil {
404		isCancelErr := errors.Is(err, context.Canceled)
405		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
406		if currentAssistant == nil {
407			return result, err
408		}
409		// Ensure we finish thinking on error to close the reasoning state.
410		currentAssistant.FinishThinking()
411		toolCalls := currentAssistant.ToolCalls()
412		// INFO: we use the parent context here because the genCtx has been cancelled.
413		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
414		if createErr != nil {
415			return nil, createErr
416		}
417		for _, tc := range toolCalls {
418			if !tc.Finished {
419				tc.Finished = true
420				tc.Input = "{}"
421				currentAssistant.AddToolCall(tc)
422				updateErr := a.messages.Update(ctx, *currentAssistant)
423				if updateErr != nil {
424					return nil, updateErr
425				}
426			}
427
428			found := false
429			for _, msg := range msgs {
430				if msg.Role == message.Tool {
431					for _, tr := range msg.ToolResults() {
432						if tr.ToolCallID == tc.ID {
433							found = true
434							break
435						}
436					}
437				}
438				if found {
439					break
440				}
441			}
442			if found {
443				continue
444			}
445			content := "There was an error while executing the tool"
446			if isCancelErr {
447				content = "Tool execution canceled by user"
448			} else if isPermissionErr {
449				content = "User denied permission"
450			}
451			toolResult := message.ToolResult{
452				ToolCallID: tc.ID,
453				Name:       tc.Name,
454				Content:    content,
455				IsError:    true,
456			}
457			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
458				Role: message.Tool,
459				Parts: []message.ContentPart{
460					toolResult,
461				},
462			})
463			if createErr != nil {
464				return nil, createErr
465			}
466		}
467		var fantasyErr *fantasy.Error
468		var providerErr *fantasy.ProviderError
469		const defaultTitle = "Provider Error"
470		if isCancelErr {
471			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
472		} else if isPermissionErr {
473			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
474		} else if errors.As(err, &providerErr) {
475			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
476		} else if errors.As(err, &fantasyErr) {
477			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
478		} else {
479			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
480		}
481		// Note: we use the parent context here because the genCtx has been
482		// cancelled.
483		updateErr := a.messages.Update(ctx, *currentAssistant)
484		if updateErr != nil {
485			return nil, updateErr
486		}
487		return nil, err
488	}
489	wg.Wait()
490
491	if shouldSummarize {
492		a.activeRequests.Del(call.SessionID)
493		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
494			return nil, summarizeErr
495		}
496		// If the agent wasn't done...
497		if len(currentAssistant.ToolCalls()) > 0 {
498			existing, ok := a.messageQueue.Get(call.SessionID)
499			if !ok {
500				existing = []SessionAgentCall{}
501			}
502			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
503			existing = append(existing, call)
504			a.messageQueue.Set(call.SessionID, existing)
505		}
506	}
507
508	// Release active request before processing queued messages.
509	a.activeRequests.Del(call.SessionID)
510	cancel()
511
512	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
513	if !ok || len(queuedMessages) == 0 {
514		return result, err
515	}
516	// There are queued messages restart the loop.
517	firstQueuedMessage := queuedMessages[0]
518	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
519	return a.Run(ctx, firstQueuedMessage)
520}
521
522func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
523	if a.IsSessionBusy(sessionID) {
524		return ErrSessionBusy
525	}
526
527	currentSession, err := a.sessions.Get(ctx, sessionID)
528	if err != nil {
529		return fmt.Errorf("failed to get session: %w", err)
530	}
531	msgs, err := a.getSessionMessages(ctx, currentSession)
532	if err != nil {
533		return err
534	}
535	if len(msgs) == 0 {
536		// Nothing to summarize.
537		return nil
538	}
539
540	aiMsgs, _ := a.preparePrompt(msgs)
541
542	genCtx, cancel := context.WithCancel(ctx)
543	a.activeRequests.Set(sessionID, cancel)
544	defer a.activeRequests.Del(sessionID)
545	defer cancel()
546
547	agent := fantasy.NewAgent(a.largeModel.Model,
548		fantasy.WithSystemPrompt(string(summaryPrompt)),
549	)
550	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
551		Role:             message.Assistant,
552		Model:            a.largeModel.Model.Model(),
553		Provider:         a.largeModel.Model.Provider(),
554		IsSummaryMessage: true,
555	})
556	if err != nil {
557		return err
558	}
559
560	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
561		Prompt:          "Provide a detailed summary of our conversation above.",
562		Messages:        aiMsgs,
563		ProviderOptions: opts,
564		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
565			prepared.Messages = options.Messages
566			if a.systemPromptPrefix != "" {
567				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
568			}
569			return callContext, prepared, nil
570		},
571		OnReasoningDelta: func(id string, text string) error {
572			summaryMessage.AppendReasoningContent(text)
573			return a.messages.Update(genCtx, summaryMessage)
574		},
575		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
576			// Handle anthropic signature.
577			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
578				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
579					summaryMessage.AppendReasoningSignature(signature.Signature)
580				}
581			}
582			summaryMessage.FinishThinking()
583			return a.messages.Update(genCtx, summaryMessage)
584		},
585		OnTextDelta: func(id, text string) error {
586			summaryMessage.AppendContent(text)
587			return a.messages.Update(genCtx, summaryMessage)
588		},
589	})
590	if err != nil {
591		isCancelErr := errors.Is(err, context.Canceled)
592		if isCancelErr {
593			// User cancelled summarize we need to remove the summary message.
594			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
595			return deleteErr
596		}
597		return err
598	}
599
600	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
601	err = a.messages.Update(genCtx, summaryMessage)
602	if err != nil {
603		return err
604	}
605
606	var openrouterCost *float64
607	for _, step := range resp.Steps {
608		stepCost := a.openrouterCost(step.ProviderMetadata)
609		if stepCost != nil {
610			newCost := *stepCost
611			if openrouterCost != nil {
612				newCost += *openrouterCost
613			}
614			openrouterCost = &newCost
615		}
616	}
617
618	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
619
620	// Just in case, get just the last usage info.
621	usage := resp.Response.Usage
622	currentSession.SummaryMessageID = summaryMessage.ID
623	currentSession.CompletionTokens = usage.OutputTokens
624	currentSession.PromptTokens = 0
625	_, err = a.sessions.Save(genCtx, currentSession)
626	return err
627}
628
629func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
630	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
631		return fantasy.ProviderOptions{}
632	}
633	return fantasy.ProviderOptions{
634		anthropic.Name: &anthropic.ProviderCacheControlOptions{
635			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
636		},
637		bedrock.Name: &anthropic.ProviderCacheControlOptions{
638			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
639		},
640	}
641}
642
643func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
644	var attachmentParts []message.ContentPart
645	for _, attachment := range call.Attachments {
646		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
647	}
648	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
649	parts = append(parts, attachmentParts...)
650	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
651		Role:  message.User,
652		Parts: parts,
653	})
654	if err != nil {
655		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
656	}
657	return msg, nil
658}
659
660func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
661	var history []fantasy.Message
662	for _, m := range msgs {
663		if len(m.Parts) == 0 {
664			continue
665		}
666		// Assistant message without content or tool calls (cancelled before it
667		// returned anything).
668		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
669			continue
670		}
671		history = append(history, m.ToAIMessage()...)
672	}
673
674	var files []fantasy.FilePart
675	for _, attachment := range attachments {
676		files = append(files, fantasy.FilePart{
677			Filename:  attachment.FileName,
678			Data:      attachment.Content,
679			MediaType: attachment.MimeType,
680		})
681	}
682
683	return history, files
684}
685
686func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
687	msgs, err := a.messages.List(ctx, session.ID)
688	if err != nil {
689		return nil, fmt.Errorf("failed to list messages: %w", err)
690	}
691
692	if session.SummaryMessageID != "" {
693		summaryMsgInex := -1
694		for i, msg := range msgs {
695			if msg.ID == session.SummaryMessageID {
696				summaryMsgInex = i
697				break
698			}
699		}
700		if summaryMsgInex != -1 {
701			msgs = msgs[summaryMsgInex:]
702			msgs[0].Role = message.User
703		}
704	}
705	return msgs, nil
706}
707
708func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
709	if prompt == "" {
710		return
711	}
712
713	var maxOutput int64 = 40
714	if a.smallModel.CatwalkCfg.CanReason {
715		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
716	}
717
718	agent := fantasy.NewAgent(a.smallModel.Model,
719		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
720		fantasy.WithMaxOutputTokens(maxOutput),
721	)
722
723	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
724		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
725		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
726			prepared.Messages = options.Messages
727			if a.systemPromptPrefix != "" {
728				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
729			}
730			return callContext, prepared, nil
731		},
732	})
733	if err != nil {
734		slog.Error("error generating title", "err", err)
735		return
736	}
737
738	title := resp.Response.Content.Text()
739
740	title = strings.ReplaceAll(title, "\n", " ")
741
742	// Remove thinking tags if present.
743	if idx := strings.Index(title, "</think>"); idx > 0 {
744		title = title[idx+len("</think>"):]
745	}
746
747	title = strings.TrimSpace(title)
748	if title == "" {
749		slog.Warn("failed to generate title", "warn", "empty title")
750		return
751	}
752
753	session.Title = title
754
755	var openrouterCost *float64
756	for _, step := range resp.Steps {
757		stepCost := a.openrouterCost(step.ProviderMetadata)
758		if stepCost != nil {
759			newCost := *stepCost
760			if openrouterCost != nil {
761				newCost += *openrouterCost
762			}
763			openrouterCost = &newCost
764		}
765	}
766
767	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
768	_, saveErr := a.sessions.Save(ctx, *session)
769	if saveErr != nil {
770		slog.Error("failed to save session title & usage", "error", saveErr)
771		return
772	}
773}
774
775func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
776	openrouterMetadata, ok := metadata[openrouter.Name]
777	if !ok {
778		return nil
779	}
780
781	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
782	if !ok {
783		return nil
784	}
785	return &opts.Usage.Cost
786}
787
788func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
789	modelConfig := model.CatwalkCfg
790	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
791		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
792		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
793		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
794
795	if a.isClaudeCode() {
796		cost = 0
797	}
798
799	a.eventTokensUsed(session.ID, model, usage, cost)
800
801	if overrideCost != nil {
802		session.Cost += *overrideCost
803	} else {
804		session.Cost += cost
805	}
806
807	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
808	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
809}
810
811func (a *sessionAgent) Cancel(sessionID string) {
812	// Cancel regular requests.
813	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
814		slog.Info("Request cancellation initiated", "session_id", sessionID)
815		cancel()
816	}
817
818	// Also check for summarize requests.
819	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
820		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
821		cancel()
822	}
823
824	if a.QueuedPrompts(sessionID) > 0 {
825		slog.Info("Clearing queued prompts", "session_id", sessionID)
826		a.messageQueue.Del(sessionID)
827	}
828}
829
830func (a *sessionAgent) ClearQueue(sessionID string) {
831	if a.QueuedPrompts(sessionID) > 0 {
832		slog.Info("Clearing queued prompts", "session_id", sessionID)
833		a.messageQueue.Del(sessionID)
834	}
835}
836
837func (a *sessionAgent) CancelAll() {
838	if !a.IsBusy() {
839		return
840	}
841	for key := range a.activeRequests.Seq2() {
842		a.Cancel(key) // key is sessionID
843	}
844
845	timeout := time.After(5 * time.Second)
846	for a.IsBusy() {
847		select {
848		case <-timeout:
849			return
850		default:
851			time.Sleep(200 * time.Millisecond)
852		}
853	}
854}
855
856func (a *sessionAgent) IsBusy() bool {
857	var busy bool
858	for cancelFunc := range a.activeRequests.Seq() {
859		if cancelFunc != nil {
860			busy = true
861			break
862		}
863	}
864	return busy
865}
866
867func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
868	_, busy := a.activeRequests.Get(sessionID)
869	return busy
870}
871
872func (a *sessionAgent) QueuedPrompts(sessionID string) int {
873	l, ok := a.messageQueue.Get(sessionID)
874	if !ok {
875		return 0
876	}
877	return len(l)
878}
879
880func (a *sessionAgent) SetModels(large Model, small Model) {
881	a.largeModel = large
882	a.smallModel = small
883}
884
885func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
886	a.tools = tools
887}
888
889func (a *sessionAgent) Model() Model {
890	return a.largeModel
891}
892
893func (a *sessionAgent) promptPrefix() string {
894	if a.isClaudeCode() {
895		return "You are Claude Code, Anthropic's official CLI for Claude."
896	}
897	return a.systemPromptPrefix
898}
899
900func (a *sessionAgent) isClaudeCode() bool {
901	cfg := config.Get()
902	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
903	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
904}