agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41	"github.com/charmbracelet/x/exp/charmtone"
  42)
  43
  44const (
  45	defaultSessionName = "Untitled Session"
  46
  47	// Constants for auto-summarization thresholds
  48	largeContextWindowThreshold = 200_000
  49	largeContextWindowBuffer    = 20_000
  50	smallContextWindowRatio     = 0.2
  51)
  52
  53//go:embed templates/title.md
  54var titlePrompt []byte
  55
  56//go:embed templates/summary.md
  57var summaryPrompt []byte
  58
  59// Used to remove <think> tags from generated titles.
  60var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  61
  62type SessionAgentCall struct {
  63	SessionID        string
  64	Prompt           string
  65	ProviderOptions  fantasy.ProviderOptions
  66	Attachments      []message.Attachment
  67	MaxOutputTokens  int64
  68	Temperature      *float64
  69	TopP             *float64
  70	TopK             *int64
  71	FrequencyPenalty *float64
  72	PresencePenalty  *float64
  73}
  74
  75type SessionAgent interface {
  76	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  77	SetModels(large Model, small Model)
  78	SetTools(tools []fantasy.AgentTool)
  79	SetSystemPrompt(systemPrompt string)
  80	Cancel(sessionID string)
  81	CancelAll()
  82	IsSessionBusy(sessionID string) bool
  83	IsBusy() bool
  84	QueuedPrompts(sessionID string) int
  85	QueuedPromptsList(sessionID string) []string
  86	ClearQueue(sessionID string)
  87	Summarize(context.Context, string, fantasy.ProviderOptions) error
  88	Model() Model
  89}
  90
  91type Model struct {
  92	Model      fantasy.LanguageModel
  93	CatwalkCfg catwalk.Model
  94	ModelCfg   config.SelectedModel
  95}
  96
  97type sessionAgent struct {
  98	largeModel         *csync.Value[Model]
  99	smallModel         *csync.Value[Model]
 100	systemPromptPrefix *csync.Value[string]
 101	systemPrompt       *csync.Value[string]
 102	tools              *csync.Slice[fantasy.AgentTool]
 103
 104	isSubAgent           bool
 105	sessions             session.Service
 106	messages             message.Service
 107	disableAutoSummarize bool
 108	isYolo               bool
 109
 110	messageQueue   *csync.Map[string, []SessionAgentCall]
 111	activeRequests *csync.Map[string, context.CancelFunc]
 112}
 113
 114type SessionAgentOptions struct {
 115	LargeModel           Model
 116	SmallModel           Model
 117	SystemPromptPrefix   string
 118	SystemPrompt         string
 119	IsSubAgent           bool
 120	DisableAutoSummarize bool
 121	IsYolo               bool
 122	Sessions             session.Service
 123	Messages             message.Service
 124	Tools                []fantasy.AgentTool
 125}
 126
 127func NewSessionAgent(
 128	opts SessionAgentOptions,
 129) SessionAgent {
 130	return &sessionAgent{
 131		largeModel:           csync.NewValue(opts.LargeModel),
 132		smallModel:           csync.NewValue(opts.SmallModel),
 133		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
 134		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 135		isSubAgent:           opts.IsSubAgent,
 136		sessions:             opts.Sessions,
 137		messages:             opts.Messages,
 138		disableAutoSummarize: opts.DisableAutoSummarize,
 139		tools:                csync.NewSliceFrom(opts.Tools),
 140		isYolo:               opts.IsYolo,
 141		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 142		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 143	}
 144}
 145
 146func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 147	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 148		return nil, ErrEmptyPrompt
 149	}
 150	if call.SessionID == "" {
 151		return nil, ErrSessionMissing
 152	}
 153
 154	// Queue the message if busy
 155	if a.IsSessionBusy(call.SessionID) {
 156		existing, ok := a.messageQueue.Get(call.SessionID)
 157		if !ok {
 158			existing = []SessionAgentCall{}
 159		}
 160		existing = append(existing, call)
 161		a.messageQueue.Set(call.SessionID, existing)
 162		return nil, nil
 163	}
 164
 165	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
 166	agentTools := a.tools.Copy()
 167	largeModel := a.largeModel.Get()
 168	systemPrompt := a.systemPrompt.Get()
 169	promptPrefix := a.systemPromptPrefix.Get()
 170
 171	if len(agentTools) > 0 {
 172		// Add Anthropic caching to the last tool.
 173		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 174	}
 175
 176	agent := fantasy.NewAgent(
 177		largeModel.Model,
 178		fantasy.WithSystemPrompt(systemPrompt),
 179		fantasy.WithTools(agentTools...),
 180	)
 181
 182	sessionLock := sync.Mutex{}
 183	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 184	if err != nil {
 185		return nil, fmt.Errorf("failed to get session: %w", err)
 186	}
 187
 188	msgs, err := a.getSessionMessages(ctx, currentSession)
 189	if err != nil {
 190		return nil, fmt.Errorf("failed to get session messages: %w", err)
 191	}
 192
 193	var wg sync.WaitGroup
 194	// Generate title if first message.
 195	if len(msgs) == 0 {
 196		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 197		wg.Go(func() {
 198			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 199		})
 200	}
 201	defer wg.Wait()
 202
 203	// Add the user message to the session.
 204	_, err = a.createUserMessage(ctx, call)
 205	if err != nil {
 206		return nil, err
 207	}
 208
 209	// Add the session to the context.
 210	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 211
 212	genCtx, cancel := context.WithCancel(ctx)
 213	a.activeRequests.Set(call.SessionID, cancel)
 214
 215	defer cancel()
 216	defer a.activeRequests.Del(call.SessionID)
 217
 218	history, files := a.preparePrompt(msgs, call.Attachments...)
 219
 220	startTime := time.Now()
 221	a.eventPromptSent(call.SessionID)
 222
 223	var currentAssistant *message.Message
 224	var shouldSummarize bool
 225	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 226		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 227		Files:            files,
 228		Messages:         history,
 229		ProviderOptions:  call.ProviderOptions,
 230		MaxOutputTokens:  &call.MaxOutputTokens,
 231		TopP:             call.TopP,
 232		Temperature:      call.Temperature,
 233		PresencePenalty:  call.PresencePenalty,
 234		TopK:             call.TopK,
 235		FrequencyPenalty: call.FrequencyPenalty,
 236		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 237			prepared.Messages = options.Messages
 238			for i := range prepared.Messages {
 239				prepared.Messages[i].ProviderOptions = nil
 240			}
 241
 242			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 243			a.messageQueue.Del(call.SessionID)
 244			for _, queued := range queuedCalls {
 245				userMessage, createErr := a.createUserMessage(callContext, queued)
 246				if createErr != nil {
 247					return callContext, prepared, createErr
 248				}
 249				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 250			}
 251
 252			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 253
 254			lastSystemRoleInx := 0
 255			systemMessageUpdated := false
 256			for i, msg := range prepared.Messages {
 257				// Only add cache control to the last message.
 258				if msg.Role == fantasy.MessageRoleSystem {
 259					lastSystemRoleInx = i
 260				} else if !systemMessageUpdated {
 261					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 262					systemMessageUpdated = true
 263				}
 264				// Than add cache control to the last 2 messages.
 265				if i > len(prepared.Messages)-3 {
 266					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 267				}
 268			}
 269
 270			if promptPrefix != "" {
 271				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 272			}
 273
 274			var assistantMsg message.Message
 275			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 276				Role:     message.Assistant,
 277				Parts:    []message.ContentPart{},
 278				Model:    largeModel.ModelCfg.Model,
 279				Provider: largeModel.ModelCfg.Provider,
 280			})
 281			if err != nil {
 282				return callContext, prepared, err
 283			}
 284			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 285			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
 286			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 287			currentAssistant = &assistantMsg
 288			return callContext, prepared, err
 289		},
 290		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 291			currentAssistant.AppendReasoningContent(reasoning.Text)
 292			return a.messages.Update(genCtx, *currentAssistant)
 293		},
 294		OnReasoningDelta: func(id string, text string) error {
 295			currentAssistant.AppendReasoningContent(text)
 296			return a.messages.Update(genCtx, *currentAssistant)
 297		},
 298		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 299			// handle anthropic signature
 300			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 301				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 302					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 303				}
 304			}
 305			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 306				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 307					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 308				}
 309			}
 310			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 311				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 312					currentAssistant.SetReasoningResponsesData(reasoning)
 313				}
 314			}
 315			currentAssistant.FinishThinking()
 316			return a.messages.Update(genCtx, *currentAssistant)
 317		},
 318		OnTextDelta: func(id string, text string) error {
 319			// Strip leading newline from initial text content. This is is
 320			// particularly important in non-interactive mode where leading
 321			// newlines are very visible.
 322			if len(currentAssistant.Parts) == 0 {
 323				text = strings.TrimPrefix(text, "\n")
 324			}
 325
 326			currentAssistant.AppendContent(text)
 327			return a.messages.Update(genCtx, *currentAssistant)
 328		},
 329		OnToolInputStart: func(id string, toolName string) error {
 330			toolCall := message.ToolCall{
 331				ID:               id,
 332				Name:             toolName,
 333				ProviderExecuted: false,
 334				Finished:         false,
 335			}
 336			currentAssistant.AddToolCall(toolCall)
 337			return a.messages.Update(genCtx, *currentAssistant)
 338		},
 339		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 340			// TODO: implement
 341		},
 342		OnToolCall: func(tc fantasy.ToolCallContent) error {
 343			toolCall := message.ToolCall{
 344				ID:               tc.ToolCallID,
 345				Name:             tc.ToolName,
 346				Input:            tc.Input,
 347				ProviderExecuted: false,
 348				Finished:         true,
 349			}
 350			currentAssistant.AddToolCall(toolCall)
 351			return a.messages.Update(genCtx, *currentAssistant)
 352		},
 353		OnToolResult: func(result fantasy.ToolResultContent) error {
 354			toolResult := a.convertToToolResult(result)
 355			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 356				Role: message.Tool,
 357				Parts: []message.ContentPart{
 358					toolResult,
 359				},
 360			})
 361			return createMsgErr
 362		},
 363		OnStepFinish: func(stepResult fantasy.StepResult) error {
 364			finishReason := message.FinishReasonUnknown
 365			switch stepResult.FinishReason {
 366			case fantasy.FinishReasonLength:
 367				finishReason = message.FinishReasonMaxTokens
 368			case fantasy.FinishReasonStop:
 369				finishReason = message.FinishReasonEndTurn
 370			case fantasy.FinishReasonToolCalls:
 371				finishReason = message.FinishReasonToolUse
 372			}
 373			currentAssistant.AddFinish(finishReason, "", "")
 374			sessionLock.Lock()
 375			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 376			if getSessionErr != nil {
 377				sessionLock.Unlock()
 378				return getSessionErr
 379			}
 380			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 381			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 382			if sessionErr == nil {
 383				currentSession = updatedSession
 384			}
 385			sessionLock.Unlock()
 386			if sessionErr != nil {
 387				return sessionErr
 388			}
 389			return a.messages.Update(genCtx, *currentAssistant)
 390		},
 391		StopWhen: []fantasy.StopCondition{
 392			func(_ []fantasy.StepResult) bool {
 393				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 394				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 395				remaining := cw - tokens
 396				var threshold int64
 397				if cw > largeContextWindowThreshold {
 398					threshold = largeContextWindowBuffer
 399				} else {
 400					threshold = int64(float64(cw) * smallContextWindowRatio)
 401				}
 402				if (remaining <= threshold) && !a.disableAutoSummarize {
 403					shouldSummarize = true
 404					return true
 405				}
 406				return false
 407			},
 408		},
 409	})
 410
 411	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 412
 413	if err != nil {
 414		isCancelErr := errors.Is(err, context.Canceled)
 415		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 416		if currentAssistant == nil {
 417			return result, err
 418		}
 419		// Ensure we finish thinking on error to close the reasoning state.
 420		currentAssistant.FinishThinking()
 421		toolCalls := currentAssistant.ToolCalls()
 422		// INFO: we use the parent context here because the genCtx has been cancelled.
 423		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 424		if createErr != nil {
 425			return nil, createErr
 426		}
 427		for _, tc := range toolCalls {
 428			if !tc.Finished {
 429				tc.Finished = true
 430				tc.Input = "{}"
 431				currentAssistant.AddToolCall(tc)
 432				updateErr := a.messages.Update(ctx, *currentAssistant)
 433				if updateErr != nil {
 434					return nil, updateErr
 435				}
 436			}
 437
 438			found := false
 439			for _, msg := range msgs {
 440				if msg.Role == message.Tool {
 441					for _, tr := range msg.ToolResults() {
 442						if tr.ToolCallID == tc.ID {
 443							found = true
 444							break
 445						}
 446					}
 447				}
 448				if found {
 449					break
 450				}
 451			}
 452			if found {
 453				continue
 454			}
 455			content := "There was an error while executing the tool"
 456			if isCancelErr {
 457				content = "Tool execution canceled by user"
 458			} else if isPermissionErr {
 459				content = "User denied permission"
 460			}
 461			toolResult := message.ToolResult{
 462				ToolCallID: tc.ID,
 463				Name:       tc.Name,
 464				Content:    content,
 465				IsError:    true,
 466			}
 467			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 468				Role: message.Tool,
 469				Parts: []message.ContentPart{
 470					toolResult,
 471				},
 472			})
 473			if createErr != nil {
 474				return nil, createErr
 475			}
 476		}
 477		var fantasyErr *fantasy.Error
 478		var providerErr *fantasy.ProviderError
 479		const defaultTitle = "Provider Error"
 480		linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
 481		if isCancelErr {
 482			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 483		} else if isPermissionErr {
 484			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 485		} else if errors.Is(err, hyper.ErrNoCredits) {
 486			url := hyper.BaseURL()
 487			link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
 488			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 489		} else if errors.As(err, &providerErr) {
 490			if providerErr.Message == "The requested model is not supported." {
 491				url := "https://github.com/settings/copilot/features"
 492				link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
 493				currentAssistant.AddFinish(
 494					message.FinishReasonError,
 495					"Copilot model not enabled",
 496					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 497				)
 498			} else {
 499				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 500			}
 501		} else if errors.As(err, &fantasyErr) {
 502			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 503		} else {
 504			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 505		}
 506		// Note: we use the parent context here because the genCtx has been
 507		// cancelled.
 508		updateErr := a.messages.Update(ctx, *currentAssistant)
 509		if updateErr != nil {
 510			return nil, updateErr
 511		}
 512		return nil, err
 513	}
 514
 515	if shouldSummarize {
 516		a.activeRequests.Del(call.SessionID)
 517		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 518			return nil, summarizeErr
 519		}
 520		// If the agent wasn't done...
 521		if len(currentAssistant.ToolCalls()) > 0 {
 522			existing, ok := a.messageQueue.Get(call.SessionID)
 523			if !ok {
 524				existing = []SessionAgentCall{}
 525			}
 526			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 527			existing = append(existing, call)
 528			a.messageQueue.Set(call.SessionID, existing)
 529		}
 530	}
 531
 532	// Release active request before processing queued messages.
 533	a.activeRequests.Del(call.SessionID)
 534	cancel()
 535
 536	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 537	if !ok || len(queuedMessages) == 0 {
 538		return result, err
 539	}
 540	// There are queued messages restart the loop.
 541	firstQueuedMessage := queuedMessages[0]
 542	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 543	return a.Run(ctx, firstQueuedMessage)
 544}
 545
 546func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 547	if a.IsSessionBusy(sessionID) {
 548		return ErrSessionBusy
 549	}
 550
 551	// Copy mutable fields under lock to avoid races with SetModels.
 552	largeModel := a.largeModel.Get()
 553	systemPromptPrefix := a.systemPromptPrefix.Get()
 554
 555	currentSession, err := a.sessions.Get(ctx, sessionID)
 556	if err != nil {
 557		return fmt.Errorf("failed to get session: %w", err)
 558	}
 559	msgs, err := a.getSessionMessages(ctx, currentSession)
 560	if err != nil {
 561		return err
 562	}
 563	if len(msgs) == 0 {
 564		// Nothing to summarize.
 565		return nil
 566	}
 567
 568	aiMsgs, _ := a.preparePrompt(msgs)
 569
 570	genCtx, cancel := context.WithCancel(ctx)
 571	a.activeRequests.Set(sessionID, cancel)
 572	defer a.activeRequests.Del(sessionID)
 573	defer cancel()
 574
 575	agent := fantasy.NewAgent(largeModel.Model,
 576		fantasy.WithSystemPrompt(string(summaryPrompt)),
 577	)
 578	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 579		Role:             message.Assistant,
 580		Model:            largeModel.Model.Model(),
 581		Provider:         largeModel.Model.Provider(),
 582		IsSummaryMessage: true,
 583	})
 584	if err != nil {
 585		return err
 586	}
 587
 588	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 589
 590	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 591		Prompt:          summaryPromptText,
 592		Messages:        aiMsgs,
 593		ProviderOptions: opts,
 594		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 595			prepared.Messages = options.Messages
 596			if systemPromptPrefix != "" {
 597				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 598			}
 599			return callContext, prepared, nil
 600		},
 601		OnReasoningDelta: func(id string, text string) error {
 602			summaryMessage.AppendReasoningContent(text)
 603			return a.messages.Update(genCtx, summaryMessage)
 604		},
 605		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 606			// Handle anthropic signature.
 607			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 608				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 609					summaryMessage.AppendReasoningSignature(signature.Signature)
 610				}
 611			}
 612			summaryMessage.FinishThinking()
 613			return a.messages.Update(genCtx, summaryMessage)
 614		},
 615		OnTextDelta: func(id, text string) error {
 616			summaryMessage.AppendContent(text)
 617			return a.messages.Update(genCtx, summaryMessage)
 618		},
 619	})
 620	if err != nil {
 621		isCancelErr := errors.Is(err, context.Canceled)
 622		if isCancelErr {
 623			// User cancelled summarize we need to remove the summary message.
 624			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 625			return deleteErr
 626		}
 627		return err
 628	}
 629
 630	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 631	err = a.messages.Update(genCtx, summaryMessage)
 632	if err != nil {
 633		return err
 634	}
 635
 636	var openrouterCost *float64
 637	for _, step := range resp.Steps {
 638		stepCost := a.openrouterCost(step.ProviderMetadata)
 639		if stepCost != nil {
 640			newCost := *stepCost
 641			if openrouterCost != nil {
 642				newCost += *openrouterCost
 643			}
 644			openrouterCost = &newCost
 645		}
 646	}
 647
 648	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 649
 650	// Just in case, get just the last usage info.
 651	usage := resp.Response.Usage
 652	currentSession.SummaryMessageID = summaryMessage.ID
 653	currentSession.CompletionTokens = usage.OutputTokens
 654	currentSession.PromptTokens = 0
 655	_, err = a.sessions.Save(genCtx, currentSession)
 656	return err
 657}
 658
 659func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 660	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 661		return fantasy.ProviderOptions{}
 662	}
 663	return fantasy.ProviderOptions{
 664		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 665			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 666		},
 667		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 668			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 669		},
 670	}
 671}
 672
 673func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 674	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 675	var attachmentParts []message.ContentPart
 676	for _, attachment := range call.Attachments {
 677		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 678	}
 679	parts = append(parts, attachmentParts...)
 680	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 681		Role:  message.User,
 682		Parts: parts,
 683	})
 684	if err != nil {
 685		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 686	}
 687	return msg, nil
 688}
 689
 690func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 691	var history []fantasy.Message
 692	if !a.isSubAgent {
 693		history = append(history, fantasy.NewUserMessage(
 694			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 695				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 696If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 697If not, please feel free to ignore. Again do not mention this message to the user.`,
 698			),
 699		))
 700	}
 701	for _, m := range msgs {
 702		if len(m.Parts) == 0 {
 703			continue
 704		}
 705		// Assistant message without content or tool calls (cancelled before it
 706		// returned anything).
 707		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 708			continue
 709		}
 710		history = append(history, m.ToAIMessage()...)
 711	}
 712
 713	var files []fantasy.FilePart
 714	for _, attachment := range attachments {
 715		if attachment.IsText() {
 716			continue
 717		}
 718		files = append(files, fantasy.FilePart{
 719			Filename:  attachment.FileName,
 720			Data:      attachment.Content,
 721			MediaType: attachment.MimeType,
 722		})
 723	}
 724
 725	return history, files
 726}
 727
 728func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 729	msgs, err := a.messages.List(ctx, session.ID)
 730	if err != nil {
 731		return nil, fmt.Errorf("failed to list messages: %w", err)
 732	}
 733
 734	if session.SummaryMessageID != "" {
 735		summaryMsgIndex := -1
 736		for i, msg := range msgs {
 737			if msg.ID == session.SummaryMessageID {
 738				summaryMsgIndex = i
 739				break
 740			}
 741		}
 742		if summaryMsgIndex != -1 {
 743			msgs = msgs[summaryMsgIndex:]
 744			msgs[0].Role = message.User
 745		}
 746	}
 747	return msgs, nil
 748}
 749
 750// generateTitle generates a session titled based on the initial prompt.
 751func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 752	if userPrompt == "" {
 753		return
 754	}
 755
 756	smallModel := a.smallModel.Get()
 757	largeModel := a.largeModel.Get()
 758	systemPromptPrefix := a.systemPromptPrefix.Get()
 759
 760	var maxOutputTokens int64 = 40
 761	if smallModel.CatwalkCfg.CanReason {
 762		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 763	}
 764
 765	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 766		return fantasy.NewAgent(m,
 767			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 768			fantasy.WithMaxOutputTokens(tok),
 769		)
 770	}
 771
 772	streamCall := fantasy.AgentStreamCall{
 773		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 774		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 775			prepared.Messages = opts.Messages
 776			if systemPromptPrefix != "" {
 777				prepared.Messages = append([]fantasy.Message{
 778					fantasy.NewSystemMessage(systemPromptPrefix),
 779				}, prepared.Messages...)
 780			}
 781			return callCtx, prepared, nil
 782		},
 783	}
 784
 785	// Use the small model to generate the title.
 786	model := smallModel
 787	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 788	resp, err := agent.Stream(ctx, streamCall)
 789	if err == nil {
 790		// We successfully generated a title with the small model.
 791		slog.Info("generated title with small model")
 792	} else {
 793		// It didn't work. Let's try with the big model.
 794		slog.Error("error generating title with small model; trying big model", "err", err)
 795		model = largeModel
 796		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 797		resp, err = agent.Stream(ctx, streamCall)
 798		if err == nil {
 799			slog.Info("generated title with large model")
 800		} else {
 801			// Welp, the large model didn't work either. Use the default
 802			// session name and return.
 803			slog.Error("error generating title with large model", "err", err)
 804			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 805			if saveErr != nil {
 806				slog.Error("failed to save session title and usage", "error", saveErr)
 807			}
 808			return
 809		}
 810	}
 811
 812	if resp == nil {
 813		// Actually, we didn't get a response so we can't. Use the default
 814		// session name and return.
 815		slog.Error("response is nil; can't generate title")
 816		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 817		if saveErr != nil {
 818			slog.Error("failed to save session title and usage", "error", saveErr)
 819		}
 820		return
 821	}
 822
 823	// Clean up title.
 824	var title string
 825	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 826
 827	// Remove thinking tags if present.
 828	title = thinkTagRegex.ReplaceAllString(title, "")
 829
 830	title = strings.TrimSpace(title)
 831	if title == "" {
 832		slog.Warn("empty title; using fallback")
 833		title = defaultSessionName
 834	}
 835
 836	// Calculate usage and cost.
 837	var openrouterCost *float64
 838	for _, step := range resp.Steps {
 839		stepCost := a.openrouterCost(step.ProviderMetadata)
 840		if stepCost != nil {
 841			newCost := *stepCost
 842			if openrouterCost != nil {
 843				newCost += *openrouterCost
 844			}
 845			openrouterCost = &newCost
 846		}
 847	}
 848
 849	modelConfig := model.CatwalkCfg
 850	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 851		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 852		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 853		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 854
 855	// Use override cost if available (e.g., from OpenRouter).
 856	if openrouterCost != nil {
 857		cost = *openrouterCost
 858	}
 859
 860	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 861	completionTokens := resp.TotalUsage.OutputTokens
 862
 863	// Atomically update only title and usage fields to avoid overriding other
 864	// concurrent session updates.
 865	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 866	if saveErr != nil {
 867		slog.Error("failed to save session title and usage", "error", saveErr)
 868		return
 869	}
 870}
 871
 872func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 873	openrouterMetadata, ok := metadata[openrouter.Name]
 874	if !ok {
 875		return nil
 876	}
 877
 878	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 879	if !ok {
 880		return nil
 881	}
 882	return &opts.Usage.Cost
 883}
 884
 885func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 886	modelConfig := model.CatwalkCfg
 887	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 888		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 889		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 890		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 891
 892	a.eventTokensUsed(session.ID, model, usage, cost)
 893
 894	if overrideCost != nil {
 895		session.Cost += *overrideCost
 896	} else {
 897		session.Cost += cost
 898	}
 899
 900	session.CompletionTokens = usage.OutputTokens
 901	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 902}
 903
 904func (a *sessionAgent) Cancel(sessionID string) {
 905	// Cancel regular requests. Don't use Take() here - we need the entry to
 906	// remain in activeRequests so IsBusy() returns true until the goroutine
 907	// fully completes (including error handling that may access the DB).
 908	// The defer in processRequest will clean up the entry.
 909	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 910		slog.Info("Request cancellation initiated", "session_id", sessionID)
 911		cancel()
 912	}
 913
 914	// Also check for summarize requests.
 915	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 916		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 917		cancel()
 918	}
 919
 920	if a.QueuedPrompts(sessionID) > 0 {
 921		slog.Info("Clearing queued prompts", "session_id", sessionID)
 922		a.messageQueue.Del(sessionID)
 923	}
 924}
 925
 926func (a *sessionAgent) ClearQueue(sessionID string) {
 927	if a.QueuedPrompts(sessionID) > 0 {
 928		slog.Info("Clearing queued prompts", "session_id", sessionID)
 929		a.messageQueue.Del(sessionID)
 930	}
 931}
 932
 933func (a *sessionAgent) CancelAll() {
 934	if !a.IsBusy() {
 935		return
 936	}
 937	for key := range a.activeRequests.Seq2() {
 938		a.Cancel(key) // key is sessionID
 939	}
 940
 941	timeout := time.After(5 * time.Second)
 942	for a.IsBusy() {
 943		select {
 944		case <-timeout:
 945			return
 946		default:
 947			time.Sleep(200 * time.Millisecond)
 948		}
 949	}
 950}
 951
 952func (a *sessionAgent) IsBusy() bool {
 953	var busy bool
 954	for cancelFunc := range a.activeRequests.Seq() {
 955		if cancelFunc != nil {
 956			busy = true
 957			break
 958		}
 959	}
 960	return busy
 961}
 962
 963func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 964	_, busy := a.activeRequests.Get(sessionID)
 965	return busy
 966}
 967
 968func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 969	l, ok := a.messageQueue.Get(sessionID)
 970	if !ok {
 971		return 0
 972	}
 973	return len(l)
 974}
 975
 976func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 977	l, ok := a.messageQueue.Get(sessionID)
 978	if !ok {
 979		return nil
 980	}
 981	prompts := make([]string, len(l))
 982	for i, call := range l {
 983		prompts[i] = call.Prompt
 984	}
 985	return prompts
 986}
 987
 988func (a *sessionAgent) SetModels(large Model, small Model) {
 989	a.largeModel.Set(large)
 990	a.smallModel.Set(small)
 991}
 992
 993func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 994	a.tools.SetSlice(tools)
 995}
 996
 997func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
 998	a.systemPrompt.Set(systemPrompt)
 999}
1000
1001func (a *sessionAgent) Model() Model {
1002	return a.largeModel.Get()
1003}
1004
1005// convertToToolResult converts a fantasy tool result to a message tool result.
1006func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1007	baseResult := message.ToolResult{
1008		ToolCallID: result.ToolCallID,
1009		Name:       result.ToolName,
1010		Metadata:   result.ClientMetadata,
1011	}
1012
1013	switch result.Result.GetType() {
1014	case fantasy.ToolResultContentTypeText:
1015		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1016			baseResult.Content = r.Text
1017		}
1018	case fantasy.ToolResultContentTypeError:
1019		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1020			baseResult.Content = r.Error.Error()
1021			baseResult.IsError = true
1022		}
1023	case fantasy.ToolResultContentTypeMedia:
1024		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1025			content := r.Text
1026			if content == "" {
1027				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1028			}
1029			baseResult.Content = content
1030			baseResult.Data = r.Data
1031			baseResult.MIMEType = r.MediaType
1032		}
1033	}
1034
1035	return baseResult
1036}
1037
1038// workaroundProviderMediaLimitations converts media content in tool results to
1039// user messages for providers that don't natively support images in tool results.
1040//
1041// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1042// don't support sending images/media in tool result messages - they only accept
1043// text in tool results. However, they DO support images in user messages.
1044//
1045// If we send media in tool results to these providers, the API returns an error.
1046//
1047// Solution: For these providers, we:
1048//  1. Replace the media in the tool result with a text placeholder
1049//  2. Inject a user message immediately after with the image as a file attachment
1050//  3. This maintains the tool execution flow while working around API limitations
1051//
1052// Anthropic and Bedrock support images natively in tool results, so we skip
1053// this workaround for them.
1054//
1055// Example transformation:
1056//
1057//	BEFORE: [tool result: image data]
1058//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1059func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1060	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1061		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1062
1063	if providerSupportsMedia {
1064		return messages
1065	}
1066
1067	convertedMessages := make([]fantasy.Message, 0, len(messages))
1068
1069	for _, msg := range messages {
1070		if msg.Role != fantasy.MessageRoleTool {
1071			convertedMessages = append(convertedMessages, msg)
1072			continue
1073		}
1074
1075		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1076		var mediaFiles []fantasy.FilePart
1077
1078		for _, part := range msg.Content {
1079			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1080			if !ok {
1081				textParts = append(textParts, part)
1082				continue
1083			}
1084
1085			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1086				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1087				if err != nil {
1088					slog.Warn("failed to decode media data", "error", err)
1089					textParts = append(textParts, part)
1090					continue
1091				}
1092
1093				mediaFiles = append(mediaFiles, fantasy.FilePart{
1094					Data:      decoded,
1095					MediaType: media.MediaType,
1096					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1097				})
1098
1099				textParts = append(textParts, fantasy.ToolResultPart{
1100					ToolCallID: toolResult.ToolCallID,
1101					Output: fantasy.ToolResultOutputContentText{
1102						Text: "[Image/media content loaded - see attached file]",
1103					},
1104					ProviderOptions: toolResult.ProviderOptions,
1105				})
1106			} else {
1107				textParts = append(textParts, part)
1108			}
1109		}
1110
1111		convertedMessages = append(convertedMessages, fantasy.Message{
1112			Role:    fantasy.MessageRoleTool,
1113			Content: textParts,
1114		})
1115
1116		if len(mediaFiles) > 0 {
1117			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1118				"Here is the media content from the tool result:",
1119				mediaFiles...,
1120			))
1121		}
1122	}
1123
1124	return convertedMessages
1125}
1126
1127// buildSummaryPrompt constructs the prompt text for session summarization.
1128func buildSummaryPrompt(todos []session.Todo) string {
1129	var sb strings.Builder
1130	sb.WriteString("Provide a detailed summary of our conversation above.")
1131	if len(todos) > 0 {
1132		sb.WriteString("\n\n## Current Todo List\n\n")
1133		for _, t := range todos {
1134			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1135		}
1136		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1137		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1138	}
1139	return sb.String()
1140}