agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"github.com/charmbracelet/catwalk/pkg/catwalk"
 30	"github.com/charmbracelet/crush/internal/agent/tools"
 31	"github.com/charmbracelet/crush/internal/config"
 32	"github.com/charmbracelet/crush/internal/csync"
 33	"github.com/charmbracelet/crush/internal/hooks"
 34	"github.com/charmbracelet/crush/internal/message"
 35	"github.com/charmbracelet/crush/internal/permission"
 36	"github.com/charmbracelet/crush/internal/session"
 37	"github.com/charmbracelet/crush/internal/stringext"
 38)
 39
 40//go:embed templates/title.md
 41var titlePrompt []byte
 42
 43//go:embed templates/summary.md
 44var summaryPrompt []byte
 45
 46type SessionAgentCall struct {
 47	SessionID        string
 48	Prompt           string
 49	ProviderOptions  fantasy.ProviderOptions
 50	Attachments      []message.Attachment
 51	MaxOutputTokens  int64
 52	Temperature      *float64
 53	TopP             *float64
 54	TopK             *int64
 55	FrequencyPenalty *float64
 56	PresencePenalty  *float64
 57}
 58
 59type SessionAgent interface {
 60	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 61	SetModels(large Model, small Model)
 62	SetTools(tools []fantasy.AgentTool)
 63	Cancel(sessionID string)
 64	CancelAll()
 65	IsSessionBusy(sessionID string) bool
 66	IsBusy() bool
 67	QueuedPrompts(sessionID string) int
 68	ClearQueue(sessionID string)
 69	Summarize(context.Context, string, fantasy.ProviderOptions) error
 70	Model() Model
 71}
 72
 73type Model struct {
 74	Model      fantasy.LanguageModel
 75	CatwalkCfg catwalk.Model
 76	ModelCfg   config.SelectedModel
 77}
 78
 79type sessionAgent struct {
 80	largeModel           Model
 81	smallModel           Model
 82	systemPromptPrefix   string
 83	systemPrompt         string
 84	tools                []fantasy.AgentTool
 85	sessions             session.Service
 86	messages             message.Service
 87	disableAutoSummarize bool
 88	isYolo               bool
 89	isSubAgent           bool
 90	hooksManager         hooks.Manager
 91	workingDir           string
 92
 93	messageQueue   *csync.Map[string, []SessionAgentCall]
 94	activeRequests *csync.Map[string, context.CancelFunc]
 95}
 96
 97type SessionAgentOptions struct {
 98	LargeModel           Model
 99	SmallModel           Model
100	SystemPromptPrefix   string
101	SystemPrompt         string
102	DisableAutoSummarize bool
103	IsYolo               bool
104	IsSubAgent           bool
105	HooksManager         hooks.Manager
106	WorkingDir           string
107	Sessions             session.Service
108	Messages             message.Service
109	Tools                []fantasy.AgentTool
110}
111
112func NewSessionAgent(
113	opts SessionAgentOptions,
114) SessionAgent {
115	return &sessionAgent{
116		largeModel:           opts.LargeModel,
117		smallModel:           opts.SmallModel,
118		systemPromptPrefix:   opts.SystemPromptPrefix,
119		systemPrompt:         opts.SystemPrompt,
120		sessions:             opts.Sessions,
121		messages:             opts.Messages,
122		disableAutoSummarize: opts.DisableAutoSummarize,
123		tools:                opts.Tools,
124		isYolo:               opts.IsYolo,
125		isSubAgent:           opts.IsSubAgent,
126		hooksManager:         opts.HooksManager,
127		workingDir:           opts.WorkingDir,
128		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
129		activeRequests:       csync.NewMap[string, context.CancelFunc](),
130	}
131}
132
133func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
134	if call.Prompt == "" {
135		return nil, ErrEmptyPrompt
136	}
137	if call.SessionID == "" {
138		return nil, ErrSessionMissing
139	}
140
141	// Queue the message if busy
142	if a.IsSessionBusy(call.SessionID) {
143		existing, ok := a.messageQueue.Get(call.SessionID)
144		if !ok {
145			existing = []SessionAgentCall{}
146		}
147		existing = append(existing, call)
148		a.messageQueue.Set(call.SessionID, existing)
149		return nil, nil
150	}
151
152	if len(a.tools) > 0 {
153		// Add Anthropic caching to the last tool.
154		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
155	}
156
157	agent := fantasy.NewAgent(
158		a.largeModel.Model,
159		fantasy.WithSystemPrompt(a.systemPrompt),
160		fantasy.WithTools(a.tools...),
161	)
162
163	sessionLock := sync.Mutex{}
164	currentSession, err := a.sessions.Get(ctx, call.SessionID)
165	if err != nil {
166		return nil, fmt.Errorf("failed to get session: %w", err)
167	}
168
169	msgs, err := a.getSessionMessages(ctx, currentSession)
170	if err != nil {
171		return nil, fmt.Errorf("failed to get session messages: %w", err)
172	}
173
174	var wg sync.WaitGroup
175	// Generate title if first message.
176	if len(msgs) == 0 {
177		wg.Go(func() {
178			sessionLock.Lock()
179			a.generateTitle(ctx, &currentSession, call.Prompt)
180			sessionLock.Unlock()
181		})
182	}
183
184	// Add the user message to the session.
185	msg, err := a.createUserMessage(ctx, call)
186	if err != nil {
187		return nil, err
188	}
189
190	// Add the session to the context.
191	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
192
193	genCtx, cancel := context.WithCancel(ctx)
194	a.activeRequests.Set(call.SessionID, cancel)
195
196	defer cancel()
197	defer a.activeRequests.Del(call.SessionID)
198
199	// create the agent message asap to show loading
200	var currentAssistant *message.Message
201	assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
202		Role:     message.Assistant,
203		Parts:    []message.ContentPart{},
204		Model:    a.largeModel.ModelCfg.Model,
205		Provider: a.largeModel.ModelCfg.Provider,
206	})
207	if err != nil {
208		return nil, err
209	}
210
211	currentAssistant = &assistantMessage
212
213	hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0)
214	if hookErr != nil {
215		// Delete the assistant message
216		// use the ctx since this could be a cancellation
217		deleteErr := a.messages.Delete(ctx, currentAssistant.ID)
218		return nil, cmp.Or(deleteErr, hookErr)
219	}
220
221	history, files := a.preparePrompt(msgs, call.Attachments...)
222
223	startTime := time.Now()
224	a.eventPromptSent(call.SessionID)
225
226	var shouldSummarize bool
227	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
228		Prompt:           msg.ContentWithHookContext(),
229		Files:            files,
230		Messages:         history,
231		ProviderOptions:  call.ProviderOptions,
232		MaxOutputTokens:  &call.MaxOutputTokens,
233		TopP:             call.TopP,
234		Temperature:      call.Temperature,
235		PresencePenalty:  call.PresencePenalty,
236		TopK:             call.TopK,
237		FrequencyPenalty: call.FrequencyPenalty,
238		// Before each step create a new assistant message.
239		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
240			// only add new assistant message when its not the first step
241			if options.StepNumber != 0 {
242				var assistantMsg message.Message
243				assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
244					Role:     message.Assistant,
245					Parts:    []message.ContentPart{},
246					Model:    a.largeModel.ModelCfg.Model,
247					Provider: a.largeModel.ModelCfg.Provider,
248				})
249				currentAssistant = &assistantMsg
250				// create the message first so we show loading asap
251				if err != nil {
252					return callContext, prepared, err
253				}
254			}
255			prepared.Messages = options.Messages
256			// Reset all cached items.
257			for i := range prepared.Messages {
258				prepared.Messages[i].ProviderOptions = nil
259			}
260
261			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
262			a.messageQueue.Del(call.SessionID)
263			for _, queued := range queuedCalls {
264				userMessage, createErr := a.createUserMessage(callContext, queued)
265				if createErr != nil {
266					return callContext, prepared, createErr
267				}
268
269				hookErr := a.executePromptSubmitHook(ctx, &msg, len(msgs) == 0)
270				if hookErr != nil {
271					return callContext, prepared, hookErr
272				}
273
274				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
275			}
276
277			lastSystemRoleInx := 0
278			systemMessageUpdated := false
279			for i, msg := range prepared.Messages {
280				// Only add cache control to the last message.
281				if msg.Role == fantasy.MessageRoleSystem {
282					lastSystemRoleInx = i
283				} else if !systemMessageUpdated {
284					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
285					systemMessageUpdated = true
286				}
287				// Than add cache control to the last 2 messages.
288				if i > len(prepared.Messages)-3 {
289					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
290				}
291			}
292
293			if a.systemPromptPrefix != "" {
294				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
295			}
296
297			callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
298			return callContext, prepared, err
299		},
300		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
301			currentAssistant.AppendReasoningContent(reasoning.Text)
302			return a.messages.Update(genCtx, *currentAssistant)
303		},
304		OnReasoningDelta: func(id string, text string) error {
305			currentAssistant.AppendReasoningContent(text)
306			return a.messages.Update(genCtx, *currentAssistant)
307		},
308		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
309			// handle anthropic signature
310			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
311				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
312					currentAssistant.AppendReasoningSignature(reasoning.Signature)
313				}
314			}
315			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
316				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
317					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
318				}
319			}
320			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
321				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
322					currentAssistant.SetReasoningResponsesData(reasoning)
323				}
324			}
325			currentAssistant.FinishThinking()
326			return a.messages.Update(genCtx, *currentAssistant)
327		},
328		OnTextDelta: func(id string, text string) error {
329			// Strip leading newline from initial text content. This is is
330			// particularly important in non-interactive mode where leading
331			// newlines are very visible.
332			if len(currentAssistant.Parts) == 0 {
333				text = strings.TrimPrefix(text, "\n")
334			}
335
336			currentAssistant.AppendContent(text)
337			return a.messages.Update(genCtx, *currentAssistant)
338		},
339		OnToolInputStart: func(id string, toolName string) error {
340			toolCall := message.ToolCall{
341				ID:               id,
342				Name:             toolName,
343				ProviderExecuted: false,
344				Finished:         false,
345			}
346			currentAssistant.AddToolCall(toolCall)
347			return a.messages.Update(genCtx, *currentAssistant)
348		},
349		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
350			// TODO: implement
351		},
352		OnToolCall: func(tc fantasy.ToolCallContent) error {
353			toolCall := message.ToolCall{
354				ID:               tc.ToolCallID,
355				Name:             tc.ToolName,
356				Input:            tc.Input,
357				ProviderExecuted: false,
358				Finished:         true,
359			}
360			currentAssistant.AddToolCall(toolCall)
361			return a.messages.Update(genCtx, *currentAssistant)
362		},
363		OnToolResult: func(result fantasy.ToolResultContent) error {
364			var resultContent string
365			isError := false
366			switch result.Result.GetType() {
367			case fantasy.ToolResultContentTypeText:
368				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
369				if ok {
370					resultContent = r.Text
371				}
372			case fantasy.ToolResultContentTypeError:
373				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
374				if ok {
375					isError = true
376					resultContent = r.Error.Error()
377				}
378			case fantasy.ToolResultContentTypeMedia:
379				// TODO: handle this message type
380			}
381			toolResult := message.ToolResult{
382				ToolCallID: result.ToolCallID,
383				Name:       result.ToolName,
384				Content:    resultContent,
385				IsError:    isError,
386				Metadata:   result.ClientMetadata,
387			}
388			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
389				Role: message.Tool,
390				Parts: []message.ContentPart{
391					toolResult,
392				},
393			})
394			if createMsgErr != nil {
395				return createMsgErr
396			}
397			return nil
398		},
399		OnStepFinish: func(stepResult fantasy.StepResult) error {
400			finishReason := message.FinishReasonUnknown
401			switch stepResult.FinishReason {
402			case fantasy.FinishReasonLength:
403				finishReason = message.FinishReasonMaxTokens
404			case fantasy.FinishReasonStop:
405				finishReason = message.FinishReasonEndTurn
406			case fantasy.FinishReasonToolCalls:
407				finishReason = message.FinishReasonToolUse
408			}
409			currentAssistant.AddFinish(finishReason, "", "")
410			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
411			sessionLock.Lock()
412			_, sessionErr := a.sessions.Save(genCtx, currentSession)
413			sessionLock.Unlock()
414			if sessionErr != nil {
415				return sessionErr
416			}
417			return a.messages.Update(genCtx, *currentAssistant)
418		},
419		StopWhen: []fantasy.StopCondition{
420			func(_ []fantasy.StepResult) bool {
421				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
422				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
423				remaining := cw - tokens
424				var threshold int64
425				if cw > 200_000 {
426					threshold = 20_000
427				} else {
428					threshold = int64(float64(cw) * 0.2)
429				}
430				if (remaining <= threshold) && !a.disableAutoSummarize {
431					shouldSummarize = true
432					return true
433				}
434				return false
435			},
436		},
437	})
438
439	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
440
441	if err != nil {
442		isCancelErr := errors.Is(err, context.Canceled)
443		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
444		if currentAssistant == nil {
445			return result, err
446		}
447		// Ensure we finish thinking on error to close the reasoning state.
448		currentAssistant.FinishThinking()
449		toolCalls := currentAssistant.ToolCalls()
450		// INFO: we use the parent context here because the genCtx has been cancelled.
451		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
452		if createErr != nil {
453			return nil, createErr
454		}
455		for _, tc := range toolCalls {
456			if !tc.Finished {
457				tc.Finished = true
458				tc.Input = "{}"
459				currentAssistant.AddToolCall(tc)
460				updateErr := a.messages.Update(ctx, *currentAssistant)
461				if updateErr != nil {
462					return nil, updateErr
463				}
464			}
465
466			found := false
467			for _, msg := range msgs {
468				if msg.Role == message.Tool {
469					for _, tr := range msg.ToolResults() {
470						if tr.ToolCallID == tc.ID {
471							found = true
472							break
473						}
474					}
475				}
476				if found {
477					break
478				}
479			}
480			if found {
481				continue
482			}
483			content := "There was an error while executing the tool"
484			if isCancelErr {
485				content = "Tool execution canceled by user"
486			} else if isPermissionErr {
487				content = "User denied permission"
488			}
489			toolResult := message.ToolResult{
490				ToolCallID: tc.ID,
491				Name:       tc.Name,
492				Content:    content,
493				IsError:    true,
494			}
495			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
496				Role: message.Tool,
497				Parts: []message.ContentPart{
498					toolResult,
499				},
500			})
501			if createErr != nil {
502				return nil, createErr
503			}
504		}
505		var fantasyErr *fantasy.Error
506		var providerErr *fantasy.ProviderError
507		const defaultTitle = "Provider Error"
508		if isCancelErr {
509			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
510		} else if isPermissionErr {
511			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
512		} else if errors.As(err, &providerErr) {
513			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
514		} else if errors.As(err, &fantasyErr) {
515			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
516		} else {
517			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
518		}
519		// Note: we use the parent context here because the genCtx has been
520		// cancelled.
521		updateErr := a.messages.Update(ctx, *currentAssistant)
522		if updateErr != nil {
523			return nil, updateErr
524		}
525		return nil, err
526	}
527	wg.Wait()
528
529	if shouldSummarize {
530		a.activeRequests.Del(call.SessionID)
531		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
532			return nil, summarizeErr
533		}
534		// If the agent wasn't done...
535		if len(currentAssistant.ToolCalls()) > 0 {
536			existing, ok := a.messageQueue.Get(call.SessionID)
537			if !ok {
538				existing = []SessionAgentCall{}
539			}
540			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
541			existing = append(existing, call)
542			a.messageQueue.Set(call.SessionID, existing)
543		}
544	}
545
546	// Release active request before processing queued messages.
547	a.activeRequests.Del(call.SessionID)
548	cancel()
549
550	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
551	if !ok || len(queuedMessages) == 0 {
552		return result, err
553	}
554	// There are queued messages restart the loop.
555	firstQueuedMessage := queuedMessages[0]
556	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
557	return a.Run(ctx, firstQueuedMessage)
558}
559
560func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
561	if a.IsSessionBusy(sessionID) {
562		return ErrSessionBusy
563	}
564
565	currentSession, err := a.sessions.Get(ctx, sessionID)
566	if err != nil {
567		return fmt.Errorf("failed to get session: %w", err)
568	}
569	msgs, err := a.getSessionMessages(ctx, currentSession)
570	if err != nil {
571		return err
572	}
573	if len(msgs) == 0 {
574		// Nothing to summarize.
575		return nil
576	}
577
578	aiMsgs, _ := a.preparePrompt(msgs)
579
580	genCtx, cancel := context.WithCancel(ctx)
581	a.activeRequests.Set(sessionID, cancel)
582	defer a.activeRequests.Del(sessionID)
583	defer cancel()
584
585	agent := fantasy.NewAgent(a.largeModel.Model,
586		fantasy.WithSystemPrompt(string(summaryPrompt)),
587	)
588	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
589		Role:             message.Assistant,
590		Model:            a.largeModel.Model.Model(),
591		Provider:         a.largeModel.Model.Provider(),
592		IsSummaryMessage: true,
593	})
594	if err != nil {
595		return err
596	}
597
598	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
599		Prompt:          "Provide a detailed summary of our conversation above.",
600		Messages:        aiMsgs,
601		ProviderOptions: opts,
602		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
603			prepared.Messages = options.Messages
604			if a.systemPromptPrefix != "" {
605				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
606			}
607			return callContext, prepared, nil
608		},
609		OnReasoningDelta: func(id string, text string) error {
610			summaryMessage.AppendReasoningContent(text)
611			return a.messages.Update(genCtx, summaryMessage)
612		},
613		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
614			// Handle anthropic signature.
615			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
616				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
617					summaryMessage.AppendReasoningSignature(signature.Signature)
618				}
619			}
620			summaryMessage.FinishThinking()
621			return a.messages.Update(genCtx, summaryMessage)
622		},
623		OnTextDelta: func(id, text string) error {
624			summaryMessage.AppendContent(text)
625			return a.messages.Update(genCtx, summaryMessage)
626		},
627	})
628	if err != nil {
629		isCancelErr := errors.Is(err, context.Canceled)
630		if isCancelErr {
631			// User cancelled summarize we need to remove the summary message.
632			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
633			return deleteErr
634		}
635		return err
636	}
637
638	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
639	err = a.messages.Update(genCtx, summaryMessage)
640	if err != nil {
641		return err
642	}
643
644	var openrouterCost *float64
645	for _, step := range resp.Steps {
646		stepCost := a.openrouterCost(step.ProviderMetadata)
647		if stepCost != nil {
648			newCost := *stepCost
649			if openrouterCost != nil {
650				newCost += *openrouterCost
651			}
652			openrouterCost = &newCost
653		}
654	}
655
656	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
657
658	// Just in case, get just the last usage info.
659	usage := resp.Response.Usage
660	currentSession.SummaryMessageID = summaryMessage.ID
661	currentSession.CompletionTokens = usage.OutputTokens
662	currentSession.PromptTokens = 0
663	_, err = a.sessions.Save(genCtx, currentSession)
664	return err
665}
666
667func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
668	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
669		return fantasy.ProviderOptions{}
670	}
671	return fantasy.ProviderOptions{
672		anthropic.Name: &anthropic.ProviderCacheControlOptions{
673			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
674		},
675		bedrock.Name: &anthropic.ProviderCacheControlOptions{
676			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
677		},
678	}
679}
680
681func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
682	var attachmentParts []message.ContentPart
683	for _, attachment := range call.Attachments {
684		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
685	}
686	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
687	parts = append(parts, attachmentParts...)
688	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
689		Role:  message.User,
690		Parts: parts,
691	})
692	if err != nil {
693		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
694	}
695	return msg, nil
696}
697
698func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
699	var history []fantasy.Message
700	for _, m := range msgs {
701		if len(m.Parts) == 0 {
702			continue
703		}
704		// Assistant message without content or tool calls (cancelled before it
705		// returned anything).
706		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
707			continue
708		}
709		history = append(history, m.ToAIMessage()...)
710	}
711
712	var files []fantasy.FilePart
713	for _, attachment := range attachments {
714		files = append(files, fantasy.FilePart{
715			Filename:  attachment.FileName,
716			Data:      attachment.Content,
717			MediaType: attachment.MimeType,
718		})
719	}
720
721	return history, files
722}
723
724func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
725	msgs, err := a.messages.List(ctx, session.ID)
726	if err != nil {
727		return nil, fmt.Errorf("failed to list messages: %w", err)
728	}
729
730	if session.SummaryMessageID != "" {
731		summaryMsgInex := -1
732		for i, msg := range msgs {
733			if msg.ID == session.SummaryMessageID {
734				summaryMsgInex = i
735				break
736			}
737		}
738		if summaryMsgInex != -1 {
739			msgs = msgs[summaryMsgInex:]
740			msgs[0].Role = message.User
741		}
742	}
743	return msgs, nil
744}
745
746func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
747	if prompt == "" {
748		return
749	}
750
751	var maxOutput int64 = 40
752	if a.smallModel.CatwalkCfg.CanReason {
753		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
754	}
755
756	agent := fantasy.NewAgent(a.smallModel.Model,
757		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
758		fantasy.WithMaxOutputTokens(maxOutput),
759	)
760
761	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
762		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
763		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
764			prepared.Messages = options.Messages
765			if a.systemPromptPrefix != "" {
766				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
767			}
768			return callContext, prepared, nil
769		},
770	})
771	if err != nil {
772		slog.Error("error generating title", "err", err)
773		return
774	}
775
776	title := resp.Response.Content.Text()
777
778	title = strings.ReplaceAll(title, "\n", " ")
779
780	// Remove thinking tags if present.
781	if idx := strings.Index(title, "</think>"); idx > 0 {
782		title = title[idx+len("</think>"):]
783	}
784
785	title = strings.TrimSpace(title)
786	if title == "" {
787		slog.Warn("failed to generate title", "warn", "empty title")
788		return
789	}
790
791	session.Title = title
792
793	var openrouterCost *float64
794	for _, step := range resp.Steps {
795		stepCost := a.openrouterCost(step.ProviderMetadata)
796		if stepCost != nil {
797			newCost := *stepCost
798			if openrouterCost != nil {
799				newCost += *openrouterCost
800			}
801			openrouterCost = &newCost
802		}
803	}
804
805	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
806	_, saveErr := a.sessions.Save(ctx, *session)
807	if saveErr != nil {
808		slog.Error("failed to save session title & usage", "error", saveErr)
809		return
810	}
811}
812
813func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
814	openrouterMetadata, ok := metadata[openrouter.Name]
815	if !ok {
816		return nil
817	}
818
819	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
820	if !ok {
821		return nil
822	}
823	return &opts.Usage.Cost
824}
825
826func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
827	modelConfig := model.CatwalkCfg
828	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
829		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
830		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
831		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
832
833	a.eventTokensUsed(session.ID, model, usage, cost)
834
835	if overrideCost != nil {
836		session.Cost += *overrideCost
837	} else {
838		session.Cost += cost
839	}
840
841	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
842	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
843}
844
845func (a *sessionAgent) Cancel(sessionID string) {
846	// Cancel regular requests.
847	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
848		slog.Info("Request cancellation initiated", "session_id", sessionID)
849		cancel()
850	}
851
852	// Also check for summarize requests.
853	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
854		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
855		cancel()
856	}
857
858	if a.QueuedPrompts(sessionID) > 0 {
859		slog.Info("Clearing queued prompts", "session_id", sessionID)
860		a.messageQueue.Del(sessionID)
861	}
862}
863
864func (a *sessionAgent) ClearQueue(sessionID string) {
865	if a.QueuedPrompts(sessionID) > 0 {
866		slog.Info("Clearing queued prompts", "session_id", sessionID)
867		a.messageQueue.Del(sessionID)
868	}
869}
870
871func (a *sessionAgent) CancelAll() {
872	if !a.IsBusy() {
873		return
874	}
875	for key := range a.activeRequests.Seq2() {
876		a.Cancel(key) // key is sessionID
877	}
878
879	timeout := time.After(5 * time.Second)
880	for a.IsBusy() {
881		select {
882		case <-timeout:
883			return
884		default:
885			time.Sleep(200 * time.Millisecond)
886		}
887	}
888}
889
890func (a *sessionAgent) IsBusy() bool {
891	var busy bool
892	for cancelFunc := range a.activeRequests.Seq() {
893		if cancelFunc != nil {
894			busy = true
895			break
896		}
897	}
898	return busy
899}
900
901func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
902	_, busy := a.activeRequests.Get(sessionID)
903	return busy
904}
905
906func (a *sessionAgent) QueuedPrompts(sessionID string) int {
907	l, ok := a.messageQueue.Get(sessionID)
908	if !ok {
909		return 0
910	}
911	return len(l)
912}
913
914func (a *sessionAgent) SetModels(large Model, small Model) {
915	a.largeModel = large
916	a.smallModel = small
917}
918
919func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
920	a.tools = tools
921}
922
923func (a *sessionAgent) Model() Model {
924	return a.largeModel
925}
926
927// executePromptSubmitHook executes the user-prompt-submit hook and applies modifications to the call.
928// Only runs for main agent (not sub-agents).
929func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message.Message, isFirstMessage bool) error {
930	// Skip if sub-agent or no hooks manager.
931	if a.isSubAgent || a.hooksManager == nil {
932		return nil
933	}
934
935	// Convert attachments to file paths.
936	attachmentPaths := make([]string, len(msg.BinaryContent()))
937	for i, att := range msg.BinaryContent() {
938		attachmentPaths[i] = att.Path
939	}
940
941	hookResult, err := a.hooksManager.ExecuteUserPromptSubmit(ctx, msg.SessionID, a.workingDir, hooks.UserPromptSubmitData{
942		Prompt:         msg.Content().Text,
943		Attachments:    attachmentPaths,
944		Model:          a.largeModel.CatwalkCfg.ID,
945		Provider:       a.largeModel.Model.Provider(),
946		IsFirstMessage: isFirstMessage,
947	})
948	if err != nil {
949		return fmt.Errorf("hook execution failed: %w", err)
950	}
951
952	// Apply hook modifications to the prompt.
953	if hookResult.ModifiedPrompt != nil {
954		for i, part := range msg.Parts {
955			if _, ok := part.(message.TextContent); ok {
956				msg.Parts[i] = message.TextContent{Text: *hookResult.ModifiedPrompt}
957			}
958		}
959	}
960	msg.AddHookResult(hookResult)
961	err = a.messages.Update(ctx, *msg)
962	if err != nil {
963		return err
964	}
965	// If hook returned Continue: false, stop execution.
966	if !hookResult.Continue {
967		return ErrHookExecutionStop
968	}
969	return nil
970}