agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	SetSystemPrompt(systemPrompt string)
  72	Cancel(sessionID string)
  73	CancelAll()
  74	IsSessionBusy(sessionID string) bool
  75	IsBusy() bool
  76	QueuedPrompts(sessionID string) int
  77	QueuedPromptsList(sessionID string) []string
  78	ClearQueue(sessionID string)
  79	Summarize(context.Context, string, fantasy.ProviderOptions) error
  80	Model() Model
  81}
  82
  83type Model struct {
  84	Model      fantasy.LanguageModel
  85	CatwalkCfg catwalk.Model
  86	ModelCfg   config.SelectedModel
  87}
  88
  89type sessionAgent struct {
  90	largeModel         *csync.Value[Model]
  91	smallModel         *csync.Value[Model]
  92	systemPromptPrefix *csync.Value[string]
  93	systemPrompt       *csync.Value[string]
  94	tools              *csync.Slice[fantasy.AgentTool]
  95
  96	isSubAgent           bool
  97	sessions             session.Service
  98	messages             message.Service
  99	disableAutoSummarize bool
 100	isYolo               bool
 101
 102	messageQueue   *csync.Map[string, []SessionAgentCall]
 103	activeRequests *csync.Map[string, context.CancelFunc]
 104}
 105
 106type SessionAgentOptions struct {
 107	LargeModel           Model
 108	SmallModel           Model
 109	SystemPromptPrefix   string
 110	SystemPrompt         string
 111	IsSubAgent           bool
 112	DisableAutoSummarize bool
 113	IsYolo               bool
 114	Sessions             session.Service
 115	Messages             message.Service
 116	Tools                []fantasy.AgentTool
 117}
 118
 119func NewSessionAgent(
 120	opts SessionAgentOptions,
 121) SessionAgent {
 122	return &sessionAgent{
 123		largeModel:           csync.NewValue(opts.LargeModel),
 124		smallModel:           csync.NewValue(opts.SmallModel),
 125		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
 126		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 127		isSubAgent:           opts.IsSubAgent,
 128		sessions:             opts.Sessions,
 129		messages:             opts.Messages,
 130		disableAutoSummarize: opts.DisableAutoSummarize,
 131		tools:                csync.NewSliceFrom(opts.Tools),
 132		isYolo:               opts.IsYolo,
 133		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 134		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 135	}
 136}
 137
 138func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 139	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 140		return nil, ErrEmptyPrompt
 141	}
 142	if call.SessionID == "" {
 143		return nil, ErrSessionMissing
 144	}
 145
 146	// Queue the message if busy
 147	if a.IsSessionBusy(call.SessionID) {
 148		existing, ok := a.messageQueue.Get(call.SessionID)
 149		if !ok {
 150			existing = []SessionAgentCall{}
 151		}
 152		existing = append(existing, call)
 153		a.messageQueue.Set(call.SessionID, existing)
 154		return nil, nil
 155	}
 156
 157	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
 158	agentTools := a.tools.Copy()
 159	largeModel := a.largeModel.Get()
 160	systemPrompt := a.systemPrompt.Get()
 161	promptPrefix := a.systemPromptPrefix.Get()
 162
 163	if len(agentTools) > 0 {
 164		// Add Anthropic caching to the last tool.
 165		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 166	}
 167
 168	agent := fantasy.NewAgent(
 169		largeModel.Model,
 170		fantasy.WithSystemPrompt(systemPrompt),
 171		fantasy.WithTools(agentTools...),
 172	)
 173
 174	sessionLock := sync.Mutex{}
 175	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 176	if err != nil {
 177		return nil, fmt.Errorf("failed to get session: %w", err)
 178	}
 179
 180	msgs, err := a.getSessionMessages(ctx, currentSession)
 181	if err != nil {
 182		return nil, fmt.Errorf("failed to get session messages: %w", err)
 183	}
 184
 185	var wg sync.WaitGroup
 186	// Generate title if first message.
 187	if len(msgs) == 0 {
 188		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 189		wg.Go(func() {
 190			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 191		})
 192	}
 193	defer wg.Wait()
 194
 195	// Add the user message to the session.
 196	_, err = a.createUserMessage(ctx, call)
 197	if err != nil {
 198		return nil, err
 199	}
 200
 201	// Add the session to the context.
 202	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 203
 204	genCtx, cancel := context.WithCancel(ctx)
 205	a.activeRequests.Set(call.SessionID, cancel)
 206
 207	defer cancel()
 208	defer a.activeRequests.Del(call.SessionID)
 209
 210	history, files := a.preparePrompt(msgs, call.Attachments...)
 211
 212	startTime := time.Now()
 213	a.eventPromptSent(call.SessionID)
 214
 215	var currentAssistant *message.Message
 216	var shouldSummarize bool
 217	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 218		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 219		Files:            files,
 220		Messages:         history,
 221		ProviderOptions:  call.ProviderOptions,
 222		MaxOutputTokens:  &call.MaxOutputTokens,
 223		TopP:             call.TopP,
 224		Temperature:      call.Temperature,
 225		PresencePenalty:  call.PresencePenalty,
 226		TopK:             call.TopK,
 227		FrequencyPenalty: call.FrequencyPenalty,
 228		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 229			prepared.Messages = options.Messages
 230			for i := range prepared.Messages {
 231				prepared.Messages[i].ProviderOptions = nil
 232			}
 233
 234			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 235			a.messageQueue.Del(call.SessionID)
 236			for _, queued := range queuedCalls {
 237				userMessage, createErr := a.createUserMessage(callContext, queued)
 238				if createErr != nil {
 239					return callContext, prepared, createErr
 240				}
 241				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 242			}
 243
 244			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 245
 246			lastSystemRoleInx := 0
 247			systemMessageUpdated := false
 248			for i, msg := range prepared.Messages {
 249				// Only add cache control to the last message.
 250				if msg.Role == fantasy.MessageRoleSystem {
 251					lastSystemRoleInx = i
 252				} else if !systemMessageUpdated {
 253					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 254					systemMessageUpdated = true
 255				}
 256				// Than add cache control to the last 2 messages.
 257				if i > len(prepared.Messages)-3 {
 258					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 259				}
 260			}
 261
 262			if promptPrefix != "" {
 263				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 264			}
 265
 266			var assistantMsg message.Message
 267			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 268				Role:     message.Assistant,
 269				Parts:    []message.ContentPart{},
 270				Model:    largeModel.ModelCfg.Model,
 271				Provider: largeModel.ModelCfg.Provider,
 272			})
 273			if err != nil {
 274				return callContext, prepared, err
 275			}
 276			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 277			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
 278			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 279			currentAssistant = &assistantMsg
 280			return callContext, prepared, err
 281		},
 282		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 283			currentAssistant.AppendReasoningContent(reasoning.Text)
 284			return a.messages.Update(genCtx, *currentAssistant)
 285		},
 286		OnReasoningDelta: func(id string, text string) error {
 287			currentAssistant.AppendReasoningContent(text)
 288			return a.messages.Update(genCtx, *currentAssistant)
 289		},
 290		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 291			// handle anthropic signature
 292			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 293				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 294					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 295				}
 296			}
 297			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 298				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 299					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 300				}
 301			}
 302			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 303				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 304					currentAssistant.SetReasoningResponsesData(reasoning)
 305				}
 306			}
 307			currentAssistant.FinishThinking()
 308			return a.messages.Update(genCtx, *currentAssistant)
 309		},
 310		OnTextDelta: func(id string, text string) error {
 311			// Strip leading newline from initial text content. This is is
 312			// particularly important in non-interactive mode where leading
 313			// newlines are very visible.
 314			if len(currentAssistant.Parts) == 0 {
 315				text = strings.TrimPrefix(text, "\n")
 316			}
 317
 318			currentAssistant.AppendContent(text)
 319			return a.messages.Update(genCtx, *currentAssistant)
 320		},
 321		OnToolInputStart: func(id string, toolName string) error {
 322			toolCall := message.ToolCall{
 323				ID:               id,
 324				Name:             toolName,
 325				ProviderExecuted: false,
 326				Finished:         false,
 327			}
 328			currentAssistant.AddToolCall(toolCall)
 329			return a.messages.Update(genCtx, *currentAssistant)
 330		},
 331		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 332			// TODO: implement
 333		},
 334		OnToolCall: func(tc fantasy.ToolCallContent) error {
 335			toolCall := message.ToolCall{
 336				ID:               tc.ToolCallID,
 337				Name:             tc.ToolName,
 338				Input:            tc.Input,
 339				ProviderExecuted: false,
 340				Finished:         true,
 341			}
 342			currentAssistant.AddToolCall(toolCall)
 343			return a.messages.Update(genCtx, *currentAssistant)
 344		},
 345		OnToolResult: func(result fantasy.ToolResultContent) error {
 346			toolResult := a.convertToToolResult(result)
 347			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 348				Role: message.Tool,
 349				Parts: []message.ContentPart{
 350					toolResult,
 351				},
 352			})
 353			return createMsgErr
 354		},
 355		OnStepFinish: func(stepResult fantasy.StepResult) error {
 356			finishReason := message.FinishReasonUnknown
 357			switch stepResult.FinishReason {
 358			case fantasy.FinishReasonLength:
 359				finishReason = message.FinishReasonMaxTokens
 360			case fantasy.FinishReasonStop:
 361				finishReason = message.FinishReasonEndTurn
 362			case fantasy.FinishReasonToolCalls:
 363				finishReason = message.FinishReasonToolUse
 364			}
 365			currentAssistant.AddFinish(finishReason, "", "")
 366			sessionLock.Lock()
 367			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 368			if getSessionErr != nil {
 369				sessionLock.Unlock()
 370				return getSessionErr
 371			}
 372			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 373			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 374			sessionLock.Unlock()
 375			if sessionErr != nil {
 376				return sessionErr
 377			}
 378			return a.messages.Update(genCtx, *currentAssistant)
 379		},
 380		StopWhen: []fantasy.StopCondition{
 381			func(_ []fantasy.StepResult) bool {
 382				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 383				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 384				remaining := cw - tokens
 385				var threshold int64
 386				if cw > 200_000 {
 387					threshold = 20_000
 388				} else {
 389					threshold = int64(float64(cw) * 0.2)
 390				}
 391				if (remaining <= threshold) && !a.disableAutoSummarize {
 392					shouldSummarize = true
 393					return true
 394				}
 395				return false
 396			},
 397		},
 398	})
 399
 400	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 401
 402	if err != nil {
 403		isCancelErr := errors.Is(err, context.Canceled)
 404		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 405		if currentAssistant == nil {
 406			return result, err
 407		}
 408		// Ensure we finish thinking on error to close the reasoning state.
 409		currentAssistant.FinishThinking()
 410		toolCalls := currentAssistant.ToolCalls()
 411		// INFO: we use the parent context here because the genCtx has been cancelled.
 412		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 413		if createErr != nil {
 414			return nil, createErr
 415		}
 416		for _, tc := range toolCalls {
 417			if !tc.Finished {
 418				tc.Finished = true
 419				tc.Input = "{}"
 420				currentAssistant.AddToolCall(tc)
 421				updateErr := a.messages.Update(ctx, *currentAssistant)
 422				if updateErr != nil {
 423					return nil, updateErr
 424				}
 425			}
 426
 427			found := false
 428			for _, msg := range msgs {
 429				if msg.Role == message.Tool {
 430					for _, tr := range msg.ToolResults() {
 431						if tr.ToolCallID == tc.ID {
 432							found = true
 433							break
 434						}
 435					}
 436				}
 437				if found {
 438					break
 439				}
 440			}
 441			if found {
 442				continue
 443			}
 444			content := "There was an error while executing the tool"
 445			if isCancelErr {
 446				content = "Tool execution canceled by user"
 447			} else if isPermissionErr {
 448				content = "User denied permission"
 449			}
 450			toolResult := message.ToolResult{
 451				ToolCallID: tc.ID,
 452				Name:       tc.Name,
 453				Content:    content,
 454				IsError:    true,
 455			}
 456			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 457				Role: message.Tool,
 458				Parts: []message.ContentPart{
 459					toolResult,
 460				},
 461			})
 462			if createErr != nil {
 463				return nil, createErr
 464			}
 465		}
 466		var fantasyErr *fantasy.Error
 467		var providerErr *fantasy.ProviderError
 468		const defaultTitle = "Provider Error"
 469		if isCancelErr {
 470			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 471		} else if isPermissionErr {
 472			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 473		} else if errors.Is(err, hyper.ErrNoCredits) {
 474			url := hyper.BaseURL()
 475			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 476			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 477		} else if errors.As(err, &providerErr) {
 478			if providerErr.Message == "The requested model is not supported." {
 479				url := "https://github.com/settings/copilot/features"
 480				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 481				currentAssistant.AddFinish(
 482					message.FinishReasonError,
 483					"Copilot model not enabled",
 484					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 485				)
 486			} else {
 487				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 488			}
 489		} else if errors.As(err, &fantasyErr) {
 490			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 491		} else {
 492			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 493		}
 494		// Note: we use the parent context here because the genCtx has been
 495		// cancelled.
 496		updateErr := a.messages.Update(ctx, *currentAssistant)
 497		if updateErr != nil {
 498			return nil, updateErr
 499		}
 500		return nil, err
 501	}
 502
 503	if shouldSummarize {
 504		a.activeRequests.Del(call.SessionID)
 505		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 506			return nil, summarizeErr
 507		}
 508		// If the agent wasn't done...
 509		if len(currentAssistant.ToolCalls()) > 0 {
 510			existing, ok := a.messageQueue.Get(call.SessionID)
 511			if !ok {
 512				existing = []SessionAgentCall{}
 513			}
 514			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 515			existing = append(existing, call)
 516			a.messageQueue.Set(call.SessionID, existing)
 517		}
 518	}
 519
 520	// Release active request before processing queued messages.
 521	a.activeRequests.Del(call.SessionID)
 522	cancel()
 523
 524	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 525	if !ok || len(queuedMessages) == 0 {
 526		return result, err
 527	}
 528	// There are queued messages restart the loop.
 529	firstQueuedMessage := queuedMessages[0]
 530	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 531	return a.Run(ctx, firstQueuedMessage)
 532}
 533
 534func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 535	if a.IsSessionBusy(sessionID) {
 536		return ErrSessionBusy
 537	}
 538
 539	// Copy mutable fields under lock to avoid races with SetModels.
 540	largeModel := a.largeModel.Get()
 541	systemPromptPrefix := a.systemPromptPrefix.Get()
 542
 543	currentSession, err := a.sessions.Get(ctx, sessionID)
 544	if err != nil {
 545		return fmt.Errorf("failed to get session: %w", err)
 546	}
 547	msgs, err := a.getSessionMessages(ctx, currentSession)
 548	if err != nil {
 549		return err
 550	}
 551	if len(msgs) == 0 {
 552		// Nothing to summarize.
 553		return nil
 554	}
 555
 556	aiMsgs, _ := a.preparePrompt(msgs)
 557
 558	genCtx, cancel := context.WithCancel(ctx)
 559	a.activeRequests.Set(sessionID, cancel)
 560	defer a.activeRequests.Del(sessionID)
 561	defer cancel()
 562
 563	agent := fantasy.NewAgent(largeModel.Model,
 564		fantasy.WithSystemPrompt(string(summaryPrompt)),
 565	)
 566	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 567		Role:             message.Assistant,
 568		Model:            largeModel.Model.Model(),
 569		Provider:         largeModel.Model.Provider(),
 570		IsSummaryMessage: true,
 571	})
 572	if err != nil {
 573		return err
 574	}
 575
 576	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 577
 578	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 579		Prompt:          summaryPromptText,
 580		Messages:        aiMsgs,
 581		ProviderOptions: opts,
 582		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 583			prepared.Messages = options.Messages
 584			if systemPromptPrefix != "" {
 585				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 586			}
 587			return callContext, prepared, nil
 588		},
 589		OnReasoningDelta: func(id string, text string) error {
 590			summaryMessage.AppendReasoningContent(text)
 591			return a.messages.Update(genCtx, summaryMessage)
 592		},
 593		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 594			// Handle anthropic signature.
 595			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 596				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 597					summaryMessage.AppendReasoningSignature(signature.Signature)
 598				}
 599			}
 600			summaryMessage.FinishThinking()
 601			return a.messages.Update(genCtx, summaryMessage)
 602		},
 603		OnTextDelta: func(id, text string) error {
 604			summaryMessage.AppendContent(text)
 605			return a.messages.Update(genCtx, summaryMessage)
 606		},
 607	})
 608	if err != nil {
 609		isCancelErr := errors.Is(err, context.Canceled)
 610		if isCancelErr {
 611			// User cancelled summarize we need to remove the summary message.
 612			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 613			return deleteErr
 614		}
 615		return err
 616	}
 617
 618	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 619	err = a.messages.Update(genCtx, summaryMessage)
 620	if err != nil {
 621		return err
 622	}
 623
 624	var openrouterCost *float64
 625	for _, step := range resp.Steps {
 626		stepCost := a.openrouterCost(step.ProviderMetadata)
 627		if stepCost != nil {
 628			newCost := *stepCost
 629			if openrouterCost != nil {
 630				newCost += *openrouterCost
 631			}
 632			openrouterCost = &newCost
 633		}
 634	}
 635
 636	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 637
 638	// Just in case, get just the last usage info.
 639	usage := resp.Response.Usage
 640	currentSession.SummaryMessageID = summaryMessage.ID
 641	currentSession.CompletionTokens = usage.OutputTokens
 642	currentSession.PromptTokens = 0
 643	_, err = a.sessions.Save(genCtx, currentSession)
 644	return err
 645}
 646
 647func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 648	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 649		return fantasy.ProviderOptions{}
 650	}
 651	return fantasy.ProviderOptions{
 652		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 653			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 654		},
 655		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 656			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 657		},
 658	}
 659}
 660
 661func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 662	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 663	var attachmentParts []message.ContentPart
 664	for _, attachment := range call.Attachments {
 665		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 666	}
 667	parts = append(parts, attachmentParts...)
 668	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 669		Role:  message.User,
 670		Parts: parts,
 671	})
 672	if err != nil {
 673		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 674	}
 675	return msg, nil
 676}
 677
 678func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 679	var history []fantasy.Message
 680	if !a.isSubAgent {
 681		history = append(history, fantasy.NewUserMessage(
 682			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 683				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 684If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 685If not, please feel free to ignore. Again do not mention this message to the user.`,
 686			),
 687		))
 688	}
 689	for _, m := range msgs {
 690		if len(m.Parts) == 0 {
 691			continue
 692		}
 693		// Assistant message without content or tool calls (cancelled before it
 694		// returned anything).
 695		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 696			continue
 697		}
 698		history = append(history, m.ToAIMessage()...)
 699	}
 700
 701	var files []fantasy.FilePart
 702	for _, attachment := range attachments {
 703		if attachment.IsText() {
 704			continue
 705		}
 706		files = append(files, fantasy.FilePart{
 707			Filename:  attachment.FileName,
 708			Data:      attachment.Content,
 709			MediaType: attachment.MimeType,
 710		})
 711	}
 712
 713	return history, files
 714}
 715
 716func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 717	msgs, err := a.messages.List(ctx, session.ID)
 718	if err != nil {
 719		return nil, fmt.Errorf("failed to list messages: %w", err)
 720	}
 721
 722	if session.SummaryMessageID != "" {
 723		summaryMsgInex := -1
 724		for i, msg := range msgs {
 725			if msg.ID == session.SummaryMessageID {
 726				summaryMsgInex = i
 727				break
 728			}
 729		}
 730		if summaryMsgInex != -1 {
 731			msgs = msgs[summaryMsgInex:]
 732			msgs[0].Role = message.User
 733		}
 734	}
 735	return msgs, nil
 736}
 737
 738// generateTitle generates a session titled based on the initial prompt.
 739func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 740	if userPrompt == "" {
 741		return
 742	}
 743
 744	smallModel := a.smallModel.Get()
 745	largeModel := a.largeModel.Get()
 746	systemPromptPrefix := a.systemPromptPrefix.Get()
 747
 748	var maxOutputTokens int64 = 40
 749	if smallModel.CatwalkCfg.CanReason {
 750		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 751	}
 752
 753	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 754		return fantasy.NewAgent(m,
 755			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 756			fantasy.WithMaxOutputTokens(tok),
 757		)
 758	}
 759
 760	streamCall := fantasy.AgentStreamCall{
 761		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 762		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 763			prepared.Messages = opts.Messages
 764			if systemPromptPrefix != "" {
 765				prepared.Messages = append([]fantasy.Message{
 766					fantasy.NewSystemMessage(systemPromptPrefix),
 767				}, prepared.Messages...)
 768			}
 769			return callCtx, prepared, nil
 770		},
 771	}
 772
 773	// Use the small model to generate the title.
 774	model := smallModel
 775	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 776	resp, err := agent.Stream(ctx, streamCall)
 777	if err == nil {
 778		// We successfully generated a title with the small model.
 779		slog.Info("generated title with small model")
 780	} else {
 781		// It didn't work. Let's try with the big model.
 782		slog.Error("error generating title with small model; trying big model", "err", err)
 783		model = largeModel
 784		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 785		resp, err = agent.Stream(ctx, streamCall)
 786		if err == nil {
 787			slog.Info("generated title with large model")
 788		} else {
 789			// Welp, the large model didn't work either. Use the default
 790			// session name and return.
 791			slog.Error("error generating title with large model", "err", err)
 792			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 793			if saveErr != nil {
 794				slog.Error("failed to save session title and usage", "error", saveErr)
 795			}
 796			return
 797		}
 798	}
 799
 800	if resp == nil {
 801		// Actually, we didn't get a response so we can't. Use the default
 802		// session name and return.
 803		slog.Error("response is nil; can't generate title")
 804		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 805		if saveErr != nil {
 806			slog.Error("failed to save session title and usage", "error", saveErr)
 807		}
 808		return
 809	}
 810
 811	// Clean up title.
 812	var title string
 813	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 814	slog.Info("generated title", "title", title)
 815
 816	// Remove thinking tags if present.
 817	title = thinkTagRegex.ReplaceAllString(title, "")
 818
 819	title = strings.TrimSpace(title)
 820	if title == "" {
 821		slog.Warn("empty title; using fallback")
 822		title = defaultSessionName
 823	}
 824
 825	// Calculate usage and cost.
 826	var openrouterCost *float64
 827	for _, step := range resp.Steps {
 828		stepCost := a.openrouterCost(step.ProviderMetadata)
 829		if stepCost != nil {
 830			newCost := *stepCost
 831			if openrouterCost != nil {
 832				newCost += *openrouterCost
 833			}
 834			openrouterCost = &newCost
 835		}
 836	}
 837
 838	modelConfig := model.CatwalkCfg
 839	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 840		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 841		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 842		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 843
 844	// Use override cost if available (e.g., from OpenRouter).
 845	if openrouterCost != nil {
 846		cost = *openrouterCost
 847	}
 848
 849	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 850	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 851
 852	// Atomically update only title and usage fields to avoid overriding other
 853	// concurrent session updates.
 854	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 855	if saveErr != nil {
 856		slog.Error("failed to save session title and usage", "error", saveErr)
 857		return
 858	}
 859}
 860
 861func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 862	openrouterMetadata, ok := metadata[openrouter.Name]
 863	if !ok {
 864		return nil
 865	}
 866
 867	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 868	if !ok {
 869		return nil
 870	}
 871	return &opts.Usage.Cost
 872}
 873
 874func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 875	modelConfig := model.CatwalkCfg
 876	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 877		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 878		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 879		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 880
 881	a.eventTokensUsed(session.ID, model, usage, cost)
 882
 883	if overrideCost != nil {
 884		session.Cost += *overrideCost
 885	} else {
 886		session.Cost += cost
 887	}
 888
 889	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 890	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 891}
 892
 893func (a *sessionAgent) Cancel(sessionID string) {
 894	// Cancel regular requests. Don't use Take() here - we need the entry to
 895	// remain in activeRequests so IsBusy() returns true until the goroutine
 896	// fully completes (including error handling that may access the DB).
 897	// The defer in processRequest will clean up the entry.
 898	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 899		slog.Info("Request cancellation initiated", "session_id", sessionID)
 900		cancel()
 901	}
 902
 903	// Also check for summarize requests.
 904	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 905		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 906		cancel()
 907	}
 908
 909	if a.QueuedPrompts(sessionID) > 0 {
 910		slog.Info("Clearing queued prompts", "session_id", sessionID)
 911		a.messageQueue.Del(sessionID)
 912	}
 913}
 914
 915func (a *sessionAgent) ClearQueue(sessionID string) {
 916	if a.QueuedPrompts(sessionID) > 0 {
 917		slog.Info("Clearing queued prompts", "session_id", sessionID)
 918		a.messageQueue.Del(sessionID)
 919	}
 920}
 921
 922func (a *sessionAgent) CancelAll() {
 923	if !a.IsBusy() {
 924		return
 925	}
 926	for key := range a.activeRequests.Seq2() {
 927		a.Cancel(key) // key is sessionID
 928	}
 929
 930	timeout := time.After(5 * time.Second)
 931	for a.IsBusy() {
 932		select {
 933		case <-timeout:
 934			return
 935		default:
 936			time.Sleep(200 * time.Millisecond)
 937		}
 938	}
 939}
 940
 941func (a *sessionAgent) IsBusy() bool {
 942	var busy bool
 943	for cancelFunc := range a.activeRequests.Seq() {
 944		if cancelFunc != nil {
 945			busy = true
 946			break
 947		}
 948	}
 949	return busy
 950}
 951
 952func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 953	_, busy := a.activeRequests.Get(sessionID)
 954	return busy
 955}
 956
 957func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 958	l, ok := a.messageQueue.Get(sessionID)
 959	if !ok {
 960		return 0
 961	}
 962	return len(l)
 963}
 964
 965func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 966	l, ok := a.messageQueue.Get(sessionID)
 967	if !ok {
 968		return nil
 969	}
 970	prompts := make([]string, len(l))
 971	for i, call := range l {
 972		prompts[i] = call.Prompt
 973	}
 974	return prompts
 975}
 976
 977func (a *sessionAgent) SetModels(large Model, small Model) {
 978	a.largeModel.Set(large)
 979	a.smallModel.Set(small)
 980}
 981
 982func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 983	a.tools.SetSlice(tools)
 984}
 985
 986func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
 987	a.systemPrompt.Set(systemPrompt)
 988}
 989
 990func (a *sessionAgent) Model() Model {
 991	return a.largeModel.Get()
 992}
 993
 994// convertToToolResult converts a fantasy tool result to a message tool result.
 995func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 996	baseResult := message.ToolResult{
 997		ToolCallID: result.ToolCallID,
 998		Name:       result.ToolName,
 999		Metadata:   result.ClientMetadata,
1000	}
1001
1002	switch result.Result.GetType() {
1003	case fantasy.ToolResultContentTypeText:
1004		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1005			baseResult.Content = r.Text
1006		}
1007	case fantasy.ToolResultContentTypeError:
1008		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1009			baseResult.Content = r.Error.Error()
1010			baseResult.IsError = true
1011		}
1012	case fantasy.ToolResultContentTypeMedia:
1013		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1014			content := r.Text
1015			if content == "" {
1016				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1017			}
1018			baseResult.Content = content
1019			baseResult.Data = r.Data
1020			baseResult.MIMEType = r.MediaType
1021		}
1022	}
1023
1024	return baseResult
1025}
1026
1027// workaroundProviderMediaLimitations converts media content in tool results to
1028// user messages for providers that don't natively support images in tool results.
1029//
1030// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1031// don't support sending images/media in tool result messages - they only accept
1032// text in tool results. However, they DO support images in user messages.
1033//
1034// If we send media in tool results to these providers, the API returns an error.
1035//
1036// Solution: For these providers, we:
1037//  1. Replace the media in the tool result with a text placeholder
1038//  2. Inject a user message immediately after with the image as a file attachment
1039//  3. This maintains the tool execution flow while working around API limitations
1040//
1041// Anthropic and Bedrock support images natively in tool results, so we skip
1042// this workaround for them.
1043//
1044// Example transformation:
1045//
1046//	BEFORE: [tool result: image data]
1047//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1048func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1049	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1050		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1051
1052	if providerSupportsMedia {
1053		return messages
1054	}
1055
1056	convertedMessages := make([]fantasy.Message, 0, len(messages))
1057
1058	for _, msg := range messages {
1059		if msg.Role != fantasy.MessageRoleTool {
1060			convertedMessages = append(convertedMessages, msg)
1061			continue
1062		}
1063
1064		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1065		var mediaFiles []fantasy.FilePart
1066
1067		for _, part := range msg.Content {
1068			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1069			if !ok {
1070				textParts = append(textParts, part)
1071				continue
1072			}
1073
1074			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1075				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1076				if err != nil {
1077					slog.Warn("failed to decode media data", "error", err)
1078					textParts = append(textParts, part)
1079					continue
1080				}
1081
1082				mediaFiles = append(mediaFiles, fantasy.FilePart{
1083					Data:      decoded,
1084					MediaType: media.MediaType,
1085					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1086				})
1087
1088				textParts = append(textParts, fantasy.ToolResultPart{
1089					ToolCallID: toolResult.ToolCallID,
1090					Output: fantasy.ToolResultOutputContentText{
1091						Text: "[Image/media content loaded - see attached file]",
1092					},
1093					ProviderOptions: toolResult.ProviderOptions,
1094				})
1095			} else {
1096				textParts = append(textParts, part)
1097			}
1098		}
1099
1100		convertedMessages = append(convertedMessages, fantasy.Message{
1101			Role:    fantasy.MessageRoleTool,
1102			Content: textParts,
1103		})
1104
1105		if len(mediaFiles) > 0 {
1106			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1107				"Here is the media content from the tool result:",
1108				mediaFiles...,
1109			))
1110		}
1111	}
1112
1113	return convertedMessages
1114}
1115
1116// buildSummaryPrompt constructs the prompt text for session summarization.
1117func buildSummaryPrompt(todos []session.Todo) string {
1118	var sb strings.Builder
1119	sb.WriteString("Provide a detailed summary of our conversation above.")
1120	if len(todos) > 0 {
1121		sb.WriteString("\n\n## Current Todo List\n\n")
1122		for _, t := range todos {
1123			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1124		}
1125		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1126		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1127	}
1128	return sb.String()
1129}