agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"github.com/charmbracelet/catwalk/pkg/catwalk"
 30	"github.com/charmbracelet/crush/internal/agent/tools"
 31	"github.com/charmbracelet/crush/internal/config"
 32	"github.com/charmbracelet/crush/internal/csync"
 33	"github.com/charmbracelet/crush/internal/message"
 34	"github.com/charmbracelet/crush/internal/permission"
 35	"github.com/charmbracelet/crush/internal/session"
 36	"github.com/charmbracelet/crush/internal/stringext"
 37)
 38
 39//go:embed templates/title.md
 40var titlePrompt []byte
 41
 42//go:embed templates/summary.md
 43var summaryPrompt []byte
 44
 45type SessionAgentCall struct {
 46	SessionID        string
 47	Prompt           string
 48	ProviderOptions  fantasy.ProviderOptions
 49	Attachments      []message.Attachment
 50	MaxOutputTokens  int64
 51	Temperature      *float64
 52	TopP             *float64
 53	TopK             *int64
 54	FrequencyPenalty *float64
 55	PresencePenalty  *float64
 56}
 57
 58type SessionAgent interface {
 59	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 60	SetModels(large Model, small Model)
 61	SetTools(tools []fantasy.AgentTool)
 62	Cancel(sessionID string)
 63	CancelAll()
 64	IsSessionBusy(sessionID string) bool
 65	IsBusy() bool
 66	QueuedPrompts(sessionID string) int
 67	ClearQueue(sessionID string)
 68	Summarize(context.Context, string, fantasy.ProviderOptions) error
 69	Model() Model
 70}
 71
 72type Model struct {
 73	Model      fantasy.LanguageModel
 74	CatwalkCfg catwalk.Model
 75	ModelCfg   config.SelectedModel
 76}
 77
 78type sessionAgent struct {
 79	largeModel           Model
 80	smallModel           Model
 81	systemPromptPrefix   string
 82	systemPrompt         string
 83	tools                []fantasy.AgentTool
 84	sessions             session.Service
 85	messages             message.Service
 86	disableAutoSummarize bool
 87	isYolo               bool
 88
 89	messageQueue   *csync.Map[string, []SessionAgentCall]
 90	activeRequests *csync.Map[string, context.CancelFunc]
 91}
 92
 93type SessionAgentOptions struct {
 94	LargeModel           Model
 95	SmallModel           Model
 96	SystemPromptPrefix   string
 97	SystemPrompt         string
 98	DisableAutoSummarize bool
 99	IsYolo               bool
100	Sessions             session.Service
101	Messages             message.Service
102	Tools                []fantasy.AgentTool
103}
104
105func NewSessionAgent(
106	opts SessionAgentOptions,
107) SessionAgent {
108	return &sessionAgent{
109		largeModel:           opts.LargeModel,
110		smallModel:           opts.SmallModel,
111		systemPromptPrefix:   opts.SystemPromptPrefix,
112		systemPrompt:         opts.SystemPrompt,
113		sessions:             opts.Sessions,
114		messages:             opts.Messages,
115		disableAutoSummarize: opts.DisableAutoSummarize,
116		tools:                opts.Tools,
117		isYolo:               opts.IsYolo,
118		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
119		activeRequests:       csync.NewMap[string, context.CancelFunc](),
120	}
121}
122
123func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
124	if call.Prompt == "" {
125		return nil, ErrEmptyPrompt
126	}
127	if call.SessionID == "" {
128		return nil, ErrSessionMissing
129	}
130
131	// Queue the message if busy
132	if a.IsSessionBusy(call.SessionID) {
133		existing, ok := a.messageQueue.Get(call.SessionID)
134		if !ok {
135			existing = []SessionAgentCall{}
136		}
137		existing = append(existing, call)
138		a.messageQueue.Set(call.SessionID, existing)
139		return nil, nil
140	}
141
142	if len(a.tools) > 0 {
143		// Add Anthropic caching to the last tool.
144		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
145	}
146
147	agent := fantasy.NewAgent(
148		a.largeModel.Model,
149		fantasy.WithSystemPrompt(a.systemPrompt),
150		fantasy.WithTools(a.tools...),
151	)
152
153	sessionLock := sync.Mutex{}
154	currentSession, err := a.sessions.Get(ctx, call.SessionID)
155	if err != nil {
156		return nil, fmt.Errorf("failed to get session: %w", err)
157	}
158
159	msgs, err := a.getSessionMessages(ctx, currentSession)
160	if err != nil {
161		return nil, fmt.Errorf("failed to get session messages: %w", err)
162	}
163
164	var wg sync.WaitGroup
165	// Generate title if first message.
166	if len(msgs) == 0 {
167		wg.Go(func() {
168			sessionLock.Lock()
169			a.generateTitle(ctx, &currentSession, call.Prompt)
170			sessionLock.Unlock()
171		})
172	}
173
174	// Add the user message to the session.
175	_, err = a.createUserMessage(ctx, call)
176	if err != nil {
177		return nil, err
178	}
179
180	// Add the session to the context.
181	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
182
183	genCtx, cancel := context.WithCancel(ctx)
184	a.activeRequests.Set(call.SessionID, cancel)
185
186	defer cancel()
187	defer a.activeRequests.Del(call.SessionID)
188
189	history, files := a.preparePrompt(msgs, call.Attachments...)
190
191	startTime := time.Now()
192	a.eventPromptSent(call.SessionID)
193
194	var currentAssistant *message.Message
195	var shouldSummarize bool
196	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
197		Prompt:           call.Prompt,
198		Files:            files,
199		Messages:         history,
200		ProviderOptions:  call.ProviderOptions,
201		MaxOutputTokens:  &call.MaxOutputTokens,
202		TopP:             call.TopP,
203		Temperature:      call.Temperature,
204		PresencePenalty:  call.PresencePenalty,
205		TopK:             call.TopK,
206		FrequencyPenalty: call.FrequencyPenalty,
207		// Before each step create a new assistant message.
208		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
209			prepared.Messages = options.Messages
210			// Reset all cached items.
211			for i := range prepared.Messages {
212				prepared.Messages[i].ProviderOptions = nil
213			}
214
215			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
216			a.messageQueue.Del(call.SessionID)
217			for _, queued := range queuedCalls {
218				userMessage, createErr := a.createUserMessage(callContext, queued)
219				if createErr != nil {
220					return callContext, prepared, createErr
221				}
222				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
223			}
224
225			lastSystemRoleInx := 0
226			systemMessageUpdated := false
227			for i, msg := range prepared.Messages {
228				// Only add cache control to the last message.
229				if msg.Role == fantasy.MessageRoleSystem {
230					lastSystemRoleInx = i
231				} else if !systemMessageUpdated {
232					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
233					systemMessageUpdated = true
234				}
235				// Than add cache control to the last 2 messages.
236				if i > len(prepared.Messages)-3 {
237					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
238				}
239			}
240
241			if a.systemPromptPrefix != "" {
242				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
243			}
244
245			var assistantMsg message.Message
246			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
247				Role:     message.Assistant,
248				Parts:    []message.ContentPart{},
249				Model:    a.largeModel.ModelCfg.Model,
250				Provider: a.largeModel.ModelCfg.Provider,
251			})
252			if err != nil {
253				return callContext, prepared, err
254			}
255			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
256			currentAssistant = &assistantMsg
257			return callContext, prepared, err
258		},
259		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
260			currentAssistant.AppendReasoningContent(reasoning.Text)
261			return a.messages.Update(genCtx, *currentAssistant)
262		},
263		OnReasoningDelta: func(id string, text string) error {
264			currentAssistant.AppendReasoningContent(text)
265			return a.messages.Update(genCtx, *currentAssistant)
266		},
267		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
268			// handle anthropic signature
269			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
270				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
271					currentAssistant.AppendReasoningSignature(reasoning.Signature)
272				}
273			}
274			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
275				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
276					currentAssistant.AppendReasoningSignature(reasoning.Signature)
277				}
278			}
279			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
280				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
281					currentAssistant.SetReasoningResponsesData(reasoning)
282				}
283			}
284			currentAssistant.FinishThinking()
285			return a.messages.Update(genCtx, *currentAssistant)
286		},
287		OnTextDelta: func(id string, text string) error {
288			// Strip leading newline from initial text content. This is is
289			// particularly important in non-interactive mode where leading
290			// newlines are very visible.
291			if len(currentAssistant.Parts) == 0 {
292				text = strings.TrimPrefix(text, "\n")
293			}
294
295			currentAssistant.AppendContent(text)
296			return a.messages.Update(genCtx, *currentAssistant)
297		},
298		OnToolInputStart: func(id string, toolName string) error {
299			toolCall := message.ToolCall{
300				ID:               id,
301				Name:             toolName,
302				ProviderExecuted: false,
303				Finished:         false,
304			}
305			currentAssistant.AddToolCall(toolCall)
306			return a.messages.Update(genCtx, *currentAssistant)
307		},
308		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
309			// TODO: implement
310		},
311		OnToolCall: func(tc fantasy.ToolCallContent) error {
312			toolCall := message.ToolCall{
313				ID:               tc.ToolCallID,
314				Name:             tc.ToolName,
315				Input:            tc.Input,
316				ProviderExecuted: false,
317				Finished:         true,
318			}
319			currentAssistant.AddToolCall(toolCall)
320			return a.messages.Update(genCtx, *currentAssistant)
321		},
322		OnToolResult: func(result fantasy.ToolResultContent) error {
323			var resultContent string
324			isError := false
325			switch result.Result.GetType() {
326			case fantasy.ToolResultContentTypeText:
327				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
328				if ok {
329					resultContent = r.Text
330				}
331			case fantasy.ToolResultContentTypeError:
332				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
333				if ok {
334					isError = true
335					resultContent = r.Error.Error()
336				}
337			case fantasy.ToolResultContentTypeMedia:
338				// TODO: handle this message type
339			}
340			toolResult := message.ToolResult{
341				ToolCallID: result.ToolCallID,
342				Name:       result.ToolName,
343				Content:    resultContent,
344				IsError:    isError,
345				Metadata:   result.ClientMetadata,
346			}
347			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
348				Role: message.Tool,
349				Parts: []message.ContentPart{
350					toolResult,
351				},
352			})
353			if createMsgErr != nil {
354				return createMsgErr
355			}
356			return nil
357		},
358		OnStepFinish: func(stepResult fantasy.StepResult) error {
359			finishReason := message.FinishReasonUnknown
360			switch stepResult.FinishReason {
361			case fantasy.FinishReasonLength:
362				finishReason = message.FinishReasonMaxTokens
363			case fantasy.FinishReasonStop:
364				finishReason = message.FinishReasonEndTurn
365			case fantasy.FinishReasonToolCalls:
366				finishReason = message.FinishReasonToolUse
367			}
368			currentAssistant.AddFinish(finishReason, "", "")
369			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
370			sessionLock.Lock()
371			_, sessionErr := a.sessions.Save(genCtx, currentSession)
372			sessionLock.Unlock()
373			if sessionErr != nil {
374				return sessionErr
375			}
376			return a.messages.Update(genCtx, *currentAssistant)
377		},
378		StopWhen: []fantasy.StopCondition{
379			func(_ []fantasy.StepResult) bool {
380				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
381				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
382				remaining := cw - tokens
383				var threshold int64
384				if cw > 200_000 {
385					threshold = 20_000
386				} else {
387					threshold = int64(float64(cw) * 0.2)
388				}
389				if (remaining <= threshold) && !a.disableAutoSummarize {
390					shouldSummarize = true
391					return true
392				}
393				return false
394			},
395		},
396	})
397
398	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
399
400	if err != nil {
401		isCancelErr := errors.Is(err, context.Canceled)
402		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
403		if currentAssistant == nil {
404			return result, err
405		}
406		// Ensure we finish thinking on error to close the reasoning state.
407		currentAssistant.FinishThinking()
408		toolCalls := currentAssistant.ToolCalls()
409		// INFO: we use the parent context here because the genCtx has been cancelled.
410		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
411		if createErr != nil {
412			return nil, createErr
413		}
414		for _, tc := range toolCalls {
415			if !tc.Finished {
416				tc.Finished = true
417				tc.Input = "{}"
418				currentAssistant.AddToolCall(tc)
419				updateErr := a.messages.Update(ctx, *currentAssistant)
420				if updateErr != nil {
421					return nil, updateErr
422				}
423			}
424
425			found := false
426			for _, msg := range msgs {
427				if msg.Role == message.Tool {
428					for _, tr := range msg.ToolResults() {
429						if tr.ToolCallID == tc.ID {
430							found = true
431							break
432						}
433					}
434				}
435				if found {
436					break
437				}
438			}
439			if found {
440				continue
441			}
442			content := "There was an error while executing the tool"
443			if isCancelErr {
444				content = "Tool execution canceled by user"
445			} else if isPermissionErr {
446				content = "User denied permission"
447			}
448			toolResult := message.ToolResult{
449				ToolCallID: tc.ID,
450				Name:       tc.Name,
451				Content:    content,
452				IsError:    true,
453			}
454			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
455				Role: message.Tool,
456				Parts: []message.ContentPart{
457					toolResult,
458				},
459			})
460			if createErr != nil {
461				return nil, createErr
462			}
463		}
464		var fantasyErr *fantasy.Error
465		var providerErr *fantasy.ProviderError
466		const defaultTitle = "Provider Error"
467		if isCancelErr {
468			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
469		} else if isPermissionErr {
470			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
471		} else if errors.As(err, &providerErr) {
472			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
473		} else if errors.As(err, &fantasyErr) {
474			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
475		} else {
476			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
477		}
478		// Note: we use the parent context here because the genCtx has been
479		// cancelled.
480		updateErr := a.messages.Update(ctx, *currentAssistant)
481		if updateErr != nil {
482			return nil, updateErr
483		}
484		return nil, err
485	}
486	wg.Wait()
487
488	if shouldSummarize {
489		a.activeRequests.Del(call.SessionID)
490		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
491			return nil, summarizeErr
492		}
493		// If the agent wasn't done...
494		if len(currentAssistant.ToolCalls()) > 0 {
495			existing, ok := a.messageQueue.Get(call.SessionID)
496			if !ok {
497				existing = []SessionAgentCall{}
498			}
499			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
500			existing = append(existing, call)
501			a.messageQueue.Set(call.SessionID, existing)
502		}
503	}
504
505	// Release active request before processing queued messages.
506	a.activeRequests.Del(call.SessionID)
507	cancel()
508
509	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
510	if !ok || len(queuedMessages) == 0 {
511		return result, err
512	}
513	// There are queued messages restart the loop.
514	firstQueuedMessage := queuedMessages[0]
515	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
516	return a.Run(ctx, firstQueuedMessage)
517}
518
519func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
520	if a.IsSessionBusy(sessionID) {
521		return ErrSessionBusy
522	}
523
524	currentSession, err := a.sessions.Get(ctx, sessionID)
525	if err != nil {
526		return fmt.Errorf("failed to get session: %w", err)
527	}
528	msgs, err := a.getSessionMessages(ctx, currentSession)
529	if err != nil {
530		return err
531	}
532	if len(msgs) == 0 {
533		// Nothing to summarize.
534		return nil
535	}
536
537	aiMsgs, _ := a.preparePrompt(msgs)
538
539	genCtx, cancel := context.WithCancel(ctx)
540	a.activeRequests.Set(sessionID, cancel)
541	defer a.activeRequests.Del(sessionID)
542	defer cancel()
543
544	agent := fantasy.NewAgent(a.largeModel.Model,
545		fantasy.WithSystemPrompt(string(summaryPrompt)),
546	)
547	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
548		Role:             message.Assistant,
549		Model:            a.largeModel.Model.Model(),
550		Provider:         a.largeModel.Model.Provider(),
551		IsSummaryMessage: true,
552	})
553	if err != nil {
554		return err
555	}
556
557	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
558		Prompt:          "Provide a detailed summary of our conversation above.",
559		Messages:        aiMsgs,
560		ProviderOptions: opts,
561		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
562			prepared.Messages = options.Messages
563			if a.systemPromptPrefix != "" {
564				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
565			}
566			return callContext, prepared, nil
567		},
568		OnReasoningDelta: func(id string, text string) error {
569			summaryMessage.AppendReasoningContent(text)
570			return a.messages.Update(genCtx, summaryMessage)
571		},
572		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
573			// Handle anthropic signature.
574			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
575				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
576					summaryMessage.AppendReasoningSignature(signature.Signature)
577				}
578			}
579			summaryMessage.FinishThinking()
580			return a.messages.Update(genCtx, summaryMessage)
581		},
582		OnTextDelta: func(id, text string) error {
583			summaryMessage.AppendContent(text)
584			return a.messages.Update(genCtx, summaryMessage)
585		},
586	})
587	if err != nil {
588		isCancelErr := errors.Is(err, context.Canceled)
589		if isCancelErr {
590			// User cancelled summarize we need to remove the summary message.
591			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
592			return deleteErr
593		}
594		return err
595	}
596
597	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
598	err = a.messages.Update(genCtx, summaryMessage)
599	if err != nil {
600		return err
601	}
602
603	var openrouterCost *float64
604	for _, step := range resp.Steps {
605		stepCost := a.openrouterCost(step.ProviderMetadata)
606		if stepCost != nil {
607			newCost := *stepCost
608			if openrouterCost != nil {
609				newCost += *openrouterCost
610			}
611			openrouterCost = &newCost
612		}
613	}
614
615	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
616
617	// Just in case, get just the last usage info.
618	usage := resp.Response.Usage
619	currentSession.SummaryMessageID = summaryMessage.ID
620	currentSession.CompletionTokens = usage.OutputTokens
621	currentSession.PromptTokens = 0
622	_, err = a.sessions.Save(genCtx, currentSession)
623	return err
624}
625
626func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
627	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
628		return fantasy.ProviderOptions{}
629	}
630	return fantasy.ProviderOptions{
631		anthropic.Name: &anthropic.ProviderCacheControlOptions{
632			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
633		},
634		bedrock.Name: &anthropic.ProviderCacheControlOptions{
635			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
636		},
637	}
638}
639
640func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
641	var attachmentParts []message.ContentPart
642	for _, attachment := range call.Attachments {
643		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
644	}
645	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
646	parts = append(parts, attachmentParts...)
647	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
648		Role:  message.User,
649		Parts: parts,
650	})
651	if err != nil {
652		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
653	}
654	return msg, nil
655}
656
657func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
658	var history []fantasy.Message
659	for _, m := range msgs {
660		if len(m.Parts) == 0 {
661			continue
662		}
663		// Assistant message without content or tool calls (cancelled before it
664		// returned anything).
665		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
666			continue
667		}
668		history = append(history, m.ToAIMessage()...)
669	}
670
671	var files []fantasy.FilePart
672	for _, attachment := range attachments {
673		files = append(files, fantasy.FilePart{
674			Filename:  attachment.FileName,
675			Data:      attachment.Content,
676			MediaType: attachment.MimeType,
677		})
678	}
679
680	return history, files
681}
682
683func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
684	msgs, err := a.messages.List(ctx, session.ID)
685	if err != nil {
686		return nil, fmt.Errorf("failed to list messages: %w", err)
687	}
688
689	if session.SummaryMessageID != "" {
690		summaryMsgInex := -1
691		for i, msg := range msgs {
692			if msg.ID == session.SummaryMessageID {
693				summaryMsgInex = i
694				break
695			}
696		}
697		if summaryMsgInex != -1 {
698			msgs = msgs[summaryMsgInex:]
699			msgs[0].Role = message.User
700		}
701	}
702	return msgs, nil
703}
704
705func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
706	if prompt == "" {
707		return
708	}
709
710	var maxOutput int64 = 40
711	if a.smallModel.CatwalkCfg.CanReason {
712		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
713	}
714
715	agent := fantasy.NewAgent(a.smallModel.Model,
716		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
717		fantasy.WithMaxOutputTokens(maxOutput),
718	)
719
720	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
721		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
722		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
723			prepared.Messages = options.Messages
724			if a.systemPromptPrefix != "" {
725				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
726			}
727			return callContext, prepared, nil
728		},
729	})
730	if err != nil {
731		slog.Error("error generating title", "err", err)
732		return
733	}
734
735	title := resp.Response.Content.Text()
736
737	title = strings.ReplaceAll(title, "\n", " ")
738
739	// Remove thinking tags if present.
740	if idx := strings.Index(title, "</think>"); idx > 0 {
741		title = title[idx+len("</think>"):]
742	}
743
744	title = strings.TrimSpace(title)
745	if title == "" {
746		slog.Warn("failed to generate title", "warn", "empty title")
747		return
748	}
749
750	session.Title = title
751
752	var openrouterCost *float64
753	for _, step := range resp.Steps {
754		stepCost := a.openrouterCost(step.ProviderMetadata)
755		if stepCost != nil {
756			newCost := *stepCost
757			if openrouterCost != nil {
758				newCost += *openrouterCost
759			}
760			openrouterCost = &newCost
761		}
762	}
763
764	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
765	_, saveErr := a.sessions.Save(ctx, *session)
766	if saveErr != nil {
767		slog.Error("failed to save session title & usage", "error", saveErr)
768		return
769	}
770}
771
772func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
773	openrouterMetadata, ok := metadata[openrouter.Name]
774	if !ok {
775		return nil
776	}
777
778	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
779	if !ok {
780		return nil
781	}
782	return &opts.Usage.Cost
783}
784
785func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
786	modelConfig := model.CatwalkCfg
787	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
788		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
789		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
790		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
791
792	a.eventTokensUsed(session.ID, model, usage, cost)
793
794	if overrideCost != nil {
795		session.Cost += *overrideCost
796	} else {
797		session.Cost += cost
798	}
799
800	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
801	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
802}
803
804func (a *sessionAgent) Cancel(sessionID string) {
805	// Cancel regular requests.
806	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
807		slog.Info("Request cancellation initiated", "session_id", sessionID)
808		cancel()
809	}
810
811	// Also check for summarize requests.
812	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
813		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
814		cancel()
815	}
816
817	if a.QueuedPrompts(sessionID) > 0 {
818		slog.Info("Clearing queued prompts", "session_id", sessionID)
819		a.messageQueue.Del(sessionID)
820	}
821}
822
823func (a *sessionAgent) ClearQueue(sessionID string) {
824	if a.QueuedPrompts(sessionID) > 0 {
825		slog.Info("Clearing queued prompts", "session_id", sessionID)
826		a.messageQueue.Del(sessionID)
827	}
828}
829
830func (a *sessionAgent) CancelAll() {
831	if !a.IsBusy() {
832		return
833	}
834	for key := range a.activeRequests.Seq2() {
835		a.Cancel(key) // key is sessionID
836	}
837
838	timeout := time.After(5 * time.Second)
839	for a.IsBusy() {
840		select {
841		case <-timeout:
842			return
843		default:
844			time.Sleep(200 * time.Millisecond)
845		}
846	}
847}
848
849func (a *sessionAgent) IsBusy() bool {
850	var busy bool
851	for cancelFunc := range a.activeRequests.Seq() {
852		if cancelFunc != nil {
853			busy = true
854			break
855		}
856	}
857	return busy
858}
859
860func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
861	_, busy := a.activeRequests.Get(sessionID)
862	return busy
863}
864
865func (a *sessionAgent) QueuedPrompts(sessionID string) int {
866	l, ok := a.messageQueue.Get(sessionID)
867	if !ok {
868		return 0
869	}
870	return len(l)
871}
872
873func (a *sessionAgent) SetModels(large Model, small Model) {
874	a.largeModel = large
875	a.smallModel = small
876}
877
878func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
879	a.tools = tools
880}
881
882func (a *sessionAgent) Model() Model {
883	return a.largeModel
884}