agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/catwalk/pkg/catwalk"
  26	"charm.land/fantasy"
  27	"charm.land/fantasy/providers/anthropic"
  28	"charm.land/fantasy/providers/bedrock"
  29	"charm.land/fantasy/providers/google"
  30	"charm.land/fantasy/providers/openai"
  31	"charm.land/fantasy/providers/openrouter"
  32	"charm.land/fantasy/providers/vercel"
  33	"charm.land/lipgloss/v2"
  34	"github.com/charmbracelet/crush/internal/agent/hyper"
  35	"github.com/charmbracelet/crush/internal/agent/tools"
  36	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
  37	"github.com/charmbracelet/crush/internal/config"
  38	"github.com/charmbracelet/crush/internal/csync"
  39	"github.com/charmbracelet/crush/internal/message"
  40	"github.com/charmbracelet/crush/internal/permission"
  41	"github.com/charmbracelet/crush/internal/session"
  42	"github.com/charmbracelet/crush/internal/stringext"
  43	"github.com/charmbracelet/x/exp/charmtone"
  44)
  45
  46const (
  47	defaultSessionName = "Untitled Session"
  48
  49	// Constants for auto-summarization thresholds
  50	largeContextWindowThreshold = 200_000
  51	largeContextWindowBuffer    = 20_000
  52	smallContextWindowRatio     = 0.2
  53)
  54
  55//go:embed templates/title.md
  56var titlePrompt []byte
  57
  58//go:embed templates/summary.md
  59var summaryPrompt []byte
  60
  61// Used to remove <think> tags from generated titles.
  62var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  63
  64type SessionAgentCall struct {
  65	SessionID        string
  66	Prompt           string
  67	ProviderOptions  fantasy.ProviderOptions
  68	Attachments      []message.Attachment
  69	MaxOutputTokens  int64
  70	Temperature      *float64
  71	TopP             *float64
  72	TopK             *int64
  73	FrequencyPenalty *float64
  74	PresencePenalty  *float64
  75}
  76
  77type SessionAgent interface {
  78	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  79	SetModels(large Model, small Model)
  80	SetTools(tools []fantasy.AgentTool)
  81	SetSystemPrompt(systemPrompt string)
  82	Cancel(sessionID string)
  83	CancelAll()
  84	IsSessionBusy(sessionID string) bool
  85	IsBusy() bool
  86	QueuedPrompts(sessionID string) int
  87	QueuedPromptsList(sessionID string) []string
  88	ClearQueue(sessionID string)
  89	Summarize(context.Context, string, fantasy.ProviderOptions) error
  90	Model() Model
  91}
  92
  93type Model struct {
  94	Model      fantasy.LanguageModel
  95	CatwalkCfg catwalk.Model
  96	ModelCfg   config.SelectedModel
  97}
  98
  99type sessionAgent struct {
 100	largeModel         *csync.Value[Model]
 101	smallModel         *csync.Value[Model]
 102	systemPromptPrefix *csync.Value[string]
 103	systemPrompt       *csync.Value[string]
 104	tools              *csync.Slice[fantasy.AgentTool]
 105
 106	isSubAgent           bool
 107	sessions             session.Service
 108	messages             message.Service
 109	disableAutoSummarize bool
 110	isYolo               bool
 111
 112	messageQueue   *csync.Map[string, []SessionAgentCall]
 113	activeRequests *csync.Map[string, context.CancelFunc]
 114}
 115
 116type SessionAgentOptions struct {
 117	LargeModel           Model
 118	SmallModel           Model
 119	SystemPromptPrefix   string
 120	SystemPrompt         string
 121	IsSubAgent           bool
 122	DisableAutoSummarize bool
 123	IsYolo               bool
 124	Sessions             session.Service
 125	Messages             message.Service
 126	Tools                []fantasy.AgentTool
 127}
 128
 129func NewSessionAgent(
 130	opts SessionAgentOptions,
 131) SessionAgent {
 132	return &sessionAgent{
 133		largeModel:           csync.NewValue(opts.LargeModel),
 134		smallModel:           csync.NewValue(opts.SmallModel),
 135		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
 136		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 137		isSubAgent:           opts.IsSubAgent,
 138		sessions:             opts.Sessions,
 139		messages:             opts.Messages,
 140		disableAutoSummarize: opts.DisableAutoSummarize,
 141		tools:                csync.NewSliceFrom(opts.Tools),
 142		isYolo:               opts.IsYolo,
 143		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 144		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 145	}
 146}
 147
 148func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 149	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 150		return nil, ErrEmptyPrompt
 151	}
 152	if call.SessionID == "" {
 153		return nil, ErrSessionMissing
 154	}
 155
 156	// Queue the message if busy
 157	if a.IsSessionBusy(call.SessionID) {
 158		existing, ok := a.messageQueue.Get(call.SessionID)
 159		if !ok {
 160			existing = []SessionAgentCall{}
 161		}
 162		existing = append(existing, call)
 163		a.messageQueue.Set(call.SessionID, existing)
 164		return nil, nil
 165	}
 166
 167	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
 168	agentTools := a.tools.Copy()
 169	largeModel := a.largeModel.Get()
 170	systemPrompt := a.systemPrompt.Get()
 171	promptPrefix := a.systemPromptPrefix.Get()
 172	var instructions strings.Builder
 173
 174	for _, server := range mcp.GetStates() {
 175		if server.State != mcp.StateConnected {
 176			continue
 177		}
 178		if s := server.Client.InitializeResult().Instructions; s != "" {
 179			instructions.WriteString(s)
 180			instructions.WriteString("\n\n")
 181		}
 182	}
 183
 184	if s := instructions.String(); s != "" {
 185		systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
 186	}
 187
 188	if len(agentTools) > 0 {
 189		// Add Anthropic caching to the last tool.
 190		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 191	}
 192
 193	agent := fantasy.NewAgent(
 194		largeModel.Model,
 195		fantasy.WithSystemPrompt(systemPrompt),
 196		fantasy.WithTools(agentTools...),
 197	)
 198
 199	sessionLock := sync.Mutex{}
 200	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 201	if err != nil {
 202		return nil, fmt.Errorf("failed to get session: %w", err)
 203	}
 204
 205	msgs, err := a.getSessionMessages(ctx, currentSession)
 206	if err != nil {
 207		return nil, fmt.Errorf("failed to get session messages: %w", err)
 208	}
 209
 210	var wg sync.WaitGroup
 211	// Generate title if first message.
 212	if len(msgs) == 0 {
 213		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 214		wg.Go(func() {
 215			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 216		})
 217	}
 218	defer wg.Wait()
 219
 220	// Add the user message to the session.
 221	_, err = a.createUserMessage(ctx, call)
 222	if err != nil {
 223		return nil, err
 224	}
 225
 226	// Add the session to the context.
 227	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 228
 229	genCtx, cancel := context.WithCancel(ctx)
 230	a.activeRequests.Set(call.SessionID, cancel)
 231
 232	defer cancel()
 233	defer a.activeRequests.Del(call.SessionID)
 234
 235	history, files := a.preparePrompt(msgs, call.Attachments...)
 236
 237	startTime := time.Now()
 238	a.eventPromptSent(call.SessionID)
 239
 240	var currentAssistant *message.Message
 241	var shouldSummarize bool
 242	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 243		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 244		Files:            files,
 245		Messages:         history,
 246		ProviderOptions:  call.ProviderOptions,
 247		MaxOutputTokens:  &call.MaxOutputTokens,
 248		TopP:             call.TopP,
 249		Temperature:      call.Temperature,
 250		PresencePenalty:  call.PresencePenalty,
 251		TopK:             call.TopK,
 252		FrequencyPenalty: call.FrequencyPenalty,
 253		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 254			prepared.Messages = options.Messages
 255			for i := range prepared.Messages {
 256				prepared.Messages[i].ProviderOptions = nil
 257			}
 258
 259			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 260			a.messageQueue.Del(call.SessionID)
 261			for _, queued := range queuedCalls {
 262				userMessage, createErr := a.createUserMessage(callContext, queued)
 263				if createErr != nil {
 264					return callContext, prepared, createErr
 265				}
 266				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 267			}
 268
 269			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 270
 271			lastSystemRoleInx := 0
 272			systemMessageUpdated := false
 273			for i, msg := range prepared.Messages {
 274				// Only add cache control to the last message.
 275				if msg.Role == fantasy.MessageRoleSystem {
 276					lastSystemRoleInx = i
 277				} else if !systemMessageUpdated {
 278					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 279					systemMessageUpdated = true
 280				}
 281				// Than add cache control to the last 2 messages.
 282				if i > len(prepared.Messages)-3 {
 283					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 284				}
 285			}
 286
 287			if promptPrefix != "" {
 288				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 289			}
 290
 291			var assistantMsg message.Message
 292			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 293				Role:     message.Assistant,
 294				Parts:    []message.ContentPart{},
 295				Model:    largeModel.ModelCfg.Model,
 296				Provider: largeModel.ModelCfg.Provider,
 297			})
 298			if err != nil {
 299				return callContext, prepared, err
 300			}
 301			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 302			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
 303			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 304			currentAssistant = &assistantMsg
 305			return callContext, prepared, err
 306		},
 307		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 308			currentAssistant.AppendReasoningContent(reasoning.Text)
 309			return a.messages.Update(genCtx, *currentAssistant)
 310		},
 311		OnReasoningDelta: func(id string, text string) error {
 312			currentAssistant.AppendReasoningContent(text)
 313			return a.messages.Update(genCtx, *currentAssistant)
 314		},
 315		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 316			// handle anthropic signature
 317			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 318				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 319					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 320				}
 321			}
 322			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 323				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 324					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 325				}
 326			}
 327			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 328				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 329					currentAssistant.SetReasoningResponsesData(reasoning)
 330				}
 331			}
 332			currentAssistant.FinishThinking()
 333			return a.messages.Update(genCtx, *currentAssistant)
 334		},
 335		OnTextDelta: func(id string, text string) error {
 336			// Strip leading newline from initial text content. This is is
 337			// particularly important in non-interactive mode where leading
 338			// newlines are very visible.
 339			if len(currentAssistant.Parts) == 0 {
 340				text = strings.TrimPrefix(text, "\n")
 341			}
 342
 343			currentAssistant.AppendContent(text)
 344			return a.messages.Update(genCtx, *currentAssistant)
 345		},
 346		OnToolInputStart: func(id string, toolName string) error {
 347			toolCall := message.ToolCall{
 348				ID:               id,
 349				Name:             toolName,
 350				ProviderExecuted: false,
 351				Finished:         false,
 352			}
 353			currentAssistant.AddToolCall(toolCall)
 354			return a.messages.Update(genCtx, *currentAssistant)
 355		},
 356		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 357			// TODO: implement
 358		},
 359		OnToolCall: func(tc fantasy.ToolCallContent) error {
 360			toolCall := message.ToolCall{
 361				ID:               tc.ToolCallID,
 362				Name:             tc.ToolName,
 363				Input:            tc.Input,
 364				ProviderExecuted: false,
 365				Finished:         true,
 366			}
 367			currentAssistant.AddToolCall(toolCall)
 368			return a.messages.Update(genCtx, *currentAssistant)
 369		},
 370		OnToolResult: func(result fantasy.ToolResultContent) error {
 371			toolResult := a.convertToToolResult(result)
 372			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 373				Role: message.Tool,
 374				Parts: []message.ContentPart{
 375					toolResult,
 376				},
 377			})
 378			return createMsgErr
 379		},
 380		OnStepFinish: func(stepResult fantasy.StepResult) error {
 381			finishReason := message.FinishReasonUnknown
 382			switch stepResult.FinishReason {
 383			case fantasy.FinishReasonLength:
 384				finishReason = message.FinishReasonMaxTokens
 385			case fantasy.FinishReasonStop:
 386				finishReason = message.FinishReasonEndTurn
 387			case fantasy.FinishReasonToolCalls:
 388				finishReason = message.FinishReasonToolUse
 389			}
 390			currentAssistant.AddFinish(finishReason, "", "")
 391			sessionLock.Lock()
 392			defer sessionLock.Unlock()
 393
 394			updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
 395			if getSessionErr != nil {
 396				return getSessionErr
 397			}
 398			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 399			_, sessionErr := a.sessions.Save(ctx, updatedSession)
 400			if sessionErr != nil {
 401				return sessionErr
 402			}
 403			currentSession = updatedSession
 404			return a.messages.Update(genCtx, *currentAssistant)
 405		},
 406		StopWhen: []fantasy.StopCondition{
 407			func(_ []fantasy.StepResult) bool {
 408				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 409				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 410				remaining := cw - tokens
 411				var threshold int64
 412				if cw > largeContextWindowThreshold {
 413					threshold = largeContextWindowBuffer
 414				} else {
 415					threshold = int64(float64(cw) * smallContextWindowRatio)
 416				}
 417				if (remaining <= threshold) && !a.disableAutoSummarize {
 418					shouldSummarize = true
 419					return true
 420				}
 421				return false
 422			},
 423			func(steps []fantasy.StepResult) bool {
 424				return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
 425			},
 426		},
 427	})
 428
 429	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 430
 431	if err != nil {
 432		isCancelErr := errors.Is(err, context.Canceled)
 433		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 434		if currentAssistant == nil {
 435			return result, err
 436		}
 437		// Ensure we finish thinking on error to close the reasoning state.
 438		currentAssistant.FinishThinking()
 439		toolCalls := currentAssistant.ToolCalls()
 440		// INFO: we use the parent context here because the genCtx has been cancelled.
 441		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 442		if createErr != nil {
 443			return nil, createErr
 444		}
 445		for _, tc := range toolCalls {
 446			if !tc.Finished {
 447				tc.Finished = true
 448				tc.Input = "{}"
 449				currentAssistant.AddToolCall(tc)
 450				updateErr := a.messages.Update(ctx, *currentAssistant)
 451				if updateErr != nil {
 452					return nil, updateErr
 453				}
 454			}
 455
 456			found := false
 457			for _, msg := range msgs {
 458				if msg.Role == message.Tool {
 459					for _, tr := range msg.ToolResults() {
 460						if tr.ToolCallID == tc.ID {
 461							found = true
 462							break
 463						}
 464					}
 465				}
 466				if found {
 467					break
 468				}
 469			}
 470			if found {
 471				continue
 472			}
 473			content := "There was an error while executing the tool"
 474			if isCancelErr {
 475				content = "Tool execution canceled by user"
 476			} else if isPermissionErr {
 477				content = "User denied permission"
 478			}
 479			toolResult := message.ToolResult{
 480				ToolCallID: tc.ID,
 481				Name:       tc.Name,
 482				Content:    content,
 483				IsError:    true,
 484			}
 485			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 486				Role: message.Tool,
 487				Parts: []message.ContentPart{
 488					toolResult,
 489				},
 490			})
 491			if createErr != nil {
 492				return nil, createErr
 493			}
 494		}
 495		var fantasyErr *fantasy.Error
 496		var providerErr *fantasy.ProviderError
 497		const defaultTitle = "Provider Error"
 498		linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
 499		if isCancelErr {
 500			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 501		} else if isPermissionErr {
 502			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 503		} else if errors.Is(err, hyper.ErrNoCredits) {
 504			url := hyper.BaseURL()
 505			link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
 506			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 507		} else if errors.As(err, &providerErr) {
 508			if providerErr.Message == "The requested model is not supported." {
 509				url := "https://github.com/settings/copilot/features"
 510				link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
 511				currentAssistant.AddFinish(
 512					message.FinishReasonError,
 513					"Copilot model not enabled",
 514					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 515				)
 516			} else {
 517				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 518			}
 519		} else if errors.As(err, &fantasyErr) {
 520			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 521		} else {
 522			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 523		}
 524		// Note: we use the parent context here because the genCtx has been
 525		// cancelled.
 526		updateErr := a.messages.Update(ctx, *currentAssistant)
 527		if updateErr != nil {
 528			return nil, updateErr
 529		}
 530		return nil, err
 531	}
 532
 533	if shouldSummarize {
 534		a.activeRequests.Del(call.SessionID)
 535		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 536			return nil, summarizeErr
 537		}
 538		// If the agent wasn't done...
 539		if len(currentAssistant.ToolCalls()) > 0 {
 540			existing, ok := a.messageQueue.Get(call.SessionID)
 541			if !ok {
 542				existing = []SessionAgentCall{}
 543			}
 544			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 545			existing = append(existing, call)
 546			a.messageQueue.Set(call.SessionID, existing)
 547		}
 548	}
 549
 550	// Release active request before processing queued messages.
 551	a.activeRequests.Del(call.SessionID)
 552	cancel()
 553
 554	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 555	if !ok || len(queuedMessages) == 0 {
 556		return result, err
 557	}
 558	// There are queued messages restart the loop.
 559	firstQueuedMessage := queuedMessages[0]
 560	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 561	return a.Run(ctx, firstQueuedMessage)
 562}
 563
 564func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 565	if a.IsSessionBusy(sessionID) {
 566		return ErrSessionBusy
 567	}
 568
 569	// Copy mutable fields under lock to avoid races with SetModels.
 570	largeModel := a.largeModel.Get()
 571	systemPromptPrefix := a.systemPromptPrefix.Get()
 572
 573	currentSession, err := a.sessions.Get(ctx, sessionID)
 574	if err != nil {
 575		return fmt.Errorf("failed to get session: %w", err)
 576	}
 577	msgs, err := a.getSessionMessages(ctx, currentSession)
 578	if err != nil {
 579		return err
 580	}
 581	if len(msgs) == 0 {
 582		// Nothing to summarize.
 583		return nil
 584	}
 585
 586	aiMsgs, _ := a.preparePrompt(msgs)
 587
 588	genCtx, cancel := context.WithCancel(ctx)
 589	a.activeRequests.Set(sessionID, cancel)
 590	defer a.activeRequests.Del(sessionID)
 591	defer cancel()
 592
 593	agent := fantasy.NewAgent(largeModel.Model,
 594		fantasy.WithSystemPrompt(string(summaryPrompt)),
 595	)
 596	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 597		Role:             message.Assistant,
 598		Model:            largeModel.Model.Model(),
 599		Provider:         largeModel.Model.Provider(),
 600		IsSummaryMessage: true,
 601	})
 602	if err != nil {
 603		return err
 604	}
 605
 606	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 607
 608	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 609		Prompt:          summaryPromptText,
 610		Messages:        aiMsgs,
 611		ProviderOptions: opts,
 612		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 613			prepared.Messages = options.Messages
 614			if systemPromptPrefix != "" {
 615				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 616			}
 617			return callContext, prepared, nil
 618		},
 619		OnReasoningDelta: func(id string, text string) error {
 620			summaryMessage.AppendReasoningContent(text)
 621			return a.messages.Update(genCtx, summaryMessage)
 622		},
 623		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 624			// Handle anthropic signature.
 625			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 626				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 627					summaryMessage.AppendReasoningSignature(signature.Signature)
 628				}
 629			}
 630			summaryMessage.FinishThinking()
 631			return a.messages.Update(genCtx, summaryMessage)
 632		},
 633		OnTextDelta: func(id, text string) error {
 634			summaryMessage.AppendContent(text)
 635			return a.messages.Update(genCtx, summaryMessage)
 636		},
 637	})
 638	if err != nil {
 639		isCancelErr := errors.Is(err, context.Canceled)
 640		if isCancelErr {
 641			// User cancelled summarize we need to remove the summary message.
 642			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 643			return deleteErr
 644		}
 645		return err
 646	}
 647
 648	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 649	err = a.messages.Update(genCtx, summaryMessage)
 650	if err != nil {
 651		return err
 652	}
 653
 654	var openrouterCost *float64
 655	for _, step := range resp.Steps {
 656		stepCost := a.openrouterCost(step.ProviderMetadata)
 657		if stepCost != nil {
 658			newCost := *stepCost
 659			if openrouterCost != nil {
 660				newCost += *openrouterCost
 661			}
 662			openrouterCost = &newCost
 663		}
 664	}
 665
 666	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 667
 668	// Just in case, get just the last usage info.
 669	usage := resp.Response.Usage
 670	currentSession.SummaryMessageID = summaryMessage.ID
 671	currentSession.CompletionTokens = usage.OutputTokens
 672	currentSession.PromptTokens = 0
 673	_, err = a.sessions.Save(genCtx, currentSession)
 674	return err
 675}
 676
 677func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 678	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 679		return fantasy.ProviderOptions{}
 680	}
 681	return fantasy.ProviderOptions{
 682		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 683			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 684		},
 685		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 686			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 687		},
 688		vercel.Name: &anthropic.ProviderCacheControlOptions{
 689			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 690		},
 691	}
 692}
 693
 694func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 695	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 696	var attachmentParts []message.ContentPart
 697	for _, attachment := range call.Attachments {
 698		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 699	}
 700	parts = append(parts, attachmentParts...)
 701	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 702		Role:  message.User,
 703		Parts: parts,
 704	})
 705	if err != nil {
 706		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 707	}
 708	return msg, nil
 709}
 710
 711func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 712	var history []fantasy.Message
 713	if !a.isSubAgent {
 714		history = append(history, fantasy.NewUserMessage(
 715			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 716				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 717If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 718If not, please feel free to ignore. Again do not mention this message to the user.`,
 719			),
 720		))
 721	}
 722	for _, m := range msgs {
 723		if len(m.Parts) == 0 {
 724			continue
 725		}
 726		// Assistant message without content or tool calls (cancelled before it
 727		// returned anything).
 728		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 729			continue
 730		}
 731		history = append(history, m.ToAIMessage()...)
 732	}
 733
 734	var files []fantasy.FilePart
 735	for _, attachment := range attachments {
 736		if attachment.IsText() {
 737			continue
 738		}
 739		files = append(files, fantasy.FilePart{
 740			Filename:  attachment.FileName,
 741			Data:      attachment.Content,
 742			MediaType: attachment.MimeType,
 743		})
 744	}
 745
 746	return history, files
 747}
 748
 749func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 750	msgs, err := a.messages.List(ctx, session.ID)
 751	if err != nil {
 752		return nil, fmt.Errorf("failed to list messages: %w", err)
 753	}
 754
 755	if session.SummaryMessageID != "" {
 756		summaryMsgIndex := -1
 757		for i, msg := range msgs {
 758			if msg.ID == session.SummaryMessageID {
 759				summaryMsgIndex = i
 760				break
 761			}
 762		}
 763		if summaryMsgIndex != -1 {
 764			msgs = msgs[summaryMsgIndex:]
 765			msgs[0].Role = message.User
 766		}
 767	}
 768	return msgs, nil
 769}
 770
 771// generateTitle generates a session titled based on the initial prompt.
 772func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 773	if userPrompt == "" {
 774		return
 775	}
 776
 777	smallModel := a.smallModel.Get()
 778	largeModel := a.largeModel.Get()
 779	systemPromptPrefix := a.systemPromptPrefix.Get()
 780
 781	var maxOutputTokens int64 = 40
 782	if smallModel.CatwalkCfg.CanReason {
 783		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 784	}
 785
 786	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 787		return fantasy.NewAgent(m,
 788			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 789			fantasy.WithMaxOutputTokens(tok),
 790		)
 791	}
 792
 793	streamCall := fantasy.AgentStreamCall{
 794		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 795		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 796			prepared.Messages = opts.Messages
 797			if systemPromptPrefix != "" {
 798				prepared.Messages = append([]fantasy.Message{
 799					fantasy.NewSystemMessage(systemPromptPrefix),
 800				}, prepared.Messages...)
 801			}
 802			return callCtx, prepared, nil
 803		},
 804	}
 805
 806	// Use the small model to generate the title.
 807	model := smallModel
 808	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 809	resp, err := agent.Stream(ctx, streamCall)
 810	if err == nil {
 811		// We successfully generated a title with the small model.
 812		slog.Debug("Generated title with small model")
 813	} else {
 814		// It didn't work. Let's try with the big model.
 815		slog.Error("Error generating title with small model; trying big model", "err", err)
 816		model = largeModel
 817		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 818		resp, err = agent.Stream(ctx, streamCall)
 819		if err == nil {
 820			slog.Debug("Generated title with large model")
 821		} else {
 822			// Welp, the large model didn't work either. Use the default
 823			// session name and return.
 824			slog.Error("Error generating title with large model", "err", err)
 825			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 826			if saveErr != nil {
 827				slog.Error("Failed to save session title and usage", "error", saveErr)
 828			}
 829			return
 830		}
 831	}
 832
 833	if resp == nil {
 834		// Actually, we didn't get a response so we can't. Use the default
 835		// session name and return.
 836		slog.Error("Response is nil; can't generate title")
 837		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 838		if saveErr != nil {
 839			slog.Error("Failed to save session title and usage", "error", saveErr)
 840		}
 841		return
 842	}
 843
 844	// Clean up title.
 845	var title string
 846	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 847
 848	// Remove thinking tags if present.
 849	title = thinkTagRegex.ReplaceAllString(title, "")
 850
 851	title = strings.TrimSpace(title)
 852	title = cmp.Or(title, defaultSessionName)
 853
 854	// Calculate usage and cost.
 855	var openrouterCost *float64
 856	for _, step := range resp.Steps {
 857		stepCost := a.openrouterCost(step.ProviderMetadata)
 858		if stepCost != nil {
 859			newCost := *stepCost
 860			if openrouterCost != nil {
 861				newCost += *openrouterCost
 862			}
 863			openrouterCost = &newCost
 864		}
 865	}
 866
 867	modelConfig := model.CatwalkCfg
 868	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 869		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 870		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 871		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 872
 873	// Use override cost if available (e.g., from OpenRouter).
 874	if openrouterCost != nil {
 875		cost = *openrouterCost
 876	}
 877
 878	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 879	completionTokens := resp.TotalUsage.OutputTokens
 880
 881	// Atomically update only title and usage fields to avoid overriding other
 882	// concurrent session updates.
 883	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 884	if saveErr != nil {
 885		slog.Error("Failed to save session title and usage", "error", saveErr)
 886		return
 887	}
 888}
 889
 890func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 891	openrouterMetadata, ok := metadata[openrouter.Name]
 892	if !ok {
 893		return nil
 894	}
 895
 896	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 897	if !ok {
 898		return nil
 899	}
 900	return &opts.Usage.Cost
 901}
 902
 903func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 904	modelConfig := model.CatwalkCfg
 905	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 906		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 907		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 908		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 909
 910	a.eventTokensUsed(session.ID, model, usage, cost)
 911
 912	if overrideCost != nil {
 913		session.Cost += *overrideCost
 914	} else {
 915		session.Cost += cost
 916	}
 917
 918	session.CompletionTokens = usage.OutputTokens
 919	session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
 920}
 921
 922func (a *sessionAgent) Cancel(sessionID string) {
 923	// Cancel regular requests. Don't use Take() here - we need the entry to
 924	// remain in activeRequests so IsBusy() returns true until the goroutine
 925	// fully completes (including error handling that may access the DB).
 926	// The defer in processRequest will clean up the entry.
 927	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 928		slog.Debug("Request cancellation initiated", "session_id", sessionID)
 929		cancel()
 930	}
 931
 932	// Also check for summarize requests.
 933	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 934		slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
 935		cancel()
 936	}
 937
 938	if a.QueuedPrompts(sessionID) > 0 {
 939		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 940		a.messageQueue.Del(sessionID)
 941	}
 942}
 943
 944func (a *sessionAgent) ClearQueue(sessionID string) {
 945	if a.QueuedPrompts(sessionID) > 0 {
 946		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 947		a.messageQueue.Del(sessionID)
 948	}
 949}
 950
 951func (a *sessionAgent) CancelAll() {
 952	if !a.IsBusy() {
 953		return
 954	}
 955	for key := range a.activeRequests.Seq2() {
 956		a.Cancel(key) // key is sessionID
 957	}
 958
 959	timeout := time.After(5 * time.Second)
 960	for a.IsBusy() {
 961		select {
 962		case <-timeout:
 963			return
 964		default:
 965			time.Sleep(200 * time.Millisecond)
 966		}
 967	}
 968}
 969
 970func (a *sessionAgent) IsBusy() bool {
 971	var busy bool
 972	for cancelFunc := range a.activeRequests.Seq() {
 973		if cancelFunc != nil {
 974			busy = true
 975			break
 976		}
 977	}
 978	return busy
 979}
 980
 981func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 982	_, busy := a.activeRequests.Get(sessionID)
 983	return busy
 984}
 985
 986func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 987	l, ok := a.messageQueue.Get(sessionID)
 988	if !ok {
 989		return 0
 990	}
 991	return len(l)
 992}
 993
 994func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 995	l, ok := a.messageQueue.Get(sessionID)
 996	if !ok {
 997		return nil
 998	}
 999	prompts := make([]string, len(l))
1000	for i, call := range l {
1001		prompts[i] = call.Prompt
1002	}
1003	return prompts
1004}
1005
1006func (a *sessionAgent) SetModels(large Model, small Model) {
1007	a.largeModel.Set(large)
1008	a.smallModel.Set(small)
1009}
1010
1011func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1012	a.tools.SetSlice(tools)
1013}
1014
1015func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1016	a.systemPrompt.Set(systemPrompt)
1017}
1018
1019func (a *sessionAgent) Model() Model {
1020	return a.largeModel.Get()
1021}
1022
1023// convertToToolResult converts a fantasy tool result to a message tool result.
1024func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1025	baseResult := message.ToolResult{
1026		ToolCallID: result.ToolCallID,
1027		Name:       result.ToolName,
1028		Metadata:   result.ClientMetadata,
1029	}
1030
1031	switch result.Result.GetType() {
1032	case fantasy.ToolResultContentTypeText:
1033		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1034			baseResult.Content = r.Text
1035		}
1036	case fantasy.ToolResultContentTypeError:
1037		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1038			baseResult.Content = r.Error.Error()
1039			baseResult.IsError = true
1040		}
1041	case fantasy.ToolResultContentTypeMedia:
1042		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1043			content := r.Text
1044			if content == "" {
1045				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1046			}
1047			baseResult.Content = content
1048			baseResult.Data = r.Data
1049			baseResult.MIMEType = r.MediaType
1050		}
1051	}
1052
1053	return baseResult
1054}
1055
1056// workaroundProviderMediaLimitations converts media content in tool results to
1057// user messages for providers that don't natively support images in tool results.
1058//
1059// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1060// don't support sending images/media in tool result messages - they only accept
1061// text in tool results. However, they DO support images in user messages.
1062//
1063// If we send media in tool results to these providers, the API returns an error.
1064//
1065// Solution: For these providers, we:
1066//  1. Replace the media in the tool result with a text placeholder
1067//  2. Inject a user message immediately after with the image as a file attachment
1068//  3. This maintains the tool execution flow while working around API limitations
1069//
1070// Anthropic and Bedrock support images natively in tool results, so we skip
1071// this workaround for them.
1072//
1073// Example transformation:
1074//
1075//	BEFORE: [tool result: image data]
1076//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1077func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1078	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1079		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1080
1081	if providerSupportsMedia {
1082		return messages
1083	}
1084
1085	convertedMessages := make([]fantasy.Message, 0, len(messages))
1086
1087	for _, msg := range messages {
1088		if msg.Role != fantasy.MessageRoleTool {
1089			convertedMessages = append(convertedMessages, msg)
1090			continue
1091		}
1092
1093		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1094		var mediaFiles []fantasy.FilePart
1095
1096		for _, part := range msg.Content {
1097			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1098			if !ok {
1099				textParts = append(textParts, part)
1100				continue
1101			}
1102
1103			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1104				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1105				if err != nil {
1106					slog.Warn("Failed to decode media data", "error", err)
1107					textParts = append(textParts, part)
1108					continue
1109				}
1110
1111				mediaFiles = append(mediaFiles, fantasy.FilePart{
1112					Data:      decoded,
1113					MediaType: media.MediaType,
1114					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1115				})
1116
1117				textParts = append(textParts, fantasy.ToolResultPart{
1118					ToolCallID: toolResult.ToolCallID,
1119					Output: fantasy.ToolResultOutputContentText{
1120						Text: "[Image/media content loaded - see attached file]",
1121					},
1122					ProviderOptions: toolResult.ProviderOptions,
1123				})
1124			} else {
1125				textParts = append(textParts, part)
1126			}
1127		}
1128
1129		convertedMessages = append(convertedMessages, fantasy.Message{
1130			Role:    fantasy.MessageRoleTool,
1131			Content: textParts,
1132		})
1133
1134		if len(mediaFiles) > 0 {
1135			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1136				"Here is the media content from the tool result:",
1137				mediaFiles...,
1138			))
1139		}
1140	}
1141
1142	return convertedMessages
1143}
1144
1145// buildSummaryPrompt constructs the prompt text for session summarization.
1146func buildSummaryPrompt(todos []session.Todo) string {
1147	var sb strings.Builder
1148	sb.WriteString("Provide a detailed summary of our conversation above.")
1149	if len(todos) > 0 {
1150		sb.WriteString("\n\n## Current Todo List\n\n")
1151		for _, t := range todos {
1152			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1153		}
1154		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1155		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1156	}
1157	return sb.String()
1158}