agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43//go:embed templates/title.md
  44var titlePrompt []byte
  45
  46//go:embed templates/summary.md
  47var summaryPrompt []byte
  48
  49// Used to remove <think> tags from generated titles.
  50var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  51
  52type SessionAgentCall struct {
  53	SessionID        string
  54	Prompt           string
  55	ProviderOptions  fantasy.ProviderOptions
  56	Attachments      []message.Attachment
  57	MaxOutputTokens  int64
  58	Temperature      *float64
  59	TopP             *float64
  60	TopK             *int64
  61	FrequencyPenalty *float64
  62	PresencePenalty  *float64
  63}
  64
  65type SessionAgent interface {
  66	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  67	SetModels(large Model, small Model)
  68	SetTools(tools []fantasy.AgentTool)
  69	Cancel(sessionID string)
  70	CancelAll()
  71	IsSessionBusy(sessionID string) bool
  72	IsBusy() bool
  73	QueuedPrompts(sessionID string) int
  74	QueuedPromptsList(sessionID string) []string
  75	ClearQueue(sessionID string)
  76	Summarize(context.Context, string, fantasy.ProviderOptions) error
  77	Model() Model
  78}
  79
  80type Model struct {
  81	Model      fantasy.LanguageModel
  82	CatwalkCfg catwalk.Model
  83	ModelCfg   config.SelectedModel
  84}
  85
  86type sessionAgent struct {
  87	largeModel           Model
  88	smallModel           Model
  89	systemPromptPrefix   string
  90	systemPrompt         string
  91	isSubAgent           bool
  92	tools                []fantasy.AgentTool
  93	sessions             session.Service
  94	messages             message.Service
  95	disableAutoSummarize bool
  96	isYolo               bool
  97
  98	messageQueue   *csync.Map[string, []SessionAgentCall]
  99	activeRequests *csync.Map[string, context.CancelFunc]
 100}
 101
 102type SessionAgentOptions struct {
 103	LargeModel           Model
 104	SmallModel           Model
 105	SystemPromptPrefix   string
 106	SystemPrompt         string
 107	IsSubAgent           bool
 108	DisableAutoSummarize bool
 109	IsYolo               bool
 110	Sessions             session.Service
 111	Messages             message.Service
 112	Tools                []fantasy.AgentTool
 113}
 114
 115func NewSessionAgent(
 116	opts SessionAgentOptions,
 117) SessionAgent {
 118	return &sessionAgent{
 119		largeModel:           opts.LargeModel,
 120		smallModel:           opts.SmallModel,
 121		systemPromptPrefix:   opts.SystemPromptPrefix,
 122		systemPrompt:         opts.SystemPrompt,
 123		isSubAgent:           opts.IsSubAgent,
 124		sessions:             opts.Sessions,
 125		messages:             opts.Messages,
 126		disableAutoSummarize: opts.DisableAutoSummarize,
 127		tools:                opts.Tools,
 128		isYolo:               opts.IsYolo,
 129		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 130		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 131	}
 132}
 133
 134func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 135	if call.Prompt == "" {
 136		return nil, ErrEmptyPrompt
 137	}
 138	if call.SessionID == "" {
 139		return nil, ErrSessionMissing
 140	}
 141
 142	// Queue the message if busy
 143	if a.IsSessionBusy(call.SessionID) {
 144		existing, ok := a.messageQueue.Get(call.SessionID)
 145		if !ok {
 146			existing = []SessionAgentCall{}
 147		}
 148		existing = append(existing, call)
 149		a.messageQueue.Set(call.SessionID, existing)
 150		return nil, nil
 151	}
 152
 153	if len(a.tools) > 0 {
 154		// Add Anthropic caching to the last tool.
 155		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 156	}
 157
 158	agent := fantasy.NewAgent(
 159		a.largeModel.Model,
 160		fantasy.WithSystemPrompt(a.systemPrompt),
 161		fantasy.WithTools(a.tools...),
 162	)
 163
 164	sessionLock := sync.Mutex{}
 165	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 166	if err != nil {
 167		return nil, fmt.Errorf("failed to get session: %w", err)
 168	}
 169
 170	msgs, err := a.getSessionMessages(ctx, currentSession)
 171	if err != nil {
 172		return nil, fmt.Errorf("failed to get session messages: %w", err)
 173	}
 174
 175	var wg sync.WaitGroup
 176	// Generate title if first message.
 177	if len(msgs) == 0 {
 178		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 179		wg.Go(func() {
 180			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 181		})
 182	}
 183
 184	// Add the user message to the session.
 185	_, err = a.createUserMessage(ctx, call)
 186	if err != nil {
 187		return nil, err
 188	}
 189
 190	// Add the session to the context.
 191	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 192
 193	genCtx, cancel := context.WithCancel(ctx)
 194	a.activeRequests.Set(call.SessionID, cancel)
 195
 196	defer cancel()
 197	defer a.activeRequests.Del(call.SessionID)
 198
 199	history, files := a.preparePrompt(msgs, call.Attachments...)
 200
 201	startTime := time.Now()
 202	a.eventPromptSent(call.SessionID)
 203
 204	var currentAssistant *message.Message
 205	var shouldSummarize bool
 206	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 207		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 208		Files:            files,
 209		Messages:         history,
 210		ProviderOptions:  call.ProviderOptions,
 211		MaxOutputTokens:  &call.MaxOutputTokens,
 212		TopP:             call.TopP,
 213		Temperature:      call.Temperature,
 214		PresencePenalty:  call.PresencePenalty,
 215		TopK:             call.TopK,
 216		FrequencyPenalty: call.FrequencyPenalty,
 217		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 218			prepared.Messages = options.Messages
 219			for i := range prepared.Messages {
 220				prepared.Messages[i].ProviderOptions = nil
 221			}
 222
 223			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 224			a.messageQueue.Del(call.SessionID)
 225			for _, queued := range queuedCalls {
 226				userMessage, createErr := a.createUserMessage(callContext, queued)
 227				if createErr != nil {
 228					return callContext, prepared, createErr
 229				}
 230				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 231			}
 232
 233			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 234
 235			lastSystemRoleInx := 0
 236			systemMessageUpdated := false
 237			for i, msg := range prepared.Messages {
 238				// Only add cache control to the last message.
 239				if msg.Role == fantasy.MessageRoleSystem {
 240					lastSystemRoleInx = i
 241				} else if !systemMessageUpdated {
 242					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 243					systemMessageUpdated = true
 244				}
 245				// Than add cache control to the last 2 messages.
 246				if i > len(prepared.Messages)-3 {
 247					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 248				}
 249			}
 250
 251			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 252				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 253			}
 254
 255			var assistantMsg message.Message
 256			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 257				Role:     message.Assistant,
 258				Parts:    []message.ContentPart{},
 259				Model:    a.largeModel.ModelCfg.Model,
 260				Provider: a.largeModel.ModelCfg.Provider,
 261			})
 262			if err != nil {
 263				return callContext, prepared, err
 264			}
 265			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 266			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 267			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 268			currentAssistant = &assistantMsg
 269			return callContext, prepared, err
 270		},
 271		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 272			currentAssistant.AppendReasoningContent(reasoning.Text)
 273			return a.messages.Update(genCtx, *currentAssistant)
 274		},
 275		OnReasoningDelta: func(id string, text string) error {
 276			currentAssistant.AppendReasoningContent(text)
 277			return a.messages.Update(genCtx, *currentAssistant)
 278		},
 279		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 280			// handle anthropic signature
 281			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 282				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 283					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 284				}
 285			}
 286			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 287				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 288					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 289				}
 290			}
 291			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 292				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 293					currentAssistant.SetReasoningResponsesData(reasoning)
 294				}
 295			}
 296			currentAssistant.FinishThinking()
 297			return a.messages.Update(genCtx, *currentAssistant)
 298		},
 299		OnTextDelta: func(id string, text string) error {
 300			// Strip leading newline from initial text content. This is is
 301			// particularly important in non-interactive mode where leading
 302			// newlines are very visible.
 303			if len(currentAssistant.Parts) == 0 {
 304				text = strings.TrimPrefix(text, "\n")
 305			}
 306
 307			currentAssistant.AppendContent(text)
 308			return a.messages.Update(genCtx, *currentAssistant)
 309		},
 310		OnToolInputStart: func(id string, toolName string) error {
 311			toolCall := message.ToolCall{
 312				ID:               id,
 313				Name:             toolName,
 314				ProviderExecuted: false,
 315				Finished:         false,
 316			}
 317			currentAssistant.AddToolCall(toolCall)
 318			return a.messages.Update(genCtx, *currentAssistant)
 319		},
 320		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 321			// TODO: implement
 322		},
 323		OnToolCall: func(tc fantasy.ToolCallContent) error {
 324			toolCall := message.ToolCall{
 325				ID:               tc.ToolCallID,
 326				Name:             tc.ToolName,
 327				Input:            tc.Input,
 328				ProviderExecuted: false,
 329				Finished:         true,
 330			}
 331			currentAssistant.AddToolCall(toolCall)
 332			return a.messages.Update(genCtx, *currentAssistant)
 333		},
 334		OnToolResult: func(result fantasy.ToolResultContent) error {
 335			toolResult := a.convertToToolResult(result)
 336			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 337				Role: message.Tool,
 338				Parts: []message.ContentPart{
 339					toolResult,
 340				},
 341			})
 342			return createMsgErr
 343		},
 344		OnStepFinish: func(stepResult fantasy.StepResult) error {
 345			finishReason := message.FinishReasonUnknown
 346			switch stepResult.FinishReason {
 347			case fantasy.FinishReasonLength:
 348				finishReason = message.FinishReasonMaxTokens
 349			case fantasy.FinishReasonStop:
 350				finishReason = message.FinishReasonEndTurn
 351			case fantasy.FinishReasonToolCalls:
 352				finishReason = message.FinishReasonToolUse
 353			}
 354			currentAssistant.AddFinish(finishReason, "", "")
 355			sessionLock.Lock()
 356			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 357			if getSessionErr != nil {
 358				sessionLock.Unlock()
 359				return getSessionErr
 360			}
 361			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 362			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 363			sessionLock.Unlock()
 364			if sessionErr != nil {
 365				return sessionErr
 366			}
 367			return a.messages.Update(genCtx, *currentAssistant)
 368		},
 369		StopWhen: []fantasy.StopCondition{
 370			func(_ []fantasy.StepResult) bool {
 371				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 372				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 373				remaining := cw - tokens
 374				var threshold int64
 375				if cw > 200_000 {
 376					threshold = 20_000
 377				} else {
 378					threshold = int64(float64(cw) * 0.2)
 379				}
 380				if (remaining <= threshold) && !a.disableAutoSummarize {
 381					shouldSummarize = true
 382					return true
 383				}
 384				return false
 385			},
 386		},
 387	})
 388
 389	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 390
 391	if err != nil {
 392		isCancelErr := errors.Is(err, context.Canceled)
 393		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 394		if currentAssistant == nil {
 395			return result, err
 396		}
 397		// Ensure we finish thinking on error to close the reasoning state.
 398		currentAssistant.FinishThinking()
 399		toolCalls := currentAssistant.ToolCalls()
 400		// INFO: we use the parent context here because the genCtx has been cancelled.
 401		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 402		if createErr != nil {
 403			return nil, createErr
 404		}
 405		for _, tc := range toolCalls {
 406			if !tc.Finished {
 407				tc.Finished = true
 408				tc.Input = "{}"
 409				currentAssistant.AddToolCall(tc)
 410				updateErr := a.messages.Update(ctx, *currentAssistant)
 411				if updateErr != nil {
 412					return nil, updateErr
 413				}
 414			}
 415
 416			found := false
 417			for _, msg := range msgs {
 418				if msg.Role == message.Tool {
 419					for _, tr := range msg.ToolResults() {
 420						if tr.ToolCallID == tc.ID {
 421							found = true
 422							break
 423						}
 424					}
 425				}
 426				if found {
 427					break
 428				}
 429			}
 430			if found {
 431				continue
 432			}
 433			content := "There was an error while executing the tool"
 434			if isCancelErr {
 435				content = "Tool execution canceled by user"
 436			} else if isPermissionErr {
 437				content = "User denied permission"
 438			}
 439			toolResult := message.ToolResult{
 440				ToolCallID: tc.ID,
 441				Name:       tc.Name,
 442				Content:    content,
 443				IsError:    true,
 444			}
 445			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 446				Role: message.Tool,
 447				Parts: []message.ContentPart{
 448					toolResult,
 449				},
 450			})
 451			if createErr != nil {
 452				return nil, createErr
 453			}
 454		}
 455		var fantasyErr *fantasy.Error
 456		var providerErr *fantasy.ProviderError
 457		const defaultTitle = "Provider Error"
 458		if isCancelErr {
 459			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 460		} else if isPermissionErr {
 461			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 462		} else if errors.Is(err, hyper.ErrNoCredits) {
 463			url := hyper.BaseURL()
 464			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 465			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 466		} else if errors.As(err, &providerErr) {
 467			if providerErr.Message == "The requested model is not supported." {
 468				url := "https://github.com/settings/copilot/features"
 469				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 470				currentAssistant.AddFinish(
 471					message.FinishReasonError,
 472					"Copilot model not enabled",
 473					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 474				)
 475			} else {
 476				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 477			}
 478		} else if errors.As(err, &fantasyErr) {
 479			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 480		} else {
 481			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 482		}
 483		// Note: we use the parent context here because the genCtx has been
 484		// cancelled.
 485		updateErr := a.messages.Update(ctx, *currentAssistant)
 486		if updateErr != nil {
 487			return nil, updateErr
 488		}
 489		return nil, err
 490	}
 491	wg.Wait()
 492
 493	if shouldSummarize {
 494		a.activeRequests.Del(call.SessionID)
 495		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 496			return nil, summarizeErr
 497		}
 498		// If the agent wasn't done...
 499		if len(currentAssistant.ToolCalls()) > 0 {
 500			existing, ok := a.messageQueue.Get(call.SessionID)
 501			if !ok {
 502				existing = []SessionAgentCall{}
 503			}
 504			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 505			existing = append(existing, call)
 506			a.messageQueue.Set(call.SessionID, existing)
 507		}
 508	}
 509
 510	// Release active request before processing queued messages.
 511	a.activeRequests.Del(call.SessionID)
 512	cancel()
 513
 514	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 515	if !ok || len(queuedMessages) == 0 {
 516		return result, err
 517	}
 518	// There are queued messages restart the loop.
 519	firstQueuedMessage := queuedMessages[0]
 520	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 521	return a.Run(ctx, firstQueuedMessage)
 522}
 523
 524func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 525	if a.IsSessionBusy(sessionID) {
 526		return ErrSessionBusy
 527	}
 528
 529	currentSession, err := a.sessions.Get(ctx, sessionID)
 530	if err != nil {
 531		return fmt.Errorf("failed to get session: %w", err)
 532	}
 533	msgs, err := a.getSessionMessages(ctx, currentSession)
 534	if err != nil {
 535		return err
 536	}
 537	if len(msgs) == 0 {
 538		// Nothing to summarize.
 539		return nil
 540	}
 541
 542	aiMsgs, _ := a.preparePrompt(msgs)
 543
 544	genCtx, cancel := context.WithCancel(ctx)
 545	a.activeRequests.Set(sessionID, cancel)
 546	defer a.activeRequests.Del(sessionID)
 547	defer cancel()
 548
 549	agent := fantasy.NewAgent(a.largeModel.Model,
 550		fantasy.WithSystemPrompt(string(summaryPrompt)),
 551	)
 552	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 553		Role:             message.Assistant,
 554		Model:            a.largeModel.Model.Model(),
 555		Provider:         a.largeModel.Model.Provider(),
 556		IsSummaryMessage: true,
 557	})
 558	if err != nil {
 559		return err
 560	}
 561
 562	summaryPromptText := "Provide a detailed summary of our conversation above."
 563	if len(currentSession.Todos) > 0 {
 564		summaryPromptText += "\n\n## Current Todo List\n\n"
 565		for _, t := range currentSession.Todos {
 566			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 567		}
 568		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 569		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 570	}
 571
 572	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 573		Prompt:          summaryPromptText,
 574		Messages:        aiMsgs,
 575		ProviderOptions: opts,
 576		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 577			prepared.Messages = options.Messages
 578			if a.systemPromptPrefix != "" {
 579				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 580			}
 581			return callContext, prepared, nil
 582		},
 583		OnReasoningDelta: func(id string, text string) error {
 584			summaryMessage.AppendReasoningContent(text)
 585			return a.messages.Update(genCtx, summaryMessage)
 586		},
 587		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 588			// Handle anthropic signature.
 589			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 590				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 591					summaryMessage.AppendReasoningSignature(signature.Signature)
 592				}
 593			}
 594			summaryMessage.FinishThinking()
 595			return a.messages.Update(genCtx, summaryMessage)
 596		},
 597		OnTextDelta: func(id, text string) error {
 598			summaryMessage.AppendContent(text)
 599			return a.messages.Update(genCtx, summaryMessage)
 600		},
 601	})
 602	if err != nil {
 603		isCancelErr := errors.Is(err, context.Canceled)
 604		if isCancelErr {
 605			// User cancelled summarize we need to remove the summary message.
 606			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 607			return deleteErr
 608		}
 609		return err
 610	}
 611
 612	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 613	err = a.messages.Update(genCtx, summaryMessage)
 614	if err != nil {
 615		return err
 616	}
 617
 618	var openrouterCost *float64
 619	for _, step := range resp.Steps {
 620		stepCost := a.openrouterCost(step.ProviderMetadata)
 621		if stepCost != nil {
 622			newCost := *stepCost
 623			if openrouterCost != nil {
 624				newCost += *openrouterCost
 625			}
 626			openrouterCost = &newCost
 627		}
 628	}
 629
 630	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 631
 632	// Just in case, get just the last usage info.
 633	usage := resp.Response.Usage
 634	currentSession.SummaryMessageID = summaryMessage.ID
 635	currentSession.CompletionTokens = usage.OutputTokens
 636	currentSession.PromptTokens = 0
 637	_, err = a.sessions.Save(genCtx, currentSession)
 638	return err
 639}
 640
 641func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 642	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 643		return fantasy.ProviderOptions{}
 644	}
 645	return fantasy.ProviderOptions{
 646		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 647			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 648		},
 649		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 650			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 651		},
 652	}
 653}
 654
 655func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 656	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 657	var attachmentParts []message.ContentPart
 658	for _, attachment := range call.Attachments {
 659		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 660	}
 661	parts = append(parts, attachmentParts...)
 662	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 663		Role:  message.User,
 664		Parts: parts,
 665	})
 666	if err != nil {
 667		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 668	}
 669	return msg, nil
 670}
 671
 672func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 673	var history []fantasy.Message
 674	if !a.isSubAgent {
 675		history = append(history, fantasy.NewUserMessage(
 676			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 677				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 678If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 679If not, please feel free to ignore. Again do not mention this message to the user.`,
 680			),
 681		))
 682	}
 683	for _, m := range msgs {
 684		if len(m.Parts) == 0 {
 685			continue
 686		}
 687		// Assistant message without content or tool calls (cancelled before it
 688		// returned anything).
 689		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 690			continue
 691		}
 692		history = append(history, m.ToAIMessage()...)
 693	}
 694
 695	var files []fantasy.FilePart
 696	for _, attachment := range attachments {
 697		if attachment.IsText() {
 698			continue
 699		}
 700		files = append(files, fantasy.FilePart{
 701			Filename:  attachment.FileName,
 702			Data:      attachment.Content,
 703			MediaType: attachment.MimeType,
 704		})
 705	}
 706
 707	return history, files
 708}
 709
 710func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 711	msgs, err := a.messages.List(ctx, session.ID)
 712	if err != nil {
 713		return nil, fmt.Errorf("failed to list messages: %w", err)
 714	}
 715
 716	if session.SummaryMessageID != "" {
 717		summaryMsgInex := -1
 718		for i, msg := range msgs {
 719			if msg.ID == session.SummaryMessageID {
 720				summaryMsgInex = i
 721				break
 722			}
 723		}
 724		if summaryMsgInex != -1 {
 725			msgs = msgs[summaryMsgInex:]
 726			msgs[0].Role = message.User
 727		}
 728	}
 729	return msgs, nil
 730}
 731
 732func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
 733	if prompt == "" {
 734		return
 735	}
 736
 737	var maxOutput int64 = 40
 738	if a.smallModel.CatwalkCfg.CanReason {
 739		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 740	}
 741
 742	agent := fantasy.NewAgent(a.smallModel.Model,
 743		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 744		fantasy.WithMaxOutputTokens(maxOutput),
 745	)
 746
 747	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 748		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 749		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 750			prepared.Messages = options.Messages
 751			if a.systemPromptPrefix != "" {
 752				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 753			}
 754			return callContext, prepared, nil
 755		},
 756	})
 757	if err != nil {
 758		slog.Error("error generating title", "err", err)
 759		return
 760	}
 761
 762	title := resp.Response.Content.Text()
 763
 764	title = strings.ReplaceAll(title, "\n", " ")
 765
 766	// Remove thinking tags if present.
 767	title = thinkTagRegex.ReplaceAllString(title, "")
 768
 769	title = strings.TrimSpace(title)
 770	if title == "" {
 771		slog.Warn("failed to generate title", "warn", "empty title")
 772		return
 773	}
 774
 775	// Calculate usage and cost.
 776	var openrouterCost *float64
 777	for _, step := range resp.Steps {
 778		stepCost := a.openrouterCost(step.ProviderMetadata)
 779		if stepCost != nil {
 780			newCost := *stepCost
 781			if openrouterCost != nil {
 782				newCost += *openrouterCost
 783			}
 784			openrouterCost = &newCost
 785		}
 786	}
 787
 788	modelConfig := a.smallModel.CatwalkCfg
 789	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 790		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 791		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 792		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 793
 794	if a.isClaudeCode() {
 795		cost = 0
 796	}
 797
 798	// Use override cost if available (e.g., from OpenRouter).
 799	if openrouterCost != nil {
 800		cost = *openrouterCost
 801	}
 802
 803	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 804	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 805
 806	// Atomically update only title and usage fields to avoid overriding other
 807	// concurrent session updates.
 808	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 809	if saveErr != nil {
 810		slog.Error("failed to save session title & usage", "error", saveErr)
 811		return
 812	}
 813}
 814
 815func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 816	openrouterMetadata, ok := metadata[openrouter.Name]
 817	if !ok {
 818		return nil
 819	}
 820
 821	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 822	if !ok {
 823		return nil
 824	}
 825	return &opts.Usage.Cost
 826}
 827
 828func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 829	modelConfig := model.CatwalkCfg
 830	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 831		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 832		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 833		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 834
 835	if a.isClaudeCode() {
 836		cost = 0
 837	}
 838
 839	a.eventTokensUsed(session.ID, model, usage, cost)
 840
 841	if overrideCost != nil {
 842		session.Cost += *overrideCost
 843	} else {
 844		session.Cost += cost
 845	}
 846
 847	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 848	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 849}
 850
 851func (a *sessionAgent) Cancel(sessionID string) {
 852	// Cancel regular requests.
 853	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 854		slog.Info("Request cancellation initiated", "session_id", sessionID)
 855		cancel()
 856	}
 857
 858	// Also check for summarize requests.
 859	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 860		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 861		cancel()
 862	}
 863
 864	if a.QueuedPrompts(sessionID) > 0 {
 865		slog.Info("Clearing queued prompts", "session_id", sessionID)
 866		a.messageQueue.Del(sessionID)
 867	}
 868}
 869
 870func (a *sessionAgent) ClearQueue(sessionID string) {
 871	if a.QueuedPrompts(sessionID) > 0 {
 872		slog.Info("Clearing queued prompts", "session_id", sessionID)
 873		a.messageQueue.Del(sessionID)
 874	}
 875}
 876
 877func (a *sessionAgent) CancelAll() {
 878	if !a.IsBusy() {
 879		return
 880	}
 881	for key := range a.activeRequests.Seq2() {
 882		a.Cancel(key) // key is sessionID
 883	}
 884
 885	timeout := time.After(5 * time.Second)
 886	for a.IsBusy() {
 887		select {
 888		case <-timeout:
 889			return
 890		default:
 891			time.Sleep(200 * time.Millisecond)
 892		}
 893	}
 894}
 895
 896func (a *sessionAgent) IsBusy() bool {
 897	var busy bool
 898	for cancelFunc := range a.activeRequests.Seq() {
 899		if cancelFunc != nil {
 900			busy = true
 901			break
 902		}
 903	}
 904	return busy
 905}
 906
 907func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 908	_, busy := a.activeRequests.Get(sessionID)
 909	return busy
 910}
 911
 912func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 913	l, ok := a.messageQueue.Get(sessionID)
 914	if !ok {
 915		return 0
 916	}
 917	return len(l)
 918}
 919
 920func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 921	l, ok := a.messageQueue.Get(sessionID)
 922	if !ok {
 923		return nil
 924	}
 925	prompts := make([]string, len(l))
 926	for i, call := range l {
 927		prompts[i] = call.Prompt
 928	}
 929	return prompts
 930}
 931
 932func (a *sessionAgent) SetModels(large Model, small Model) {
 933	a.largeModel = large
 934	a.smallModel = small
 935}
 936
 937func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 938	a.tools = tools
 939}
 940
 941func (a *sessionAgent) Model() Model {
 942	return a.largeModel
 943}
 944
 945func (a *sessionAgent) promptPrefix() string {
 946	if a.isClaudeCode() {
 947		return "You are Claude Code, Anthropic's official CLI for Claude."
 948	}
 949	return a.systemPromptPrefix
 950}
 951
 952func (a *sessionAgent) isClaudeCode() bool {
 953	cfg := config.Get()
 954	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 955	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 956}
 957
 958// convertToToolResult converts a fantasy tool result to a message tool result.
 959func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 960	baseResult := message.ToolResult{
 961		ToolCallID: result.ToolCallID,
 962		Name:       result.ToolName,
 963		Metadata:   result.ClientMetadata,
 964	}
 965
 966	switch result.Result.GetType() {
 967	case fantasy.ToolResultContentTypeText:
 968		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 969			baseResult.Content = r.Text
 970		}
 971	case fantasy.ToolResultContentTypeError:
 972		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 973			baseResult.Content = r.Error.Error()
 974			baseResult.IsError = true
 975		}
 976	case fantasy.ToolResultContentTypeMedia:
 977		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 978			content := r.Text
 979			if content == "" {
 980				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 981			}
 982			baseResult.Content = content
 983			baseResult.Data = r.Data
 984			baseResult.MIMEType = r.MediaType
 985		}
 986	}
 987
 988	return baseResult
 989}
 990
 991// workaroundProviderMediaLimitations converts media content in tool results to
 992// user messages for providers that don't natively support images in tool results.
 993//
 994// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 995// don't support sending images/media in tool result messages - they only accept
 996// text in tool results. However, they DO support images in user messages.
 997//
 998// If we send media in tool results to these providers, the API returns an error.
 999//
1000// Solution: For these providers, we:
1001//  1. Replace the media in the tool result with a text placeholder
1002//  2. Inject a user message immediately after with the image as a file attachment
1003//  3. This maintains the tool execution flow while working around API limitations
1004//
1005// Anthropic and Bedrock support images natively in tool results, so we skip
1006// this workaround for them.
1007//
1008// Example transformation:
1009//
1010//	BEFORE: [tool result: image data]
1011//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1012func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1013	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1014		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1015
1016	if providerSupportsMedia {
1017		return messages
1018	}
1019
1020	convertedMessages := make([]fantasy.Message, 0, len(messages))
1021
1022	for _, msg := range messages {
1023		if msg.Role != fantasy.MessageRoleTool {
1024			convertedMessages = append(convertedMessages, msg)
1025			continue
1026		}
1027
1028		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1029		var mediaFiles []fantasy.FilePart
1030
1031		for _, part := range msg.Content {
1032			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1033			if !ok {
1034				textParts = append(textParts, part)
1035				continue
1036			}
1037
1038			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1039				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1040				if err != nil {
1041					slog.Warn("failed to decode media data", "error", err)
1042					textParts = append(textParts, part)
1043					continue
1044				}
1045
1046				mediaFiles = append(mediaFiles, fantasy.FilePart{
1047					Data:      decoded,
1048					MediaType: media.MediaType,
1049					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1050				})
1051
1052				textParts = append(textParts, fantasy.ToolResultPart{
1053					ToolCallID: toolResult.ToolCallID,
1054					Output: fantasy.ToolResultOutputContentText{
1055						Text: "[Image/media content loaded - see attached file]",
1056					},
1057					ProviderOptions: toolResult.ProviderOptions,
1058				})
1059			} else {
1060				textParts = append(textParts, part)
1061			}
1062		}
1063
1064		convertedMessages = append(convertedMessages, fantasy.Message{
1065			Role:    fantasy.MessageRoleTool,
1066			Content: textParts,
1067		})
1068
1069		if len(mediaFiles) > 0 {
1070			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1071				"Here is the media content from the tool result:",
1072				mediaFiles...,
1073			))
1074		}
1075	}
1076
1077	return convertedMessages
1078}