agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"github.com/charmbracelet/catwalk/pkg/catwalk"
 30	"github.com/charmbracelet/crush/internal/agent/tools"
 31	"github.com/charmbracelet/crush/internal/config"
 32	"github.com/charmbracelet/crush/internal/csync"
 33	"github.com/charmbracelet/crush/internal/message"
 34	"github.com/charmbracelet/crush/internal/notification"
 35	"github.com/charmbracelet/crush/internal/permission"
 36	"github.com/charmbracelet/crush/internal/session"
 37	"github.com/charmbracelet/crush/internal/stringext"
 38)
 39
 40//go:embed templates/title.md
 41var titlePrompt []byte
 42
 43//go:embed templates/summary.md
 44var summaryPrompt []byte
 45
 46type SessionAgentCall struct {
 47	SessionID        string
 48	Prompt           string
 49	ProviderOptions  fantasy.ProviderOptions
 50	Attachments      []message.Attachment
 51	MaxOutputTokens  int64
 52	Temperature      *float64
 53	TopP             *float64
 54	TopK             *int64
 55	FrequencyPenalty *float64
 56	PresencePenalty  *float64
 57}
 58
 59type SessionAgent interface {
 60	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 61	SetModels(large Model, small Model)
 62	SetTools(tools []fantasy.AgentTool)
 63	Cancel(sessionID string)
 64	CancelAll()
 65	IsSessionBusy(sessionID string) bool
 66	IsBusy() bool
 67	QueuedPrompts(sessionID string) int
 68	ClearQueue(sessionID string)
 69	Summarize(context.Context, string, fantasy.ProviderOptions) error
 70	Model() Model
 71}
 72
 73type Model struct {
 74	Model      fantasy.LanguageModel
 75	CatwalkCfg catwalk.Model
 76	ModelCfg   config.SelectedModel
 77}
 78
 79type sessionAgent struct {
 80	largeModel           Model
 81	smallModel           Model
 82	systemPromptPrefix   string
 83	systemPrompt         string
 84	tools                []fantasy.AgentTool
 85	sessions             session.Service
 86	messages             message.Service
 87	disableAutoSummarize bool
 88	isYolo               bool
 89
 90	messageQueue   *csync.Map[string, []SessionAgentCall]
 91	activeRequests *csync.Map[string, context.CancelFunc]
 92}
 93
 94type SessionAgentOptions struct {
 95	LargeModel           Model
 96	SmallModel           Model
 97	SystemPromptPrefix   string
 98	SystemPrompt         string
 99	DisableAutoSummarize bool
100	IsYolo               bool
101	Sessions             session.Service
102	Messages             message.Service
103	Tools                []fantasy.AgentTool
104}
105
106func NewSessionAgent(
107	opts SessionAgentOptions,
108) SessionAgent {
109	return &sessionAgent{
110		largeModel:           opts.LargeModel,
111		smallModel:           opts.SmallModel,
112		systemPromptPrefix:   opts.SystemPromptPrefix,
113		systemPrompt:         opts.SystemPrompt,
114		sessions:             opts.Sessions,
115		messages:             opts.Messages,
116		disableAutoSummarize: opts.DisableAutoSummarize,
117		tools:                opts.Tools,
118		isYolo:               opts.IsYolo,
119		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
120		activeRequests:       csync.NewMap[string, context.CancelFunc](),
121	}
122}
123
124func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
125	if call.Prompt == "" {
126		return nil, ErrEmptyPrompt
127	}
128	if call.SessionID == "" {
129		return nil, ErrSessionMissing
130	}
131
132	// Queue the message if busy
133	if a.IsSessionBusy(call.SessionID) {
134		existing, ok := a.messageQueue.Get(call.SessionID)
135		if !ok {
136			existing = []SessionAgentCall{}
137		}
138		existing = append(existing, call)
139		a.messageQueue.Set(call.SessionID, existing)
140		return nil, nil
141	}
142
143	if len(a.tools) > 0 {
144		// Add Anthropic caching to the last tool.
145		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
146	}
147
148	agent := fantasy.NewAgent(
149		a.largeModel.Model,
150		fantasy.WithSystemPrompt(a.systemPrompt),
151		fantasy.WithTools(a.tools...),
152	)
153
154	sessionLock := sync.Mutex{}
155	currentSession, err := a.sessions.Get(ctx, call.SessionID)
156	if err != nil {
157		return nil, fmt.Errorf("failed to get session: %w", err)
158	}
159
160	msgs, err := a.getSessionMessages(ctx, currentSession)
161	if err != nil {
162		return nil, fmt.Errorf("failed to get session messages: %w", err)
163	}
164
165	var wg sync.WaitGroup
166	// Generate title if first message.
167	if len(msgs) == 0 {
168		wg.Go(func() {
169			sessionLock.Lock()
170			a.generateTitle(ctx, &currentSession, call.Prompt)
171			sessionLock.Unlock()
172		})
173	}
174
175	// Add the user message to the session.
176	_, err = a.createUserMessage(ctx, call)
177	if err != nil {
178		return nil, err
179	}
180
181	// Add the session to the context.
182	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
183
184	genCtx, cancel := context.WithCancel(ctx)
185	a.activeRequests.Set(call.SessionID, cancel)
186
187	defer cancel()
188	defer a.activeRequests.Del(call.SessionID)
189
190	history, files := a.preparePrompt(msgs, call.Attachments...)
191
192	startTime := time.Now()
193	a.eventPromptSent(call.SessionID)
194
195	var currentAssistant *message.Message
196	var shouldSummarize bool
197	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
198		Prompt:           call.Prompt,
199		Files:            files,
200		Messages:         history,
201		ProviderOptions:  call.ProviderOptions,
202		MaxOutputTokens:  &call.MaxOutputTokens,
203		TopP:             call.TopP,
204		Temperature:      call.Temperature,
205		PresencePenalty:  call.PresencePenalty,
206		TopK:             call.TopK,
207		FrequencyPenalty: call.FrequencyPenalty,
208		// Before each step create a new assistant message.
209		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
210			prepared.Messages = options.Messages
211			// Reset all cached items.
212			for i := range prepared.Messages {
213				prepared.Messages[i].ProviderOptions = nil
214			}
215
216			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
217			a.messageQueue.Del(call.SessionID)
218			for _, queued := range queuedCalls {
219				userMessage, createErr := a.createUserMessage(callContext, queued)
220				if createErr != nil {
221					return callContext, prepared, createErr
222				}
223				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
224			}
225
226			lastSystemRoleInx := 0
227			systemMessageUpdated := false
228			for i, msg := range prepared.Messages {
229				// Only add cache control to the last message.
230				if msg.Role == fantasy.MessageRoleSystem {
231					lastSystemRoleInx = i
232				} else if !systemMessageUpdated {
233					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
234					systemMessageUpdated = true
235				}
236				// Than add cache control to the last 2 messages.
237				if i > len(prepared.Messages)-3 {
238					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
239				}
240			}
241
242			if a.systemPromptPrefix != "" {
243				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
244			}
245
246			var assistantMsg message.Message
247			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
248				Role:     message.Assistant,
249				Parts:    []message.ContentPart{},
250				Model:    a.largeModel.ModelCfg.Model,
251				Provider: a.largeModel.ModelCfg.Provider,
252			})
253			if err != nil {
254				return callContext, prepared, err
255			}
256			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
257			currentAssistant = &assistantMsg
258			return callContext, prepared, err
259		},
260		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
261			currentAssistant.AppendReasoningContent(reasoning.Text)
262			return a.messages.Update(genCtx, *currentAssistant)
263		},
264		OnReasoningDelta: func(id string, text string) error {
265			currentAssistant.AppendReasoningContent(text)
266			return a.messages.Update(genCtx, *currentAssistant)
267		},
268		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
269			// handle anthropic signature
270			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
271				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
272					currentAssistant.AppendReasoningSignature(reasoning.Signature)
273				}
274			}
275			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
276				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
277					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
278				}
279			}
280			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
281				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
282					currentAssistant.SetReasoningResponsesData(reasoning)
283				}
284			}
285			currentAssistant.FinishThinking()
286			return a.messages.Update(genCtx, *currentAssistant)
287		},
288		OnTextDelta: func(id string, text string) error {
289			// Strip leading newline from initial text content. This is is
290			// particularly important in non-interactive mode where leading
291			// newlines are very visible.
292			if len(currentAssistant.Parts) == 0 {
293				text = strings.TrimPrefix(text, "\n")
294			}
295
296			currentAssistant.AppendContent(text)
297			return a.messages.Update(genCtx, *currentAssistant)
298		},
299		OnToolInputStart: func(id string, toolName string) error {
300			toolCall := message.ToolCall{
301				ID:               id,
302				Name:             toolName,
303				ProviderExecuted: false,
304				Finished:         false,
305			}
306			currentAssistant.AddToolCall(toolCall)
307			return a.messages.Update(genCtx, *currentAssistant)
308		},
309		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
310			// TODO: implement
311		},
312		OnToolCall: func(tc fantasy.ToolCallContent) error {
313			toolCall := message.ToolCall{
314				ID:               tc.ToolCallID,
315				Name:             tc.ToolName,
316				Input:            tc.Input,
317				ProviderExecuted: false,
318				Finished:         true,
319			}
320			currentAssistant.AddToolCall(toolCall)
321			return a.messages.Update(genCtx, *currentAssistant)
322		},
323		OnToolResult: func(result fantasy.ToolResultContent) error {
324			var resultContent string
325			isError := false
326			switch result.Result.GetType() {
327			case fantasy.ToolResultContentTypeText:
328				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
329				if ok {
330					resultContent = r.Text
331				}
332			case fantasy.ToolResultContentTypeError:
333				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
334				if ok {
335					isError = true
336					resultContent = r.Error.Error()
337				}
338			case fantasy.ToolResultContentTypeMedia:
339				// TODO: handle this message type
340			}
341			toolResult := message.ToolResult{
342				ToolCallID: result.ToolCallID,
343				Name:       result.ToolName,
344				Content:    resultContent,
345				IsError:    isError,
346				Metadata:   result.ClientMetadata,
347			}
348			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
349				Role: message.Tool,
350				Parts: []message.ContentPart{
351					toolResult,
352				},
353			})
354			if createMsgErr != nil {
355				return createMsgErr
356			}
357			return nil
358		},
359		OnStepFinish: func(stepResult fantasy.StepResult) error {
360			finishReason := message.FinishReasonUnknown
361			switch stepResult.FinishReason {
362			case fantasy.FinishReasonLength:
363				finishReason = message.FinishReasonMaxTokens
364			case fantasy.FinishReasonStop:
365				finishReason = message.FinishReasonEndTurn
366			case fantasy.FinishReasonToolCalls:
367				finishReason = message.FinishReasonToolUse
368			}
369			currentAssistant.AddFinish(finishReason, "", "")
370			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
371			sessionLock.Lock()
372			_, sessionErr := a.sessions.Save(genCtx, currentSession)
373			sessionLock.Unlock()
374			if sessionErr != nil {
375				return sessionErr
376			}
377			return a.messages.Update(genCtx, *currentAssistant)
378		},
379		StopWhen: []fantasy.StopCondition{
380			func(_ []fantasy.StepResult) bool {
381				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
382				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
383				remaining := cw - tokens
384				var threshold int64
385				if cw > 200_000 {
386					threshold = 20_000
387				} else {
388					threshold = int64(float64(cw) * 0.2)
389				}
390				if (remaining <= threshold) && !a.disableAutoSummarize {
391					shouldSummarize = true
392					return true
393				}
394				return false
395			},
396		},
397	})
398
399	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
400
401	if err != nil {
402		isCancelErr := errors.Is(err, context.Canceled)
403		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
404		if currentAssistant == nil {
405			return result, err
406		}
407		// Ensure we finish thinking on error to close the reasoning state.
408		currentAssistant.FinishThinking()
409		toolCalls := currentAssistant.ToolCalls()
410		// INFO: we use the parent context here because the genCtx has been cancelled.
411		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
412		if createErr != nil {
413			return nil, createErr
414		}
415		for _, tc := range toolCalls {
416			if !tc.Finished {
417				tc.Finished = true
418				tc.Input = "{}"
419				currentAssistant.AddToolCall(tc)
420				updateErr := a.messages.Update(ctx, *currentAssistant)
421				if updateErr != nil {
422					return nil, updateErr
423				}
424			}
425
426			found := false
427			for _, msg := range msgs {
428				if msg.Role == message.Tool {
429					for _, tr := range msg.ToolResults() {
430						if tr.ToolCallID == tc.ID {
431							found = true
432							break
433						}
434					}
435				}
436				if found {
437					break
438				}
439			}
440			if found {
441				continue
442			}
443			content := "There was an error while executing the tool"
444			if isCancelErr {
445				content = "Tool execution canceled by user"
446			} else if isPermissionErr {
447				content = "User denied permission"
448			}
449			toolResult := message.ToolResult{
450				ToolCallID: tc.ID,
451				Name:       tc.Name,
452				Content:    content,
453				IsError:    true,
454			}
455			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
456				Role: message.Tool,
457				Parts: []message.ContentPart{
458					toolResult,
459				},
460			})
461			if createErr != nil {
462				return nil, createErr
463			}
464		}
465		var fantasyErr *fantasy.Error
466		var providerErr *fantasy.ProviderError
467		const defaultTitle = "Provider Error"
468		if isCancelErr {
469			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
470		} else if isPermissionErr {
471			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
472		} else if errors.As(err, &providerErr) {
473			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
474		} else if errors.As(err, &fantasyErr) {
475			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
476		} else {
477			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
478		}
479		// Note: we use the parent context here because the genCtx has been
480		// cancelled.
481		updateErr := a.messages.Update(ctx, *currentAssistant)
482		if updateErr != nil {
483			return nil, updateErr
484		}
485		return nil, err
486	}
487	wg.Wait()
488
489	// Send notification that agent has finished its turn.
490	sessionLock.Lock()
491	notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
492	sessionLock.Unlock()
493	if err := notification.Send("Crush is waiting...", notifBody); err != nil {
494		// Log but don't fail on notification errors.
495	}
496
497	if shouldSummarize {
498		a.activeRequests.Del(call.SessionID)
499		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
500			return nil, summarizeErr
501		}
502		// If the agent wasn't done...
503		if len(currentAssistant.ToolCalls()) > 0 {
504			existing, ok := a.messageQueue.Get(call.SessionID)
505			if !ok {
506				existing = []SessionAgentCall{}
507			}
508			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
509			existing = append(existing, call)
510			a.messageQueue.Set(call.SessionID, existing)
511		}
512	}
513
514	// Release active request before processing queued messages.
515	a.activeRequests.Del(call.SessionID)
516	cancel()
517
518	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
519	if !ok || len(queuedMessages) == 0 {
520		return result, err
521	}
522	// There are queued messages restart the loop.
523	firstQueuedMessage := queuedMessages[0]
524	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
525	return a.Run(ctx, firstQueuedMessage)
526}
527
528func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
529	if a.IsSessionBusy(sessionID) {
530		return ErrSessionBusy
531	}
532
533	currentSession, err := a.sessions.Get(ctx, sessionID)
534	if err != nil {
535		return fmt.Errorf("failed to get session: %w", err)
536	}
537	msgs, err := a.getSessionMessages(ctx, currentSession)
538	if err != nil {
539		return err
540	}
541	if len(msgs) == 0 {
542		// Nothing to summarize.
543		return nil
544	}
545
546	aiMsgs, _ := a.preparePrompt(msgs)
547
548	genCtx, cancel := context.WithCancel(ctx)
549	a.activeRequests.Set(sessionID, cancel)
550	defer a.activeRequests.Del(sessionID)
551	defer cancel()
552
553	agent := fantasy.NewAgent(a.largeModel.Model,
554		fantasy.WithSystemPrompt(string(summaryPrompt)),
555	)
556	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
557		Role:             message.Assistant,
558		Model:            a.largeModel.Model.Model(),
559		Provider:         a.largeModel.Model.Provider(),
560		IsSummaryMessage: true,
561	})
562	if err != nil {
563		return err
564	}
565
566	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
567		Prompt:          "Provide a detailed summary of our conversation above.",
568		Messages:        aiMsgs,
569		ProviderOptions: opts,
570		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
571			prepared.Messages = options.Messages
572			if a.systemPromptPrefix != "" {
573				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
574			}
575			return callContext, prepared, nil
576		},
577		OnReasoningDelta: func(id string, text string) error {
578			summaryMessage.AppendReasoningContent(text)
579			return a.messages.Update(genCtx, summaryMessage)
580		},
581		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
582			// Handle anthropic signature.
583			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
584				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
585					summaryMessage.AppendReasoningSignature(signature.Signature)
586				}
587			}
588			summaryMessage.FinishThinking()
589			return a.messages.Update(genCtx, summaryMessage)
590		},
591		OnTextDelta: func(id, text string) error {
592			summaryMessage.AppendContent(text)
593			return a.messages.Update(genCtx, summaryMessage)
594		},
595	})
596	if err != nil {
597		isCancelErr := errors.Is(err, context.Canceled)
598		if isCancelErr {
599			// User cancelled summarize we need to remove the summary message.
600			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
601			return deleteErr
602		}
603		return err
604	}
605
606	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
607	err = a.messages.Update(genCtx, summaryMessage)
608	if err != nil {
609		return err
610	}
611
612	var openrouterCost *float64
613	for _, step := range resp.Steps {
614		stepCost := a.openrouterCost(step.ProviderMetadata)
615		if stepCost != nil {
616			newCost := *stepCost
617			if openrouterCost != nil {
618				newCost += *openrouterCost
619			}
620			openrouterCost = &newCost
621		}
622	}
623
624	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
625
626	// Just in case, get just the last usage info.
627	usage := resp.Response.Usage
628	currentSession.SummaryMessageID = summaryMessage.ID
629	currentSession.CompletionTokens = usage.OutputTokens
630	currentSession.PromptTokens = 0
631	_, err = a.sessions.Save(genCtx, currentSession)
632	return err
633}
634
635func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
636	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
637		return fantasy.ProviderOptions{}
638	}
639	return fantasy.ProviderOptions{
640		anthropic.Name: &anthropic.ProviderCacheControlOptions{
641			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
642		},
643		bedrock.Name: &anthropic.ProviderCacheControlOptions{
644			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
645		},
646	}
647}
648
649func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
650	var attachmentParts []message.ContentPart
651	for _, attachment := range call.Attachments {
652		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
653	}
654	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
655	parts = append(parts, attachmentParts...)
656	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
657		Role:  message.User,
658		Parts: parts,
659	})
660	if err != nil {
661		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
662	}
663	return msg, nil
664}
665
666func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
667	var history []fantasy.Message
668	for _, m := range msgs {
669		if len(m.Parts) == 0 {
670			continue
671		}
672		// Assistant message without content or tool calls (cancelled before it
673		// returned anything).
674		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
675			continue
676		}
677		history = append(history, m.ToAIMessage()...)
678	}
679
680	var files []fantasy.FilePart
681	for _, attachment := range attachments {
682		files = append(files, fantasy.FilePart{
683			Filename:  attachment.FileName,
684			Data:      attachment.Content,
685			MediaType: attachment.MimeType,
686		})
687	}
688
689	return history, files
690}
691
692func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
693	msgs, err := a.messages.List(ctx, session.ID)
694	if err != nil {
695		return nil, fmt.Errorf("failed to list messages: %w", err)
696	}
697
698	if session.SummaryMessageID != "" {
699		summaryMsgInex := -1
700		for i, msg := range msgs {
701			if msg.ID == session.SummaryMessageID {
702				summaryMsgInex = i
703				break
704			}
705		}
706		if summaryMsgInex != -1 {
707			msgs = msgs[summaryMsgInex:]
708			msgs[0].Role = message.User
709		}
710	}
711	return msgs, nil
712}
713
714func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
715	if prompt == "" {
716		return
717	}
718
719	var maxOutput int64 = 40
720	if a.smallModel.CatwalkCfg.CanReason {
721		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
722	}
723
724	agent := fantasy.NewAgent(a.smallModel.Model,
725		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
726		fantasy.WithMaxOutputTokens(maxOutput),
727	)
728
729	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
730		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
731		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
732			prepared.Messages = options.Messages
733			if a.systemPromptPrefix != "" {
734				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
735			}
736			return callContext, prepared, nil
737		},
738	})
739	if err != nil {
740		slog.Error("error generating title", "err", err)
741		return
742	}
743
744	title := resp.Response.Content.Text()
745
746	title = strings.ReplaceAll(title, "\n", " ")
747
748	// Remove thinking tags if present.
749	if idx := strings.Index(title, "</think>"); idx > 0 {
750		title = title[idx+len("</think>"):]
751	}
752
753	title = strings.TrimSpace(title)
754	if title == "" {
755		slog.Warn("failed to generate title", "warn", "empty title")
756		return
757	}
758
759	session.Title = title
760
761	var openrouterCost *float64
762	for _, step := range resp.Steps {
763		stepCost := a.openrouterCost(step.ProviderMetadata)
764		if stepCost != nil {
765			newCost := *stepCost
766			if openrouterCost != nil {
767				newCost += *openrouterCost
768			}
769			openrouterCost = &newCost
770		}
771	}
772
773	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
774	_, saveErr := a.sessions.Save(ctx, *session)
775	if saveErr != nil {
776		slog.Error("failed to save session title & usage", "error", saveErr)
777		return
778	}
779}
780
781func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
782	openrouterMetadata, ok := metadata[openrouter.Name]
783	if !ok {
784		return nil
785	}
786
787	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
788	if !ok {
789		return nil
790	}
791	return &opts.Usage.Cost
792}
793
794func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
795	modelConfig := model.CatwalkCfg
796	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
797		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
798		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
799		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
800
801	a.eventTokensUsed(session.ID, model, usage, cost)
802
803	if overrideCost != nil {
804		session.Cost += *overrideCost
805	} else {
806		session.Cost += cost
807	}
808
809	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
810	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
811}
812
813func (a *sessionAgent) Cancel(sessionID string) {
814	// Cancel regular requests.
815	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
816		slog.Info("Request cancellation initiated", "session_id", sessionID)
817		cancel()
818	}
819
820	// Also check for summarize requests.
821	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
822		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
823		cancel()
824	}
825
826	if a.QueuedPrompts(sessionID) > 0 {
827		slog.Info("Clearing queued prompts", "session_id", sessionID)
828		a.messageQueue.Del(sessionID)
829	}
830}
831
832func (a *sessionAgent) ClearQueue(sessionID string) {
833	if a.QueuedPrompts(sessionID) > 0 {
834		slog.Info("Clearing queued prompts", "session_id", sessionID)
835		a.messageQueue.Del(sessionID)
836	}
837}
838
839func (a *sessionAgent) CancelAll() {
840	if !a.IsBusy() {
841		return
842	}
843	for key := range a.activeRequests.Seq2() {
844		a.Cancel(key) // key is sessionID
845	}
846
847	timeout := time.After(5 * time.Second)
848	for a.IsBusy() {
849		select {
850		case <-timeout:
851			return
852		default:
853			time.Sleep(200 * time.Millisecond)
854		}
855	}
856}
857
858func (a *sessionAgent) IsBusy() bool {
859	var busy bool
860	for cancelFunc := range a.activeRequests.Seq() {
861		if cancelFunc != nil {
862			busy = true
863			break
864		}
865	}
866	return busy
867}
868
869func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
870	_, busy := a.activeRequests.Get(sessionID)
871	return busy
872}
873
874func (a *sessionAgent) QueuedPrompts(sessionID string) int {
875	l, ok := a.messageQueue.Get(sessionID)
876	if !ok {
877		return 0
878	}
879	return len(l)
880}
881
882func (a *sessionAgent) SetModels(large Model, small Model) {
883	a.largeModel = large
884	a.smallModel = small
885}
886
887func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
888	a.tools = tools
889}
890
891func (a *sessionAgent) Model() Model {
892	return a.largeModel
893}