agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/json"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"github.com/charmbracelet/catwalk/pkg/catwalk"
  31	"github.com/charmbracelet/crush/internal/agent/tools"
  32	"github.com/charmbracelet/crush/internal/config"
  33	"github.com/charmbracelet/crush/internal/csync"
  34	"github.com/charmbracelet/crush/internal/hooks"
  35	"github.com/charmbracelet/crush/internal/message"
  36	"github.com/charmbracelet/crush/internal/permission"
  37	"github.com/charmbracelet/crush/internal/session"
  38	"github.com/charmbracelet/crush/internal/stringext"
  39)
  40
  41//go:embed templates/title.md
  42var titlePrompt []byte
  43
  44//go:embed templates/summary.md
  45var summaryPrompt []byte
  46
  47type SessionAgentCall struct {
  48	SessionID        string
  49	Prompt           string
  50	ProviderOptions  fantasy.ProviderOptions
  51	Attachments      []message.Attachment
  52	MaxOutputTokens  int64
  53	Temperature      *float64
  54	TopP             *float64
  55	TopK             *int64
  56	FrequencyPenalty *float64
  57	PresencePenalty  *float64
  58}
  59
  60type SessionAgent interface {
  61	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  62	SetModels(large Model, small Model)
  63	SetTools(tools []fantasy.AgentTool)
  64	Cancel(sessionID string)
  65	CancelAll()
  66	IsSessionBusy(sessionID string) bool
  67	IsBusy() bool
  68	QueuedPrompts(sessionID string) int
  69	ClearQueue(sessionID string)
  70	Summarize(context.Context, string, fantasy.ProviderOptions) error
  71	Model() Model
  72}
  73
  74type Model struct {
  75	Model      fantasy.LanguageModel
  76	CatwalkCfg catwalk.Model
  77	ModelCfg   config.SelectedModel
  78}
  79
  80type sessionAgent struct {
  81	largeModel           Model
  82	smallModel           Model
  83	systemPromptPrefix   string
  84	systemPrompt         string
  85	tools                []fantasy.AgentTool
  86	sessions             session.Service
  87	messages             message.Service
  88	disableAutoSummarize bool
  89	isYolo               bool
  90	isSubAgent           bool
  91	hooksManager         hooks.Manager
  92	workingDir           string
  93
  94	messageQueue   *csync.Map[string, []SessionAgentCall]
  95	activeRequests *csync.Map[string, context.CancelFunc]
  96}
  97
  98type SessionAgentOptions struct {
  99	LargeModel           Model
 100	SmallModel           Model
 101	SystemPromptPrefix   string
 102	SystemPrompt         string
 103	DisableAutoSummarize bool
 104	IsYolo               bool
 105	IsSubAgent           bool
 106	HooksManager         hooks.Manager
 107	WorkingDir           string
 108	Sessions             session.Service
 109	Messages             message.Service
 110	Tools                []fantasy.AgentTool
 111}
 112
 113func NewSessionAgent(
 114	opts SessionAgentOptions,
 115) SessionAgent {
 116	return &sessionAgent{
 117		largeModel:           opts.LargeModel,
 118		smallModel:           opts.SmallModel,
 119		systemPromptPrefix:   opts.SystemPromptPrefix,
 120		systemPrompt:         opts.SystemPrompt,
 121		sessions:             opts.Sessions,
 122		messages:             opts.Messages,
 123		disableAutoSummarize: opts.DisableAutoSummarize,
 124		tools:                opts.Tools,
 125		isYolo:               opts.IsYolo,
 126		isSubAgent:           opts.IsSubAgent,
 127		hooksManager:         opts.HooksManager,
 128		workingDir:           opts.WorkingDir,
 129		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 130		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 131	}
 132}
 133
 134func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 135	if call.Prompt == "" {
 136		return nil, ErrEmptyPrompt
 137	}
 138	if call.SessionID == "" {
 139		return nil, ErrSessionMissing
 140	}
 141
 142	// Queue the message if busy
 143	if a.IsSessionBusy(call.SessionID) {
 144		existing, ok := a.messageQueue.Get(call.SessionID)
 145		if !ok {
 146			existing = []SessionAgentCall{}
 147		}
 148		existing = append(existing, call)
 149		a.messageQueue.Set(call.SessionID, existing)
 150		return nil, nil
 151	}
 152
 153	if len(a.tools) > 0 {
 154		// Add Anthropic caching to the last tool.
 155		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 156	}
 157
 158	agent := fantasy.NewAgent(
 159		a.largeModel.Model,
 160		fantasy.WithSystemPrompt(a.systemPrompt),
 161		fantasy.WithTools(a.tools...),
 162	)
 163
 164	sessionLock := sync.Mutex{}
 165	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 166	if err != nil {
 167		return nil, fmt.Errorf("failed to get session: %w", err)
 168	}
 169
 170	msgs, err := a.getSessionMessages(ctx, currentSession)
 171	if err != nil {
 172		return nil, fmt.Errorf("failed to get session messages: %w", err)
 173	}
 174
 175	var wg sync.WaitGroup
 176	// Generate title if first message.
 177	if len(msgs) == 0 {
 178		wg.Go(func() {
 179			sessionLock.Lock()
 180			a.generateTitle(ctx, &currentSession, call.Prompt)
 181			sessionLock.Unlock()
 182		})
 183	}
 184
 185	// Add the user message to the session.
 186	msg, err := a.createUserMessage(ctx, call)
 187	if err != nil {
 188		return nil, err
 189	}
 190
 191	// Add the session to the context.
 192	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 193
 194	genCtx, cancel := context.WithCancel(ctx)
 195	a.activeRequests.Set(call.SessionID, cancel)
 196
 197	defer cancel()
 198	defer a.activeRequests.Del(call.SessionID)
 199
 200	// Track completion reason for stop hook
 201	var stopReason string
 202	defer func() {
 203		if stopReason != "" {
 204			a.executeStopHook(ctx, call.SessionID, stopReason)
 205		}
 206	}()
 207
 208	// create the agent message asap to show loading
 209	var currentAssistant *message.Message
 210	assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
 211		Role:     message.Assistant,
 212		Parts:    []message.ContentPart{},
 213		Model:    a.largeModel.ModelCfg.Model,
 214		Provider: a.largeModel.ModelCfg.Provider,
 215	})
 216	if err != nil {
 217		return nil, err
 218	}
 219
 220	currentAssistant = &assistantMessage
 221
 222	hookErr := a.executePromptSubmitHook(genCtx, &msg, len(msgs) == 0)
 223	if hookErr != nil {
 224		stopReason = "error"
 225		// Delete the assistant message
 226		// use the ctx since this could be a cancellation
 227		deleteErr := a.messages.Delete(ctx, currentAssistant.ID)
 228		return nil, cmp.Or(deleteErr, hookErr)
 229	}
 230
 231	history, files := a.preparePrompt(msgs, call.Attachments...)
 232
 233	startTime := time.Now()
 234	a.eventPromptSent(call.SessionID)
 235
 236	// Map to store post-tool-use hook results for OnToolResult callback
 237	postToolHookResults := csync.NewMap[string, hooks.HookResult]()
 238
 239	var shouldSummarize bool
 240	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 241		Prompt:           msg.ContentWithHookContext(),
 242		Files:            files,
 243		Messages:         history,
 244		ProviderOptions:  call.ProviderOptions,
 245		MaxOutputTokens:  &call.MaxOutputTokens,
 246		TopP:             call.TopP,
 247		Temperature:      call.Temperature,
 248		PresencePenalty:  call.PresencePenalty,
 249		TopK:             call.TopK,
 250		FrequencyPenalty: call.FrequencyPenalty,
 251		// Before each step create a new assistant message.
 252		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 253			// only add new assistant message when its not the first step
 254			if options.StepNumber != 0 {
 255				var assistantMsg message.Message
 256				assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 257					Role:     message.Assistant,
 258					Model:    a.largeModel.ModelCfg.Model,
 259					Provider: a.largeModel.ModelCfg.Provider,
 260				})
 261				currentAssistant = &assistantMsg
 262				// create the message first so we show loading asap
 263				if err != nil {
 264					return callContext, prepared, err
 265				}
 266			}
 267			prepared.Messages = options.Messages
 268			// Reset all cached items.
 269			for i := range prepared.Messages {
 270				prepared.Messages[i].ProviderOptions = nil
 271			}
 272
 273			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 274			a.messageQueue.Del(call.SessionID)
 275			for _, queued := range queuedCalls {
 276				userMessage, createErr := a.createUserMessage(callContext, queued)
 277				if createErr != nil {
 278					return callContext, prepared, createErr
 279				}
 280
 281				hookErr := a.executePromptSubmitHook(ctx, &msg, len(msgs) == 0)
 282				if hookErr != nil {
 283					return callContext, prepared, hookErr
 284				}
 285
 286				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 287			}
 288
 289			lastSystemRoleInx := 0
 290			systemMessageUpdated := false
 291			for i, msg := range prepared.Messages {
 292				// Only add cache control to the last message.
 293				if msg.Role == fantasy.MessageRoleSystem {
 294					lastSystemRoleInx = i
 295				} else if !systemMessageUpdated {
 296					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 297					systemMessageUpdated = true
 298				}
 299				// Than add cache control to the last 2 messages.
 300				if i > len(prepared.Messages)-3 {
 301					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 302				}
 303			}
 304
 305			if a.systemPromptPrefix != "" {
 306				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 307			}
 308
 309			callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
 310			return callContext, prepared, err
 311		},
 312		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 313			currentAssistant.AppendReasoningContent(reasoning.Text)
 314			return a.messages.Update(genCtx, *currentAssistant)
 315		},
 316		OnReasoningDelta: func(id string, text string) error {
 317			currentAssistant.AppendReasoningContent(text)
 318			return a.messages.Update(genCtx, *currentAssistant)
 319		},
 320		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 321			// handle anthropic signature
 322			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 323				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 324					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 325				}
 326			}
 327			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 328				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 329					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 330				}
 331			}
 332			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 333				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 334					currentAssistant.SetReasoningResponsesData(reasoning)
 335				}
 336			}
 337			currentAssistant.FinishThinking()
 338			return a.messages.Update(genCtx, *currentAssistant)
 339		},
 340		OnTextDelta: func(id string, text string) error {
 341			// Strip leading newline from initial text content. This is is
 342			// particularly important in non-interactive mode where leading
 343			// newlines are very visible.
 344			if len(currentAssistant.Parts) == 0 {
 345				text = strings.TrimPrefix(text, "\n")
 346			}
 347
 348			currentAssistant.AppendContent(text)
 349			return a.messages.Update(genCtx, *currentAssistant)
 350		},
 351		OnToolInputStart: func(id string, toolName string) error {
 352			toolCall := message.ToolCall{
 353				ID:               id,
 354				Name:             toolName,
 355				ProviderExecuted: false,
 356				Finished:         false,
 357			}
 358			currentAssistant.AddToolCall(toolCall)
 359			return a.messages.Update(genCtx, *currentAssistant)
 360		},
 361		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 362			// TODO: implement
 363		},
 364		OnToolCall: func(tc fantasy.ToolCallContent) error {
 365			toolCall := message.ToolCall{
 366				ID:               tc.ToolCallID,
 367				Name:             tc.ToolName,
 368				Input:            tc.Input,
 369				ProviderExecuted: false,
 370				Finished:         true,
 371			}
 372			currentAssistant.AddToolCall(toolCall)
 373			return a.messages.Update(genCtx, *currentAssistant)
 374		},
 375		PreToolExecute: func(ctx context.Context, toolCall fantasy.ToolCall) (context.Context, *fantasy.ToolCall, error) {
 376			return a.executePreToolUseHook(ctx, call.SessionID, toolCall, currentAssistant)
 377		},
 378		OnToolResult: func(result fantasy.ToolResultContent) error {
 379			var resultContent string
 380			isError := false
 381			switch result.Result.GetType() {
 382			case fantasy.ToolResultContentTypeText:
 383				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
 384				if ok {
 385					resultContent = r.Text
 386				}
 387			case fantasy.ToolResultContentTypeError:
 388				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
 389				if ok {
 390					isError = true
 391					resultContent = r.Error.Error()
 392				}
 393			case fantasy.ToolResultContentTypeMedia:
 394				// TODO: handle this message type
 395			}
 396			toolResult := message.ToolResult{
 397				ToolCallID: result.ToolCallID,
 398				Name:       result.ToolName,
 399				Content:    resultContent,
 400				IsError:    isError,
 401				Metadata:   result.ClientMetadata,
 402			}
 403			// Attach hook result if available
 404			if hookRes, ok := postToolHookResults.Get(result.ToolCallID); ok {
 405				toolResult.HookResult = &hookRes
 406			}
 407			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 408				Role: message.Tool,
 409				Parts: []message.ContentPart{
 410					toolResult,
 411				},
 412			})
 413			if createMsgErr != nil {
 414				return createMsgErr
 415			}
 416			return nil
 417		},
 418		PostToolExecute: func(ctx context.Context, toolCall fantasy.ToolCall, response fantasy.ToolResponse, executionTimeMs int64) (*fantasy.ToolResponse, error) {
 419			modifiedResponse, hookResult, err := a.executePostToolUseHook(ctx, call.SessionID, toolCall, response, executionTimeMs)
 420			if hookResult != nil {
 421				// Store for OnToolResult callback
 422				postToolHookResults.Set(toolCall.ID, *hookResult)
 423			}
 424			return modifiedResponse, err
 425		},
 426		OnStepFinish: func(stepResult fantasy.StepResult) error {
 427			finishReason := message.FinishReasonUnknown
 428			switch stepResult.FinishReason {
 429			case fantasy.FinishReasonLength:
 430				finishReason = message.FinishReasonMaxTokens
 431			case fantasy.FinishReasonStop:
 432				finishReason = message.FinishReasonEndTurn
 433			case fantasy.FinishReasonToolCalls:
 434				finishReason = message.FinishReasonToolUse
 435			}
 436			currentAssistant.AddFinish(finishReason, "", "")
 437			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 438			sessionLock.Lock()
 439			_, sessionErr := a.sessions.Save(genCtx, currentSession)
 440			sessionLock.Unlock()
 441			if sessionErr != nil {
 442				return sessionErr
 443			}
 444			return a.messages.Update(genCtx, *currentAssistant)
 445		},
 446		StopWhen: []fantasy.StopCondition{
 447			func(_ []fantasy.StepResult) bool {
 448				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 449				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 450				remaining := cw - tokens
 451				var threshold int64
 452				if cw > 200_000 {
 453					threshold = 20_000
 454				} else {
 455					threshold = int64(float64(cw) * 0.2)
 456				}
 457				if (remaining <= threshold) && !a.disableAutoSummarize {
 458					shouldSummarize = true
 459					return true
 460				}
 461				return false
 462			},
 463		},
 464	})
 465
 466	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 467
 468	if err != nil {
 469		isCancelErr := errors.Is(err, context.Canceled)
 470		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 471		isHookDenied := errors.Is(err, ErrHookDenied)
 472
 473		// Set stop reason for defer
 474		if isCancelErr {
 475			stopReason = "cancelled"
 476		} else if isPermissionErr || isHookDenied {
 477			stopReason = "permission_denied"
 478		} else {
 479			stopReason = "error"
 480		}
 481
 482		if currentAssistant == nil {
 483			return result, err
 484		}
 485		// Ensure we finish thinking on error to close the reasoning state.
 486		currentAssistant.FinishThinking()
 487		toolCalls := currentAssistant.ToolCalls()
 488		// INFO: we use the parent context here because the genCtx has been cancelled.
 489		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 490		if createErr != nil {
 491			return nil, createErr
 492		}
 493		for _, tc := range toolCalls {
 494			if !tc.Finished {
 495				tc.Finished = true
 496				tc.Input = "{}"
 497				currentAssistant.AddToolCall(tc)
 498				updateErr := a.messages.Update(ctx, *currentAssistant)
 499				if updateErr != nil {
 500					return nil, updateErr
 501				}
 502			}
 503
 504			found := false
 505			for _, msg := range msgs {
 506				if msg.Role == message.Tool {
 507					for _, tr := range msg.ToolResults() {
 508						if tr.ToolCallID == tc.ID {
 509							found = true
 510							break
 511						}
 512					}
 513				}
 514				if found {
 515					break
 516				}
 517			}
 518			if found {
 519				continue
 520			}
 521			content := "There was an error while executing the tool"
 522			if isCancelErr {
 523				content = "Tool execution canceled by user"
 524			} else if isPermissionErr {
 525				content = "User denied permission"
 526			} else if isHookDenied {
 527				content = "Hook denied execution"
 528			}
 529			toolResult := message.ToolResult{
 530				ToolCallID: tc.ID,
 531				Name:       tc.Name,
 532				Content:    content,
 533				IsError:    true,
 534			}
 535			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 536				Role: message.Tool,
 537				Parts: []message.ContentPart{
 538					toolResult,
 539				},
 540			})
 541			if createErr != nil {
 542				return nil, createErr
 543			}
 544		}
 545		var fantasyErr *fantasy.Error
 546		var providerErr *fantasy.ProviderError
 547		const defaultTitle = "Provider Error"
 548		if isCancelErr {
 549			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 550		} else if isPermissionErr {
 551			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 552		} else if isHookDenied {
 553			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook denied execution", "")
 554		} else if errors.As(err, &providerErr) {
 555			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 556		} else if errors.As(err, &fantasyErr) {
 557			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 558		} else {
 559			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 560		}
 561		// Note: we use the parent context here because the genCtx has been
 562		// cancelled.
 563		updateErr := a.messages.Update(ctx, *currentAssistant)
 564		if updateErr != nil {
 565			return nil, updateErr
 566		}
 567		return nil, err
 568	}
 569	wg.Wait()
 570
 571	// Set completion reason for stop hook
 572	stopReason = "completed"
 573
 574	if shouldSummarize {
 575		a.activeRequests.Del(call.SessionID)
 576		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 577			return nil, summarizeErr
 578		}
 579		// If the agent wasn't done...
 580		if len(currentAssistant.ToolCalls()) > 0 {
 581			existing, ok := a.messageQueue.Get(call.SessionID)
 582			if !ok {
 583				existing = []SessionAgentCall{}
 584			}
 585			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 586			existing = append(existing, call)
 587			a.messageQueue.Set(call.SessionID, existing)
 588		}
 589	}
 590
 591	// Release active request before processing queued messages.
 592	a.activeRequests.Del(call.SessionID)
 593	cancel()
 594
 595	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 596	if !ok || len(queuedMessages) == 0 {
 597		return result, err
 598	}
 599	// There are queued messages restart the loop.
 600	firstQueuedMessage := queuedMessages[0]
 601	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 602	return a.Run(ctx, firstQueuedMessage)
 603}
 604
 605func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 606	if a.IsSessionBusy(sessionID) {
 607		return ErrSessionBusy
 608	}
 609
 610	currentSession, err := a.sessions.Get(ctx, sessionID)
 611	if err != nil {
 612		return fmt.Errorf("failed to get session: %w", err)
 613	}
 614	msgs, err := a.getSessionMessages(ctx, currentSession)
 615	if err != nil {
 616		return err
 617	}
 618	if len(msgs) == 0 {
 619		// Nothing to summarize.
 620		return nil
 621	}
 622
 623	aiMsgs, _ := a.preparePrompt(msgs)
 624
 625	genCtx, cancel := context.WithCancel(ctx)
 626	a.activeRequests.Set(sessionID, cancel)
 627	defer a.activeRequests.Del(sessionID)
 628	defer cancel()
 629
 630	agent := fantasy.NewAgent(a.largeModel.Model,
 631		fantasy.WithSystemPrompt(string(summaryPrompt)),
 632	)
 633	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 634		Role:             message.Assistant,
 635		Model:            a.largeModel.Model.Model(),
 636		Provider:         a.largeModel.Model.Provider(),
 637		IsSummaryMessage: true,
 638	})
 639	if err != nil {
 640		return err
 641	}
 642
 643	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 644		Prompt:          "Provide a detailed summary of our conversation above.",
 645		Messages:        aiMsgs,
 646		ProviderOptions: opts,
 647		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 648			prepared.Messages = options.Messages
 649			if a.systemPromptPrefix != "" {
 650				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 651			}
 652			return callContext, prepared, nil
 653		},
 654		OnReasoningDelta: func(id string, text string) error {
 655			summaryMessage.AppendReasoningContent(text)
 656			return a.messages.Update(genCtx, summaryMessage)
 657		},
 658		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 659			// Handle anthropic signature.
 660			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 661				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 662					summaryMessage.AppendReasoningSignature(signature.Signature)
 663				}
 664			}
 665			summaryMessage.FinishThinking()
 666			return a.messages.Update(genCtx, summaryMessage)
 667		},
 668		OnTextDelta: func(id, text string) error {
 669			summaryMessage.AppendContent(text)
 670			return a.messages.Update(genCtx, summaryMessage)
 671		},
 672	})
 673	if err != nil {
 674		isCancelErr := errors.Is(err, context.Canceled)
 675		if isCancelErr {
 676			// User cancelled summarize we need to remove the summary message.
 677			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 678			return deleteErr
 679		}
 680		return err
 681	}
 682
 683	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 684	err = a.messages.Update(genCtx, summaryMessage)
 685	if err != nil {
 686		return err
 687	}
 688
 689	var openrouterCost *float64
 690	for _, step := range resp.Steps {
 691		stepCost := a.openrouterCost(step.ProviderMetadata)
 692		if stepCost != nil {
 693			newCost := *stepCost
 694			if openrouterCost != nil {
 695				newCost += *openrouterCost
 696			}
 697			openrouterCost = &newCost
 698		}
 699	}
 700
 701	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 702
 703	// Just in case, get just the last usage info.
 704	usage := resp.Response.Usage
 705	currentSession.SummaryMessageID = summaryMessage.ID
 706	currentSession.CompletionTokens = usage.OutputTokens
 707	currentSession.PromptTokens = 0
 708	_, err = a.sessions.Save(genCtx, currentSession)
 709	return err
 710}
 711
 712func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 713	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 714		return fantasy.ProviderOptions{}
 715	}
 716	return fantasy.ProviderOptions{
 717		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 718			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 719		},
 720		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 721			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 722		},
 723	}
 724}
 725
 726func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 727	var attachmentParts []message.ContentPart
 728	for _, attachment := range call.Attachments {
 729		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 730	}
 731	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 732	parts = append(parts, attachmentParts...)
 733	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 734		Role:  message.User,
 735		Parts: parts,
 736	})
 737	if err != nil {
 738		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 739	}
 740	return msg, nil
 741}
 742
 743func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 744	var history []fantasy.Message
 745	for _, m := range msgs {
 746		if len(m.Parts) == 0 {
 747			continue
 748		}
 749		// Assistant message without content or tool calls (cancelled before it
 750		// returned anything).
 751		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 752			continue
 753		}
 754		history = append(history, m.ToAIMessage()...)
 755	}
 756
 757	var files []fantasy.FilePart
 758	for _, attachment := range attachments {
 759		files = append(files, fantasy.FilePart{
 760			Filename:  attachment.FileName,
 761			Data:      attachment.Content,
 762			MediaType: attachment.MimeType,
 763		})
 764	}
 765
 766	return history, files
 767}
 768
 769func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 770	msgs, err := a.messages.List(ctx, session.ID)
 771	if err != nil {
 772		return nil, fmt.Errorf("failed to list messages: %w", err)
 773	}
 774
 775	if session.SummaryMessageID != "" {
 776		summaryMsgInex := -1
 777		for i, msg := range msgs {
 778			if msg.ID == session.SummaryMessageID {
 779				summaryMsgInex = i
 780				break
 781			}
 782		}
 783		if summaryMsgInex != -1 {
 784			msgs = msgs[summaryMsgInex:]
 785			msgs[0].Role = message.User
 786		}
 787	}
 788	return msgs, nil
 789}
 790
 791func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
 792	if prompt == "" {
 793		return
 794	}
 795
 796	var maxOutput int64 = 40
 797	if a.smallModel.CatwalkCfg.CanReason {
 798		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 799	}
 800
 801	agent := fantasy.NewAgent(a.smallModel.Model,
 802		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 803		fantasy.WithMaxOutputTokens(maxOutput),
 804	)
 805
 806	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 807		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 808		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 809			prepared.Messages = options.Messages
 810			if a.systemPromptPrefix != "" {
 811				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 812			}
 813			return callContext, prepared, nil
 814		},
 815	})
 816	if err != nil {
 817		slog.Error("error generating title", "err", err)
 818		return
 819	}
 820
 821	title := resp.Response.Content.Text()
 822
 823	title = strings.ReplaceAll(title, "\n", " ")
 824
 825	// Remove thinking tags if present.
 826	if idx := strings.Index(title, "</think>"); idx > 0 {
 827		title = title[idx+len("</think>"):]
 828	}
 829
 830	title = strings.TrimSpace(title)
 831	if title == "" {
 832		slog.Warn("failed to generate title", "warn", "empty title")
 833		return
 834	}
 835
 836	session.Title = title
 837
 838	var openrouterCost *float64
 839	for _, step := range resp.Steps {
 840		stepCost := a.openrouterCost(step.ProviderMetadata)
 841		if stepCost != nil {
 842			newCost := *stepCost
 843			if openrouterCost != nil {
 844				newCost += *openrouterCost
 845			}
 846			openrouterCost = &newCost
 847		}
 848	}
 849
 850	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
 851	_, saveErr := a.sessions.Save(ctx, *session)
 852	if saveErr != nil {
 853		slog.Error("failed to save session title & usage", "error", saveErr)
 854		return
 855	}
 856}
 857
 858func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 859	openrouterMetadata, ok := metadata[openrouter.Name]
 860	if !ok {
 861		return nil
 862	}
 863
 864	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 865	if !ok {
 866		return nil
 867	}
 868	return &opts.Usage.Cost
 869}
 870
 871func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 872	modelConfig := model.CatwalkCfg
 873	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 874		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 875		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 876		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 877
 878	a.eventTokensUsed(session.ID, model, usage, cost)
 879
 880	if overrideCost != nil {
 881		session.Cost += *overrideCost
 882	} else {
 883		session.Cost += cost
 884	}
 885
 886	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 887	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 888}
 889
 890func (a *sessionAgent) Cancel(sessionID string) {
 891	// Cancel regular requests.
 892	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 893		slog.Info("Request cancellation initiated", "session_id", sessionID)
 894		cancel()
 895	}
 896
 897	// Also check for summarize requests.
 898	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 899		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 900		cancel()
 901	}
 902
 903	if a.QueuedPrompts(sessionID) > 0 {
 904		slog.Info("Clearing queued prompts", "session_id", sessionID)
 905		a.messageQueue.Del(sessionID)
 906	}
 907}
 908
 909func (a *sessionAgent) ClearQueue(sessionID string) {
 910	if a.QueuedPrompts(sessionID) > 0 {
 911		slog.Info("Clearing queued prompts", "session_id", sessionID)
 912		a.messageQueue.Del(sessionID)
 913	}
 914}
 915
 916func (a *sessionAgent) CancelAll() {
 917	if !a.IsBusy() {
 918		return
 919	}
 920	for key := range a.activeRequests.Seq2() {
 921		a.Cancel(key) // key is sessionID
 922	}
 923
 924	timeout := time.After(5 * time.Second)
 925	for a.IsBusy() {
 926		select {
 927		case <-timeout:
 928			return
 929		default:
 930			time.Sleep(200 * time.Millisecond)
 931		}
 932	}
 933}
 934
 935func (a *sessionAgent) IsBusy() bool {
 936	var busy bool
 937	for cancelFunc := range a.activeRequests.Seq() {
 938		if cancelFunc != nil {
 939			busy = true
 940			break
 941		}
 942	}
 943	return busy
 944}
 945
 946func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 947	_, busy := a.activeRequests.Get(sessionID)
 948	return busy
 949}
 950
 951func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 952	l, ok := a.messageQueue.Get(sessionID)
 953	if !ok {
 954		return 0
 955	}
 956	return len(l)
 957}
 958
 959func (a *sessionAgent) SetModels(large Model, small Model) {
 960	a.largeModel = large
 961	a.smallModel = small
 962}
 963
 964func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 965	a.tools = tools
 966}
 967
 968func (a *sessionAgent) Model() Model {
 969	return a.largeModel
 970}
 971
 972// executePromptSubmitHook executes the user-prompt-submit hook and applies modifications to the call.
 973// Only runs for main agent (not sub-agents).
 974func (a *sessionAgent) executePromptSubmitHook(ctx context.Context, msg *message.Message, isFirstMessage bool) error {
 975	// Skip if sub-agent or no hooks manager.
 976	if a.isSubAgent || a.hooksManager == nil {
 977		return nil
 978	}
 979
 980	// Convert attachments to file paths.
 981	attachmentPaths := make([]string, len(msg.BinaryContent()))
 982	for i, att := range msg.BinaryContent() {
 983		attachmentPaths[i] = att.Path
 984	}
 985
 986	hookResult, err := a.hooksManager.ExecuteUserPromptSubmit(ctx, msg.SessionID, a.workingDir, hooks.UserPromptSubmitData{
 987		Prompt:         msg.Content().Text,
 988		Attachments:    attachmentPaths,
 989		Model:          a.largeModel.CatwalkCfg.ID,
 990		Provider:       a.largeModel.Model.Provider(),
 991		IsFirstMessage: isFirstMessage,
 992	})
 993	if err != nil {
 994		return fmt.Errorf("hook execution failed: %w", err)
 995	}
 996
 997	// Apply hook modifications to the prompt.
 998	if hookResult.ModifiedPrompt != nil {
 999		for i, part := range msg.Parts {
1000			if _, ok := part.(message.TextContent); ok {
1001				msg.Parts[i] = message.TextContent{Text: *hookResult.ModifiedPrompt}
1002			}
1003		}
1004	}
1005	msg.AddHookResult(hookResult)
1006	err = a.messages.Update(ctx, *msg)
1007	if err != nil {
1008		return err
1009	}
1010	// If hook returned Continue: false, stop execution.
1011	if !hookResult.Continue {
1012		return ErrHookExecutionStop
1013	}
1014	return nil
1015}
1016
1017// executePreToolUseHook executes the pre-tool-use hook and applies modifications.
1018// Only runs for main agent (not sub-agents).
1019func (a *sessionAgent) executePreToolUseHook(ctx context.Context, sessionID string, toolCall fantasy.ToolCall, currentAssistant *message.Message) (context.Context, *fantasy.ToolCall, error) {
1020	// Skip if sub-agent or no hooks manager.
1021	if a.isSubAgent || a.hooksManager == nil {
1022		return ctx, nil, nil
1023	}
1024
1025	// Parse tool input to map
1026	var toolInput map[string]any
1027	if err := json.Unmarshal([]byte(toolCall.Input), &toolInput); err != nil {
1028		// If we can't parse the input, skip the hook
1029		return ctx, nil, nil
1030	}
1031
1032	hookResult, err := a.hooksManager.ExecutePreToolUse(ctx, sessionID, a.workingDir, hooks.PreToolUseData{
1033		ToolName:   toolCall.Name,
1034		ToolCallID: toolCall.ID,
1035		ToolInput:  toolInput,
1036	})
1037	if err != nil {
1038		return ctx, nil, fmt.Errorf("pre-tool-use hook execution failed: %w", err)
1039	}
1040
1041	// Store hook result in the current assistant's tool call
1042	for _, tc := range currentAssistant.ToolCalls() {
1043		if tc.ID == toolCall.ID {
1044			tc.HookResult = &hookResult
1045			currentAssistant.AddToolCall(tc)
1046			if updateErr := a.messages.Update(ctx, *currentAssistant); updateErr != nil {
1047				slog.Error("failed to update assistant message with pre-hook result", "error", updateErr)
1048			}
1049			break
1050		}
1051	}
1052
1053	// If hook returned Continue: false, deny execution.
1054	if !hookResult.Continue {
1055		return ctx, nil, ErrHookDenied
1056	}
1057
1058	// Set permission in context for tools to use
1059	if hookResult.Permission != "" {
1060		ctx = tools.SetHookPermissionInContext(ctx, hookResult.Permission)
1061	}
1062
1063	// Apply modified input if present.
1064	if len(hookResult.ModifiedInput) > 0 {
1065		// Merge modified input with original
1066		for k, v := range hookResult.ModifiedInput {
1067			toolInput[k] = v
1068		}
1069
1070		modifiedInputJSON, err := json.Marshal(toolInput)
1071		if err != nil {
1072			return ctx, nil, fmt.Errorf("failed to marshal modified input: %w", err)
1073		}
1074
1075		modifiedCall := toolCall
1076		modifiedCall.Input = string(modifiedInputJSON)
1077		return ctx, &modifiedCall, nil
1078	}
1079
1080	return ctx, nil, nil
1081}
1082
1083// executePostToolUseHook executes the post-tool-use hook and applies modifications.
1084// Only runs for main agent (not sub-agents).
1085func (a *sessionAgent) executePostToolUseHook(ctx context.Context, sessionID string, toolCall fantasy.ToolCall, response fantasy.ToolResponse, executionTimeMs int64) (*fantasy.ToolResponse, *hooks.HookResult, error) {
1086	// Skip if sub-agent or no hooks manager.
1087	if a.isSubAgent || a.hooksManager == nil {
1088		return nil, nil, nil
1089	}
1090
1091	// Parse tool input to map
1092	var toolInput map[string]any
1093	if err := json.Unmarshal([]byte(toolCall.Input), &toolInput); err != nil {
1094		return nil, nil, nil
1095	}
1096
1097	// Parse tool output to map
1098	toolOutput := map[string]any{
1099		"success": !response.IsError,
1100		"content": response.Content,
1101	}
1102	if response.Metadata != "" {
1103		toolOutput["metadata"] = response.Metadata
1104	}
1105
1106	hookResult, err := a.hooksManager.ExecutePostToolUse(ctx, sessionID, a.workingDir, hooks.PostToolUseData{
1107		ToolName:        toolCall.Name,
1108		ToolCallID:      toolCall.ID,
1109		ToolInput:       toolInput,
1110		ToolOutput:      toolOutput,
1111		ExecutionTimeMs: executionTimeMs,
1112	})
1113	if err != nil {
1114		return nil, nil, fmt.Errorf("post-tool-use hook execution failed: %w", err)
1115	}
1116
1117	// If hook returned Continue: false, return error to stop execution.
1118	if !hookResult.Continue {
1119		return nil, &hookResult, ErrHookDenied
1120	}
1121
1122	// Apply modified output if present.
1123	if len(hookResult.ModifiedOutput) > 0 {
1124		modifiedResponse := response
1125
1126		// Apply modifications
1127		if content, ok := hookResult.ModifiedOutput["content"].(string); ok {
1128			modifiedResponse.Content = content
1129		}
1130		if success, ok := hookResult.ModifiedOutput["success"].(bool); ok {
1131			modifiedResponse.IsError = !success
1132		}
1133		if metadata, ok := hookResult.ModifiedOutput["metadata"].(string); ok {
1134			modifiedResponse.Metadata = metadata
1135		}
1136
1137		return &modifiedResponse, &hookResult, nil
1138	}
1139
1140	return nil, &hookResult, nil
1141}
1142
1143// executeStopHook executes the stop hook when agent loop ends.
1144// Only runs for main agent (not sub-agents). Errors are logged but don't fail.
1145func (a *sessionAgent) executeStopHook(ctx context.Context, sessionID, reason string) {
1146	// Skip if sub-agent or no hooks manager.
1147	if a.isSubAgent || a.hooksManager == nil {
1148		return
1149	}
1150
1151	// Use a fresh context with timeout to ensure hook runs even if parent is cancelled
1152	hookCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
1153	defer cancel()
1154
1155	_, err := a.hooksManager.ExecuteStop(hookCtx, sessionID, a.workingDir, hooks.StopData{
1156		Reason: reason,
1157	})
1158	if err != nil {
1159		slog.Error("stop hook execution failed", "session_id", sessionID, "reason", reason, "error", err)
1160	}
1161}