agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"github.com/charmbracelet/catwalk/pkg/catwalk"
 30	"github.com/charmbracelet/crush/internal/agent/tools"
 31	"github.com/charmbracelet/crush/internal/config"
 32	"github.com/charmbracelet/crush/internal/csync"
 33	"github.com/charmbracelet/crush/internal/hooks"
 34	"github.com/charmbracelet/crush/internal/message"
 35	"github.com/charmbracelet/crush/internal/permission"
 36	"github.com/charmbracelet/crush/internal/session"
 37	"github.com/charmbracelet/crush/internal/stringext"
 38)
 39
 40//go:embed templates/title.md
 41var titlePrompt []byte
 42
 43//go:embed templates/summary.md
 44var summaryPrompt []byte
 45
 46type SessionAgentCall struct {
 47	SessionID        string
 48	Prompt           string
 49	ProviderOptions  fantasy.ProviderOptions
 50	Attachments      []message.Attachment
 51	MaxOutputTokens  int64
 52	Temperature      *float64
 53	TopP             *float64
 54	TopK             *int64
 55	FrequencyPenalty *float64
 56	PresencePenalty  *float64
 57}
 58
 59type SessionAgent interface {
 60	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 61	SetModels(large Model, small Model)
 62	SetTools(tools []fantasy.AgentTool)
 63	Cancel(sessionID string)
 64	CancelAll()
 65	IsSessionBusy(sessionID string) bool
 66	IsBusy() bool
 67	QueuedPrompts(sessionID string) int
 68	ClearQueue(sessionID string)
 69	Summarize(context.Context, string, fantasy.ProviderOptions) error
 70	Model() Model
 71}
 72
 73type Model struct {
 74	Model      fantasy.LanguageModel
 75	CatwalkCfg catwalk.Model
 76	ModelCfg   config.SelectedModel
 77}
 78
 79type sessionAgent struct {
 80	largeModel           Model
 81	smallModel           Model
 82	systemPromptPrefix   string
 83	systemPrompt         string
 84	tools                []fantasy.AgentTool
 85	sessions             session.Service
 86	messages             message.Service
 87	disableAutoSummarize bool
 88	isYolo               bool
 89	isSubAgent           bool
 90	hooksManager         hooks.Manager
 91	workingDir           string
 92
 93	messageQueue   *csync.Map[string, []SessionAgentCall]
 94	activeRequests *csync.Map[string, context.CancelFunc]
 95}
 96
 97type SessionAgentOptions struct {
 98	LargeModel           Model
 99	SmallModel           Model
100	SystemPromptPrefix   string
101	SystemPrompt         string
102	DisableAutoSummarize bool
103	IsYolo               bool
104	IsSubAgent           bool
105	HooksManager         hooks.Manager
106	WorkingDir           string
107	Sessions             session.Service
108	Messages             message.Service
109	Tools                []fantasy.AgentTool
110}
111
112func NewSessionAgent(
113	opts SessionAgentOptions,
114) SessionAgent {
115	return &sessionAgent{
116		largeModel:           opts.LargeModel,
117		smallModel:           opts.SmallModel,
118		systemPromptPrefix:   opts.SystemPromptPrefix,
119		systemPrompt:         opts.SystemPrompt,
120		sessions:             opts.Sessions,
121		messages:             opts.Messages,
122		disableAutoSummarize: opts.DisableAutoSummarize,
123		tools:                opts.Tools,
124		isYolo:               opts.IsYolo,
125		isSubAgent:           opts.IsSubAgent,
126		hooksManager:         opts.HooksManager,
127		workingDir:           opts.WorkingDir,
128		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
129		activeRequests:       csync.NewMap[string, context.CancelFunc](),
130	}
131}
132
133func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
134	if call.Prompt == "" {
135		return nil, ErrEmptyPrompt
136	}
137	if call.SessionID == "" {
138		return nil, ErrSessionMissing
139	}
140
141	// Queue the message if busy
142	if a.IsSessionBusy(call.SessionID) {
143		existing, ok := a.messageQueue.Get(call.SessionID)
144		if !ok {
145			existing = []SessionAgentCall{}
146		}
147		existing = append(existing, call)
148		a.messageQueue.Set(call.SessionID, existing)
149		return nil, nil
150	}
151
152	if len(a.tools) > 0 {
153		// Add Anthropic caching to the last tool.
154		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
155	}
156
157	agent := fantasy.NewAgent(
158		a.largeModel.Model,
159		fantasy.WithSystemPrompt(a.systemPrompt),
160		fantasy.WithTools(a.tools...),
161	)
162
163	sessionLock := sync.Mutex{}
164	currentSession, err := a.sessions.Get(ctx, call.SessionID)
165	if err != nil {
166		return nil, fmt.Errorf("failed to get session: %w", err)
167	}
168
169	msgs, err := a.getSessionMessages(ctx, currentSession)
170	if err != nil {
171		return nil, fmt.Errorf("failed to get session messages: %w", err)
172	}
173
174	var wg sync.WaitGroup
175	// Generate title if first message.
176	if len(msgs) == 0 {
177		wg.Go(func() {
178			sessionLock.Lock()
179			a.generateTitle(ctx, &currentSession, call.Prompt)
180			sessionLock.Unlock()
181		})
182	}
183
184	// Add the user message to the session.
185	msg, err := a.createUserMessage(ctx, call)
186	if err != nil {
187		return nil, err
188	}
189
190	// Add the session to the context.
191	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
192
193	genCtx, cancel := context.WithCancel(ctx)
194	a.activeRequests.Set(call.SessionID, cancel)
195
196	defer cancel()
197	defer a.activeRequests.Del(call.SessionID)
198
199	// create the agent message asap to show loading
200	var currentAssistant *message.Message
201	assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
202		Role:     message.Assistant,
203		Parts:    []message.ContentPart{},
204		Model:    a.largeModel.ModelCfg.Model,
205		Provider: a.largeModel.ModelCfg.Provider,
206	})
207	if err != nil {
208		return nil, err
209	}
210
211	currentAssistant = &assistantMessage
212
213	hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0)
214	if hookErr != nil {
215		// Delete the assistant message
216		// use the ctx since this could be a cancellation
217		deleteErr := a.messages.Delete(ctx, currentAssistant.ID)
218		return nil, cmp.Or(deleteErr, hookErr)
219	}
220
221	history, files := a.preparePrompt(msgs, call.Attachments...)
222
223	startTime := time.Now()
224	a.eventPromptSent(call.SessionID)
225
226	var shouldSummarize bool
227	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
228		Prompt:           msg.ContentWithHookContext(),
229		Files:            files,
230		Messages:         history,
231		ProviderOptions:  call.ProviderOptions,
232		MaxOutputTokens:  &call.MaxOutputTokens,
233		TopP:             call.TopP,
234		Temperature:      call.Temperature,
235		PresencePenalty:  call.PresencePenalty,
236		TopK:             call.TopK,
237		FrequencyPenalty: call.FrequencyPenalty,
238		// Before each step create a new assistant message.
239		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
240			// only add new assistant message when its not the first step
241			if options.StepNumber != 0 {
242				var assistantMsg message.Message
243				assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
244					Role:     message.Assistant,
245					Model:    a.largeModel.ModelCfg.Model,
246					Provider: a.largeModel.ModelCfg.Provider,
247				})
248				currentAssistant = &assistantMsg
249				// create the message first so we show loading asap
250				if err != nil {
251					return callContext, prepared, err
252				}
253			}
254			prepared.Messages = options.Messages
255			// Reset all cached items.
256			for i := range prepared.Messages {
257				prepared.Messages[i].ProviderOptions = nil
258			}
259
260			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
261			a.messageQueue.Del(call.SessionID)
262			for _, queued := range queuedCalls {
263				userMessage, createErr := a.createUserMessage(callContext, queued)
264				if createErr != nil {
265					return callContext, prepared, createErr
266				}
267
268				hookErr := a.executePromptSubmitHook(ctx, &msg, len(msgs) == 0)
269				if hookErr != nil {
270					return callContext, prepared, hookErr
271				}
272
273				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
274			}
275
276			lastSystemRoleInx := 0
277			systemMessageUpdated := false
278			for i, msg := range prepared.Messages {
279				// Only add cache control to the last message.
280				if msg.Role == fantasy.MessageRoleSystem {
281					lastSystemRoleInx = i
282				} else if !systemMessageUpdated {
283					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
284					systemMessageUpdated = true
285				}
286				// Than add cache control to the last 2 messages.
287				if i > len(prepared.Messages)-3 {
288					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
289				}
290			}
291
292			if a.systemPromptPrefix != "" {
293				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
294			}
295
296			callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
297			return callContext, prepared, err
298		},
299		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
300			currentAssistant.AppendReasoningContent(reasoning.Text)
301			return a.messages.Update(genCtx, *currentAssistant)
302		},
303		OnReasoningDelta: func(id string, text string) error {
304			currentAssistant.AppendReasoningContent(text)
305			return a.messages.Update(genCtx, *currentAssistant)
306		},
307		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
308			// handle anthropic signature
309			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
310				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
311					currentAssistant.AppendReasoningSignature(reasoning.Signature)
312				}
313			}
314			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
315				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
316					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
317				}
318			}
319			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
320				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
321					currentAssistant.SetReasoningResponsesData(reasoning)
322				}
323			}
324			currentAssistant.FinishThinking()
325			return a.messages.Update(genCtx, *currentAssistant)
326		},
327		OnTextDelta: func(id string, text string) error {
328			// Strip leading newline from initial text content. This is is
329			// particularly important in non-interactive mode where leading
330			// newlines are very visible.
331			if len(currentAssistant.Parts) == 0 {
332				text = strings.TrimPrefix(text, "\n")
333			}
334
335			currentAssistant.AppendContent(text)
336			return a.messages.Update(genCtx, *currentAssistant)
337		},
338		OnToolInputStart: func(id string, toolName string) error {
339			toolCall := message.ToolCall{
340				ID:               id,
341				Name:             toolName,
342				ProviderExecuted: false,
343				Finished:         false,
344			}
345			currentAssistant.AddToolCall(toolCall)
346			return a.messages.Update(genCtx, *currentAssistant)
347		},
348		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
349			// TODO: implement
350		},
351		OnToolCall: func(tc fantasy.ToolCallContent) error {
352			toolCall := message.ToolCall{
353				ID:               tc.ToolCallID,
354				Name:             tc.ToolName,
355				Input:            tc.Input,
356				ProviderExecuted: false,
357				Finished:         true,
358			}
359			currentAssistant.AddToolCall(toolCall)
360			return a.messages.Update(genCtx, *currentAssistant)
361		},
362		OnToolResult: func(result fantasy.ToolResultContent) error {
363			var resultContent string
364			isError := false
365			switch result.Result.GetType() {
366			case fantasy.ToolResultContentTypeText:
367				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
368				if ok {
369					resultContent = r.Text
370				}
371			case fantasy.ToolResultContentTypeError:
372				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
373				if ok {
374					isError = true
375					resultContent = r.Error.Error()
376				}
377			case fantasy.ToolResultContentTypeMedia:
378				// TODO: handle this message type
379			}
380			toolResult := message.ToolResult{
381				ToolCallID: result.ToolCallID,
382				Name:       result.ToolName,
383				Content:    resultContent,
384				IsError:    isError,
385				Metadata:   result.ClientMetadata,
386			}
387			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
388				Role: message.Tool,
389				Parts: []message.ContentPart{
390					toolResult,
391				},
392			})
393			if createMsgErr != nil {
394				return createMsgErr
395			}
396			return nil
397		},
398		OnStepFinish: func(stepResult fantasy.StepResult) error {
399			finishReason := message.FinishReasonUnknown
400			switch stepResult.FinishReason {
401			case fantasy.FinishReasonLength:
402				finishReason = message.FinishReasonMaxTokens
403			case fantasy.FinishReasonStop:
404				finishReason = message.FinishReasonEndTurn
405			case fantasy.FinishReasonToolCalls:
406				finishReason = message.FinishReasonToolUse
407			}
408			currentAssistant.AddFinish(finishReason, "", "")
409			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
410			sessionLock.Lock()
411			_, sessionErr := a.sessions.Save(genCtx, currentSession)
412			sessionLock.Unlock()
413			if sessionErr != nil {
414				return sessionErr
415			}
416			return a.messages.Update(genCtx, *currentAssistant)
417		},
418		StopWhen: []fantasy.StopCondition{
419			func(_ []fantasy.StepResult) bool {
420				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
421				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
422				remaining := cw - tokens
423				var threshold int64
424				if cw > 200_000 {
425					threshold = 20_000
426				} else {
427					threshold = int64(float64(cw) * 0.2)
428				}
429				if (remaining <= threshold) && !a.disableAutoSummarize {
430					shouldSummarize = true
431					return true
432				}
433				return false
434			},
435		},
436	})
437
438	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
439
440	if err != nil {
441		isCancelErr := errors.Is(err, context.Canceled)
442		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
443		if currentAssistant == nil {
444			return result, err
445		}
446		// Ensure we finish thinking on error to close the reasoning state.
447		currentAssistant.FinishThinking()
448		toolCalls := currentAssistant.ToolCalls()
449		// INFO: we use the parent context here because the genCtx has been cancelled.
450		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
451		if createErr != nil {
452			return nil, createErr
453		}
454		for _, tc := range toolCalls {
455			if !tc.Finished {
456				tc.Finished = true
457				tc.Input = "{}"
458				currentAssistant.AddToolCall(tc)
459				updateErr := a.messages.Update(ctx, *currentAssistant)
460				if updateErr != nil {
461					return nil, updateErr
462				}
463			}
464
465			found := false
466			for _, msg := range msgs {
467				if msg.Role == message.Tool {
468					for _, tr := range msg.ToolResults() {
469						if tr.ToolCallID == tc.ID {
470							found = true
471							break
472						}
473					}
474				}
475				if found {
476					break
477				}
478			}
479			if found {
480				continue
481			}
482			content := "There was an error while executing the tool"
483			if isCancelErr {
484				content = "Tool execution canceled by user"
485			} else if isPermissionErr {
486				content = "User denied permission"
487			}
488			toolResult := message.ToolResult{
489				ToolCallID: tc.ID,
490				Name:       tc.Name,
491				Content:    content,
492				IsError:    true,
493			}
494			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
495				Role: message.Tool,
496				Parts: []message.ContentPart{
497					toolResult,
498				},
499			})
500			if createErr != nil {
501				return nil, createErr
502			}
503		}
504		var fantasyErr *fantasy.Error
505		var providerErr *fantasy.ProviderError
506		const defaultTitle = "Provider Error"
507		if isCancelErr {
508			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
509		} else if isPermissionErr {
510			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
511		} else if errors.As(err, &providerErr) {
512			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
513		} else if errors.As(err, &fantasyErr) {
514			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
515		} else {
516			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
517		}
518		// Note: we use the parent context here because the genCtx has been
519		// cancelled.
520		updateErr := a.messages.Update(ctx, *currentAssistant)
521		if updateErr != nil {
522			return nil, updateErr
523		}
524		return nil, err
525	}
526	wg.Wait()
527
528	if shouldSummarize {
529		a.activeRequests.Del(call.SessionID)
530		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
531			return nil, summarizeErr
532		}
533		// If the agent wasn't done...
534		if len(currentAssistant.ToolCalls()) > 0 {
535			existing, ok := a.messageQueue.Get(call.SessionID)
536			if !ok {
537				existing = []SessionAgentCall{}
538			}
539			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
540			existing = append(existing, call)
541			a.messageQueue.Set(call.SessionID, existing)
542		}
543	}
544
545	// Release active request before processing queued messages.
546	a.activeRequests.Del(call.SessionID)
547	cancel()
548
549	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
550	if !ok || len(queuedMessages) == 0 {
551		return result, err
552	}
553	// There are queued messages restart the loop.
554	firstQueuedMessage := queuedMessages[0]
555	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
556	return a.Run(ctx, firstQueuedMessage)
557}
558
559func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
560	if a.IsSessionBusy(sessionID) {
561		return ErrSessionBusy
562	}
563
564	currentSession, err := a.sessions.Get(ctx, sessionID)
565	if err != nil {
566		return fmt.Errorf("failed to get session: %w", err)
567	}
568	msgs, err := a.getSessionMessages(ctx, currentSession)
569	if err != nil {
570		return err
571	}
572	if len(msgs) == 0 {
573		// Nothing to summarize.
574		return nil
575	}
576
577	aiMsgs, _ := a.preparePrompt(msgs)
578
579	genCtx, cancel := context.WithCancel(ctx)
580	a.activeRequests.Set(sessionID, cancel)
581	defer a.activeRequests.Del(sessionID)
582	defer cancel()
583
584	agent := fantasy.NewAgent(a.largeModel.Model,
585		fantasy.WithSystemPrompt(string(summaryPrompt)),
586	)
587	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
588		Role:             message.Assistant,
589		Model:            a.largeModel.Model.Model(),
590		Provider:         a.largeModel.Model.Provider(),
591		IsSummaryMessage: true,
592	})
593	if err != nil {
594		return err
595	}
596
597	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
598		Prompt:          "Provide a detailed summary of our conversation above.",
599		Messages:        aiMsgs,
600		ProviderOptions: opts,
601		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
602			prepared.Messages = options.Messages
603			if a.systemPromptPrefix != "" {
604				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
605			}
606			return callContext, prepared, nil
607		},
608		OnReasoningDelta: func(id string, text string) error {
609			summaryMessage.AppendReasoningContent(text)
610			return a.messages.Update(genCtx, summaryMessage)
611		},
612		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
613			// Handle anthropic signature.
614			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
615				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
616					summaryMessage.AppendReasoningSignature(signature.Signature)
617				}
618			}
619			summaryMessage.FinishThinking()
620			return a.messages.Update(genCtx, summaryMessage)
621		},
622		OnTextDelta: func(id, text string) error {
623			summaryMessage.AppendContent(text)
624			return a.messages.Update(genCtx, summaryMessage)
625		},
626	})
627	if err != nil {
628		isCancelErr := errors.Is(err, context.Canceled)
629		if isCancelErr {
630			// User cancelled summarize we need to remove the summary message.
631			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
632			return deleteErr
633		}
634		return err
635	}
636
637	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
638	err = a.messages.Update(genCtx, summaryMessage)
639	if err != nil {
640		return err
641	}
642
643	var openrouterCost *float64
644	for _, step := range resp.Steps {
645		stepCost := a.openrouterCost(step.ProviderMetadata)
646		if stepCost != nil {
647			newCost := *stepCost
648			if openrouterCost != nil {
649				newCost += *openrouterCost
650			}
651			openrouterCost = &newCost
652		}
653	}
654
655	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
656
657	// Just in case, get just the last usage info.
658	usage := resp.Response.Usage
659	currentSession.SummaryMessageID = summaryMessage.ID
660	currentSession.CompletionTokens = usage.OutputTokens
661	currentSession.PromptTokens = 0
662	_, err = a.sessions.Save(genCtx, currentSession)
663	return err
664}
665
666func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
667	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
668		return fantasy.ProviderOptions{}
669	}
670	return fantasy.ProviderOptions{
671		anthropic.Name: &anthropic.ProviderCacheControlOptions{
672			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
673		},
674		bedrock.Name: &anthropic.ProviderCacheControlOptions{
675			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
676		},
677	}
678}
679
680func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
681	var attachmentParts []message.ContentPart
682	for _, attachment := range call.Attachments {
683		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
684	}
685	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
686	parts = append(parts, attachmentParts...)
687	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
688		Role:  message.User,
689		Parts: parts,
690	})
691	if err != nil {
692		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
693	}
694	return msg, nil
695}
696
697func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
698	var history []fantasy.Message
699	for _, m := range msgs {
700		if len(m.Parts) == 0 {
701			continue
702		}
703		// Assistant message without content or tool calls (cancelled before it
704		// returned anything).
705		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
706			continue
707		}
708		history = append(history, m.ToAIMessage()...)
709	}
710
711	var files []fantasy.FilePart
712	for _, attachment := range attachments {
713		files = append(files, fantasy.FilePart{
714			Filename:  attachment.FileName,
715			Data:      attachment.Content,
716			MediaType: attachment.MimeType,
717		})
718	}
719
720	return history, files
721}
722
723func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
724	msgs, err := a.messages.List(ctx, session.ID)
725	if err != nil {
726		return nil, fmt.Errorf("failed to list messages: %w", err)
727	}
728
729	if session.SummaryMessageID != "" {
730		summaryMsgInex := -1
731		for i, msg := range msgs {
732			if msg.ID == session.SummaryMessageID {
733				summaryMsgInex = i
734				break
735			}
736		}
737		if summaryMsgInex != -1 {
738			msgs = msgs[summaryMsgInex:]
739			msgs[0].Role = message.User
740		}
741	}
742	return msgs, nil
743}
744
745func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
746	if prompt == "" {
747		return
748	}
749
750	var maxOutput int64 = 40
751	if a.smallModel.CatwalkCfg.CanReason {
752		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
753	}
754
755	agent := fantasy.NewAgent(a.smallModel.Model,
756		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
757		fantasy.WithMaxOutputTokens(maxOutput),
758	)
759
760	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
761		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
762		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
763			prepared.Messages = options.Messages
764			if a.systemPromptPrefix != "" {
765				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
766			}
767			return callContext, prepared, nil
768		},
769	})
770	if err != nil {
771		slog.Error("error generating title", "err", err)
772		return
773	}
774
775	title := resp.Response.Content.Text()
776
777	title = strings.ReplaceAll(title, "\n", " ")
778
779	// Remove thinking tags if present.
780	if idx := strings.Index(title, "</think>"); idx > 0 {
781		title = title[idx+len("</think>"):]
782	}
783
784	title = strings.TrimSpace(title)
785	if title == "" {
786		slog.Warn("failed to generate title", "warn", "empty title")
787		return
788	}
789
790	session.Title = title
791
792	var openrouterCost *float64
793	for _, step := range resp.Steps {
794		stepCost := a.openrouterCost(step.ProviderMetadata)
795		if stepCost != nil {
796			newCost := *stepCost
797			if openrouterCost != nil {
798				newCost += *openrouterCost
799			}
800			openrouterCost = &newCost
801		}
802	}
803
804	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
805	_, saveErr := a.sessions.Save(ctx, *session)
806	if saveErr != nil {
807		slog.Error("failed to save session title & usage", "error", saveErr)
808		return
809	}
810}
811
812func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
813	openrouterMetadata, ok := metadata[openrouter.Name]
814	if !ok {
815		return nil
816	}
817
818	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
819	if !ok {
820		return nil
821	}
822	return &opts.Usage.Cost
823}
824
825func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
826	modelConfig := model.CatwalkCfg
827	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
828		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
829		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
830		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
831
832	a.eventTokensUsed(session.ID, model, usage, cost)
833
834	if overrideCost != nil {
835		session.Cost += *overrideCost
836	} else {
837		session.Cost += cost
838	}
839
840	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
841	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
842}
843
844func (a *sessionAgent) Cancel(sessionID string) {
845	// Cancel regular requests.
846	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
847		slog.Info("Request cancellation initiated", "session_id", sessionID)
848		cancel()
849	}
850
851	// Also check for summarize requests.
852	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
853		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
854		cancel()
855	}
856
857	if a.QueuedPrompts(sessionID) > 0 {
858		slog.Info("Clearing queued prompts", "session_id", sessionID)
859		a.messageQueue.Del(sessionID)
860	}
861}
862
863func (a *sessionAgent) ClearQueue(sessionID string) {
864	if a.QueuedPrompts(sessionID) > 0 {
865		slog.Info("Clearing queued prompts", "session_id", sessionID)
866		a.messageQueue.Del(sessionID)
867	}
868}
869
870func (a *sessionAgent) CancelAll() {
871	if !a.IsBusy() {
872		return
873	}
874	for key := range a.activeRequests.Seq2() {
875		a.Cancel(key) // key is sessionID
876	}
877
878	timeout := time.After(5 * time.Second)
879	for a.IsBusy() {
880		select {
881		case <-timeout:
882			return
883		default:
884			time.Sleep(200 * time.Millisecond)
885		}
886	}
887}
888
889func (a *sessionAgent) IsBusy() bool {
890	var busy bool
891	for cancelFunc := range a.activeRequests.Seq() {
892		if cancelFunc != nil {
893			busy = true
894			break
895		}
896	}
897	return busy
898}
899
900func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
901	_, busy := a.activeRequests.Get(sessionID)
902	return busy
903}
904
905func (a *sessionAgent) QueuedPrompts(sessionID string) int {
906	l, ok := a.messageQueue.Get(sessionID)
907	if !ok {
908		return 0
909	}
910	return len(l)
911}
912
913func (a *sessionAgent) SetModels(large Model, small Model) {
914	a.largeModel = large
915	a.smallModel = small
916}
917
918func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
919	a.tools = tools
920}
921
922func (a *sessionAgent) Model() Model {
923	return a.largeModel
924}
925
926// executePromptSubmitHook executes the user-prompt-submit hook and applies modifications to the call.
927// Only runs for main agent (not sub-agents).
928func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message.Message, isFirstMessage bool) error {
929	// Skip if sub-agent or no hooks manager.
930	if a.isSubAgent || a.hooksManager == nil {
931		return nil
932	}
933
934	// Convert attachments to file paths.
935	attachmentPaths := make([]string, len(msg.BinaryContent()))
936	for i, att := range msg.BinaryContent() {
937		attachmentPaths[i] = att.Path
938	}
939
940	hookResult, err := a.hooksManager.ExecuteUserPromptSubmit(ctx, msg.SessionID, a.workingDir, hooks.UserPromptSubmitData{
941		Prompt:         msg.Content().Text,
942		Attachments:    attachmentPaths,
943		Model:          a.largeModel.CatwalkCfg.ID,
944		Provider:       a.largeModel.Model.Provider(),
945		IsFirstMessage: isFirstMessage,
946	})
947	if err != nil {
948		return fmt.Errorf("hook execution failed: %w", err)
949	}
950
951	// Apply hook modifications to the prompt.
952	if hookResult.ModifiedPrompt != nil {
953		for i, part := range msg.Parts {
954			if _, ok := part.(message.TextContent); ok {
955				msg.Parts[i] = message.TextContent{Text: *hookResult.ModifiedPrompt}
956			}
957		}
958	}
959	msg.AddHookResult(hookResult)
960	err = a.messages.Update(ctx, *msg)
961	if err != nil {
962		return err
963	}
964	// If hook returned Continue: false, stop execution.
965	if !hookResult.Continue {
966		return ErrHookExecutionStop
967	}
968	return nil
969}