agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"charm.land/lipgloss/v2"
  31	"github.com/charmbracelet/catwalk/pkg/catwalk"
  32	"github.com/charmbracelet/crush/internal/agent/hyper"
  33	"github.com/charmbracelet/crush/internal/agent/tools"
  34	"github.com/charmbracelet/crush/internal/config"
  35	"github.com/charmbracelet/crush/internal/csync"
  36	"github.com/charmbracelet/crush/internal/message"
  37	"github.com/charmbracelet/crush/internal/permission"
  38	"github.com/charmbracelet/crush/internal/session"
  39	"github.com/charmbracelet/crush/internal/stringext"
  40)
  41
  42//go:embed templates/title.md
  43var titlePrompt []byte
  44
  45//go:embed templates/summary.md
  46var summaryPrompt []byte
  47
  48type SessionAgentCall struct {
  49	SessionID        string
  50	Prompt           string
  51	ProviderOptions  fantasy.ProviderOptions
  52	Attachments      []message.Attachment
  53	MaxOutputTokens  int64
  54	Temperature      *float64
  55	TopP             *float64
  56	TopK             *int64
  57	FrequencyPenalty *float64
  58	PresencePenalty  *float64
  59}
  60
  61type SessionAgent interface {
  62	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  63	SetModels(large Model, small Model)
  64	SetTools(tools []fantasy.AgentTool)
  65	Cancel(sessionID string)
  66	CancelAll()
  67	IsSessionBusy(sessionID string) bool
  68	IsBusy() bool
  69	QueuedPrompts(sessionID string) int
  70	QueuedPromptsList(sessionID string) []string
  71	ClearQueue(sessionID string)
  72	Summarize(context.Context, string, fantasy.ProviderOptions) error
  73	Model() Model
  74}
  75
  76type Model struct {
  77	Model      fantasy.LanguageModel
  78	CatwalkCfg catwalk.Model
  79	ModelCfg   config.SelectedModel
  80}
  81
  82type sessionAgent struct {
  83	largeModel           Model
  84	smallModel           Model
  85	systemPromptPrefix   string
  86	systemPrompt         string
  87	isSubAgent           bool
  88	tools                []fantasy.AgentTool
  89	sessions             session.Service
  90	messages             message.Service
  91	disableAutoSummarize bool
  92	isYolo               bool
  93
  94	messageQueue   *csync.Map[string, []SessionAgentCall]
  95	activeRequests *csync.Map[string, context.CancelFunc]
  96}
  97
  98type SessionAgentOptions struct {
  99	LargeModel           Model
 100	SmallModel           Model
 101	SystemPromptPrefix   string
 102	SystemPrompt         string
 103	IsSubAgent           bool
 104	DisableAutoSummarize bool
 105	IsYolo               bool
 106	Sessions             session.Service
 107	Messages             message.Service
 108	Tools                []fantasy.AgentTool
 109}
 110
 111func NewSessionAgent(
 112	opts SessionAgentOptions,
 113) SessionAgent {
 114	return &sessionAgent{
 115		largeModel:           opts.LargeModel,
 116		smallModel:           opts.SmallModel,
 117		systemPromptPrefix:   opts.SystemPromptPrefix,
 118		systemPrompt:         opts.SystemPrompt,
 119		isSubAgent:           opts.IsSubAgent,
 120		sessions:             opts.Sessions,
 121		messages:             opts.Messages,
 122		disableAutoSummarize: opts.DisableAutoSummarize,
 123		tools:                opts.Tools,
 124		isYolo:               opts.IsYolo,
 125		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 126		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 127	}
 128}
 129
 130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 131	if call.Prompt == "" {
 132		return nil, ErrEmptyPrompt
 133	}
 134	if call.SessionID == "" {
 135		return nil, ErrSessionMissing
 136	}
 137
 138	// Queue the message if busy
 139	if a.IsSessionBusy(call.SessionID) {
 140		existing, ok := a.messageQueue.Get(call.SessionID)
 141		if !ok {
 142			existing = []SessionAgentCall{}
 143		}
 144		existing = append(existing, call)
 145		a.messageQueue.Set(call.SessionID, existing)
 146		return nil, nil
 147	}
 148
 149	if len(a.tools) > 0 {
 150		// Add Anthropic caching to the last tool.
 151		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 152	}
 153
 154	agent := fantasy.NewAgent(
 155		a.largeModel.Model,
 156		fantasy.WithSystemPrompt(a.systemPrompt),
 157		fantasy.WithTools(a.tools...),
 158	)
 159
 160	sessionLock := sync.Mutex{}
 161	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 162	if err != nil {
 163		return nil, fmt.Errorf("failed to get session: %w", err)
 164	}
 165
 166	msgs, err := a.getSessionMessages(ctx, currentSession)
 167	if err != nil {
 168		return nil, fmt.Errorf("failed to get session messages: %w", err)
 169	}
 170
 171	var wg sync.WaitGroup
 172	// Generate title if first message.
 173	if len(msgs) == 0 {
 174		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 175		wg.Go(func() {
 176			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 177		})
 178	}
 179
 180	// Add the user message to the session.
 181	_, err = a.createUserMessage(ctx, call)
 182	if err != nil {
 183		return nil, err
 184	}
 185
 186	// Add the session to the context.
 187	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 188
 189	genCtx, cancel := context.WithCancel(ctx)
 190	a.activeRequests.Set(call.SessionID, cancel)
 191
 192	defer cancel()
 193	defer a.activeRequests.Del(call.SessionID)
 194
 195	history, files := a.preparePrompt(msgs, call.Attachments...)
 196
 197	startTime := time.Now()
 198	a.eventPromptSent(call.SessionID)
 199
 200	var currentAssistant *message.Message
 201	var shouldSummarize bool
 202	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 203		Prompt:           call.Prompt,
 204		Files:            files,
 205		Messages:         history,
 206		ProviderOptions:  call.ProviderOptions,
 207		MaxOutputTokens:  &call.MaxOutputTokens,
 208		TopP:             call.TopP,
 209		Temperature:      call.Temperature,
 210		PresencePenalty:  call.PresencePenalty,
 211		TopK:             call.TopK,
 212		FrequencyPenalty: call.FrequencyPenalty,
 213		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 214			prepared.Messages = options.Messages
 215			for i := range prepared.Messages {
 216				prepared.Messages[i].ProviderOptions = nil
 217			}
 218
 219			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 220			a.messageQueue.Del(call.SessionID)
 221			for _, queued := range queuedCalls {
 222				userMessage, createErr := a.createUserMessage(callContext, queued)
 223				if createErr != nil {
 224					return callContext, prepared, createErr
 225				}
 226				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 227			}
 228
 229			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 230
 231			lastSystemRoleInx := 0
 232			systemMessageUpdated := false
 233			for i, msg := range prepared.Messages {
 234				// Only add cache control to the last message.
 235				if msg.Role == fantasy.MessageRoleSystem {
 236					lastSystemRoleInx = i
 237				} else if !systemMessageUpdated {
 238					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 239					systemMessageUpdated = true
 240				}
 241				// Than add cache control to the last 2 messages.
 242				if i > len(prepared.Messages)-3 {
 243					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 244				}
 245			}
 246
 247			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 248				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 249			}
 250
 251			var assistantMsg message.Message
 252			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 253				Role:     message.Assistant,
 254				Parts:    []message.ContentPart{},
 255				Model:    a.largeModel.ModelCfg.Model,
 256				Provider: a.largeModel.ModelCfg.Provider,
 257			})
 258			if err != nil {
 259				return callContext, prepared, err
 260			}
 261			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 262			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 263			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 264			currentAssistant = &assistantMsg
 265			return callContext, prepared, err
 266		},
 267		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 268			currentAssistant.AppendReasoningContent(reasoning.Text)
 269			return a.messages.Update(genCtx, *currentAssistant)
 270		},
 271		OnReasoningDelta: func(id string, text string) error {
 272			currentAssistant.AppendReasoningContent(text)
 273			return a.messages.Update(genCtx, *currentAssistant)
 274		},
 275		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 276			// handle anthropic signature
 277			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 278				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 279					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 280				}
 281			}
 282			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 283				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 284					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 285				}
 286			}
 287			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 288				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 289					currentAssistant.SetReasoningResponsesData(reasoning)
 290				}
 291			}
 292			currentAssistant.FinishThinking()
 293			return a.messages.Update(genCtx, *currentAssistant)
 294		},
 295		OnTextDelta: func(id string, text string) error {
 296			// Strip leading newline from initial text content. This is is
 297			// particularly important in non-interactive mode where leading
 298			// newlines are very visible.
 299			if len(currentAssistant.Parts) == 0 {
 300				text = strings.TrimPrefix(text, "\n")
 301			}
 302
 303			currentAssistant.AppendContent(text)
 304			return a.messages.Update(genCtx, *currentAssistant)
 305		},
 306		OnToolInputStart: func(id string, toolName string) error {
 307			toolCall := message.ToolCall{
 308				ID:               id,
 309				Name:             toolName,
 310				ProviderExecuted: false,
 311				Finished:         false,
 312			}
 313			currentAssistant.AddToolCall(toolCall)
 314			return a.messages.Update(genCtx, *currentAssistant)
 315		},
 316		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 317			// TODO: implement
 318		},
 319		OnToolCall: func(tc fantasy.ToolCallContent) error {
 320			toolCall := message.ToolCall{
 321				ID:               tc.ToolCallID,
 322				Name:             tc.ToolName,
 323				Input:            tc.Input,
 324				ProviderExecuted: false,
 325				Finished:         true,
 326			}
 327			currentAssistant.AddToolCall(toolCall)
 328			return a.messages.Update(genCtx, *currentAssistant)
 329		},
 330		OnToolResult: func(result fantasy.ToolResultContent) error {
 331			toolResult := a.convertToToolResult(result)
 332			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 333				Role: message.Tool,
 334				Parts: []message.ContentPart{
 335					toolResult,
 336				},
 337			})
 338			return createMsgErr
 339		},
 340		OnStepFinish: func(stepResult fantasy.StepResult) error {
 341			finishReason := message.FinishReasonUnknown
 342			switch stepResult.FinishReason {
 343			case fantasy.FinishReasonLength:
 344				finishReason = message.FinishReasonMaxTokens
 345			case fantasy.FinishReasonStop:
 346				finishReason = message.FinishReasonEndTurn
 347			case fantasy.FinishReasonToolCalls:
 348				finishReason = message.FinishReasonToolUse
 349			}
 350			currentAssistant.AddFinish(finishReason, "", "")
 351			sessionLock.Lock()
 352			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 353			if getSessionErr != nil {
 354				sessionLock.Unlock()
 355				return getSessionErr
 356			}
 357			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 358			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 359			sessionLock.Unlock()
 360			if sessionErr != nil {
 361				return sessionErr
 362			}
 363			return a.messages.Update(genCtx, *currentAssistant)
 364		},
 365		StopWhen: []fantasy.StopCondition{
 366			func(_ []fantasy.StepResult) bool {
 367				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 368				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 369				remaining := cw - tokens
 370				var threshold int64
 371				if cw > 200_000 {
 372					threshold = 20_000
 373				} else {
 374					threshold = int64(float64(cw) * 0.2)
 375				}
 376				if (remaining <= threshold) && !a.disableAutoSummarize {
 377					shouldSummarize = true
 378					return true
 379				}
 380				return false
 381			},
 382		},
 383	})
 384
 385	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 386
 387	if err != nil {
 388		isCancelErr := errors.Is(err, context.Canceled)
 389		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 390		if currentAssistant == nil {
 391			return result, err
 392		}
 393		// Ensure we finish thinking on error to close the reasoning state.
 394		currentAssistant.FinishThinking()
 395		toolCalls := currentAssistant.ToolCalls()
 396		// INFO: we use the parent context here because the genCtx has been cancelled.
 397		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 398		if createErr != nil {
 399			return nil, createErr
 400		}
 401		for _, tc := range toolCalls {
 402			if !tc.Finished {
 403				tc.Finished = true
 404				tc.Input = "{}"
 405				currentAssistant.AddToolCall(tc)
 406				updateErr := a.messages.Update(ctx, *currentAssistant)
 407				if updateErr != nil {
 408					return nil, updateErr
 409				}
 410			}
 411
 412			found := false
 413			for _, msg := range msgs {
 414				if msg.Role == message.Tool {
 415					for _, tr := range msg.ToolResults() {
 416						if tr.ToolCallID == tc.ID {
 417							found = true
 418							break
 419						}
 420					}
 421				}
 422				if found {
 423					break
 424				}
 425			}
 426			if found {
 427				continue
 428			}
 429			content := "There was an error while executing the tool"
 430			if isCancelErr {
 431				content = "Tool execution canceled by user"
 432			} else if isPermissionErr {
 433				content = "User denied permission"
 434			}
 435			toolResult := message.ToolResult{
 436				ToolCallID: tc.ID,
 437				Name:       tc.Name,
 438				Content:    content,
 439				IsError:    true,
 440			}
 441			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 442				Role: message.Tool,
 443				Parts: []message.ContentPart{
 444					toolResult,
 445				},
 446			})
 447			if createErr != nil {
 448				return nil, createErr
 449			}
 450		}
 451		var fantasyErr *fantasy.Error
 452		var providerErr *fantasy.ProviderError
 453		const defaultTitle = "Provider Error"
 454		if isCancelErr {
 455			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 456		} else if isPermissionErr {
 457			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 458		} else if errors.Is(err, hyper.ErrNoCredits) {
 459			url := hyper.BaseURL()
 460			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 461			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 462		} else if errors.As(err, &providerErr) {
 463			if providerErr.Message == "The requested model is not supported." {
 464				url := "https://github.com/settings/copilot/features"
 465				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 466				currentAssistant.AddFinish(
 467					message.FinishReasonError,
 468					"Copilot model not enabled",
 469					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 470				)
 471			} else {
 472				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 473			}
 474		} else if errors.As(err, &fantasyErr) {
 475			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 476		} else {
 477			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 478		}
 479		// Note: we use the parent context here because the genCtx has been
 480		// cancelled.
 481		updateErr := a.messages.Update(ctx, *currentAssistant)
 482		if updateErr != nil {
 483			return nil, updateErr
 484		}
 485		return nil, err
 486	}
 487	wg.Wait()
 488
 489	if shouldSummarize {
 490		a.activeRequests.Del(call.SessionID)
 491		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 492			return nil, summarizeErr
 493		}
 494		// If the agent wasn't done...
 495		if len(currentAssistant.ToolCalls()) > 0 {
 496			existing, ok := a.messageQueue.Get(call.SessionID)
 497			if !ok {
 498				existing = []SessionAgentCall{}
 499			}
 500			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 501			existing = append(existing, call)
 502			a.messageQueue.Set(call.SessionID, existing)
 503		}
 504	}
 505
 506	// Release active request before processing queued messages.
 507	a.activeRequests.Del(call.SessionID)
 508	cancel()
 509
 510	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 511	if !ok || len(queuedMessages) == 0 {
 512		return result, err
 513	}
 514	// There are queued messages restart the loop.
 515	firstQueuedMessage := queuedMessages[0]
 516	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 517	return a.Run(ctx, firstQueuedMessage)
 518}
 519
 520func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 521	if a.IsSessionBusy(sessionID) {
 522		return ErrSessionBusy
 523	}
 524
 525	currentSession, err := a.sessions.Get(ctx, sessionID)
 526	if err != nil {
 527		return fmt.Errorf("failed to get session: %w", err)
 528	}
 529	msgs, err := a.getSessionMessages(ctx, currentSession)
 530	if err != nil {
 531		return err
 532	}
 533	if len(msgs) == 0 {
 534		// Nothing to summarize.
 535		return nil
 536	}
 537
 538	aiMsgs, _ := a.preparePrompt(msgs)
 539
 540	genCtx, cancel := context.WithCancel(ctx)
 541	a.activeRequests.Set(sessionID, cancel)
 542	defer a.activeRequests.Del(sessionID)
 543	defer cancel()
 544
 545	agent := fantasy.NewAgent(a.largeModel.Model,
 546		fantasy.WithSystemPrompt(string(summaryPrompt)),
 547	)
 548	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 549		Role:             message.Assistant,
 550		Model:            a.largeModel.Model.Model(),
 551		Provider:         a.largeModel.Model.Provider(),
 552		IsSummaryMessage: true,
 553	})
 554	if err != nil {
 555		return err
 556	}
 557
 558	summaryPromptText := "Provide a detailed summary of our conversation above."
 559	if len(currentSession.Todos) > 0 {
 560		summaryPromptText += "\n\n## Current Todo List\n\n"
 561		for _, t := range currentSession.Todos {
 562			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 563		}
 564		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 565		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 566	}
 567
 568	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 569		Prompt:          summaryPromptText,
 570		Messages:        aiMsgs,
 571		ProviderOptions: opts,
 572		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 573			prepared.Messages = options.Messages
 574			if a.systemPromptPrefix != "" {
 575				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 576			}
 577			return callContext, prepared, nil
 578		},
 579		OnReasoningDelta: func(id string, text string) error {
 580			summaryMessage.AppendReasoningContent(text)
 581			return a.messages.Update(genCtx, summaryMessage)
 582		},
 583		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 584			// Handle anthropic signature.
 585			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 586				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 587					summaryMessage.AppendReasoningSignature(signature.Signature)
 588				}
 589			}
 590			summaryMessage.FinishThinking()
 591			return a.messages.Update(genCtx, summaryMessage)
 592		},
 593		OnTextDelta: func(id, text string) error {
 594			summaryMessage.AppendContent(text)
 595			return a.messages.Update(genCtx, summaryMessage)
 596		},
 597	})
 598	if err != nil {
 599		isCancelErr := errors.Is(err, context.Canceled)
 600		if isCancelErr {
 601			// User cancelled summarize we need to remove the summary message.
 602			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 603			return deleteErr
 604		}
 605		return err
 606	}
 607
 608	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 609	err = a.messages.Update(genCtx, summaryMessage)
 610	if err != nil {
 611		return err
 612	}
 613
 614	var openrouterCost *float64
 615	for _, step := range resp.Steps {
 616		stepCost := a.openrouterCost(step.ProviderMetadata)
 617		if stepCost != nil {
 618			newCost := *stepCost
 619			if openrouterCost != nil {
 620				newCost += *openrouterCost
 621			}
 622			openrouterCost = &newCost
 623		}
 624	}
 625
 626	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 627
 628	// Just in case, get just the last usage info.
 629	usage := resp.Response.Usage
 630	currentSession.SummaryMessageID = summaryMessage.ID
 631	currentSession.CompletionTokens = usage.OutputTokens
 632	currentSession.PromptTokens = 0
 633	_, err = a.sessions.Save(genCtx, currentSession)
 634	return err
 635}
 636
 637func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 638	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 639		return fantasy.ProviderOptions{}
 640	}
 641	return fantasy.ProviderOptions{
 642		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 643			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 644		},
 645		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 646			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 647		},
 648	}
 649}
 650
 651func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 652	var attachmentParts []message.ContentPart
 653	for _, attachment := range call.Attachments {
 654		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 655	}
 656	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 657	parts = append(parts, attachmentParts...)
 658	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 659		Role:  message.User,
 660		Parts: parts,
 661	})
 662	if err != nil {
 663		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 664	}
 665	return msg, nil
 666}
 667
 668func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 669	var history []fantasy.Message
 670	if !a.isSubAgent {
 671		history = append(history, fantasy.NewUserMessage(
 672			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 673				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 674If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 675If not, please feel free to ignore. Again do not mention this message to the user.`,
 676			),
 677		))
 678	}
 679	for _, m := range msgs {
 680		if len(m.Parts) == 0 {
 681			continue
 682		}
 683		// Assistant message without content or tool calls (cancelled before it
 684		// returned anything).
 685		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 686			continue
 687		}
 688		history = append(history, m.ToAIMessage()...)
 689	}
 690
 691	var files []fantasy.FilePart
 692	for _, attachment := range attachments {
 693		files = append(files, fantasy.FilePart{
 694			Filename:  attachment.FileName,
 695			Data:      attachment.Content,
 696			MediaType: attachment.MimeType,
 697		})
 698	}
 699
 700	return history, files
 701}
 702
 703func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 704	msgs, err := a.messages.List(ctx, session.ID)
 705	if err != nil {
 706		return nil, fmt.Errorf("failed to list messages: %w", err)
 707	}
 708
 709	if session.SummaryMessageID != "" {
 710		summaryMsgInex := -1
 711		for i, msg := range msgs {
 712			if msg.ID == session.SummaryMessageID {
 713				summaryMsgInex = i
 714				break
 715			}
 716		}
 717		if summaryMsgInex != -1 {
 718			msgs = msgs[summaryMsgInex:]
 719			msgs[0].Role = message.User
 720		}
 721	}
 722	return msgs, nil
 723}
 724
 725func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
 726	if prompt == "" {
 727		return
 728	}
 729
 730	var maxOutput int64 = 40
 731	if a.smallModel.CatwalkCfg.CanReason {
 732		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 733	}
 734
 735	agent := fantasy.NewAgent(a.smallModel.Model,
 736		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 737		fantasy.WithMaxOutputTokens(maxOutput),
 738	)
 739
 740	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 741		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 742		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 743			prepared.Messages = options.Messages
 744			if a.systemPromptPrefix != "" {
 745				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 746			}
 747			return callContext, prepared, nil
 748		},
 749	})
 750	if err != nil {
 751		slog.Error("error generating title", "err", err)
 752		return
 753	}
 754
 755	title := resp.Response.Content.Text()
 756
 757	title = strings.ReplaceAll(title, "\n", " ")
 758
 759	// Remove thinking tags if present.
 760	if idx := strings.Index(title, "</think>"); idx > 0 {
 761		title = title[idx+len("</think>"):]
 762	}
 763
 764	title = strings.TrimSpace(title)
 765	if title == "" {
 766		slog.Warn("failed to generate title", "warn", "empty title")
 767		return
 768	}
 769
 770	// Calculate usage and cost.
 771	var openrouterCost *float64
 772	for _, step := range resp.Steps {
 773		stepCost := a.openrouterCost(step.ProviderMetadata)
 774		if stepCost != nil {
 775			newCost := *stepCost
 776			if openrouterCost != nil {
 777				newCost += *openrouterCost
 778			}
 779			openrouterCost = &newCost
 780		}
 781	}
 782
 783	modelConfig := a.smallModel.CatwalkCfg
 784	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 785		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 786		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 787		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 788
 789	if a.isClaudeCode() {
 790		cost = 0
 791	}
 792
 793	// Use override cost if available (e.g., from OpenRouter).
 794	if openrouterCost != nil {
 795		cost = *openrouterCost
 796	}
 797
 798	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 799	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 800
 801	// Atomically update only title and usage fields to avoid overriding other
 802	// concurrent session updates.
 803	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 804	if saveErr != nil {
 805		slog.Error("failed to save session title & usage", "error", saveErr)
 806		return
 807	}
 808}
 809
 810func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 811	openrouterMetadata, ok := metadata[openrouter.Name]
 812	if !ok {
 813		return nil
 814	}
 815
 816	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 817	if !ok {
 818		return nil
 819	}
 820	return &opts.Usage.Cost
 821}
 822
 823func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 824	modelConfig := model.CatwalkCfg
 825	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 826		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 827		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 828		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 829
 830	if a.isClaudeCode() {
 831		cost = 0
 832	}
 833
 834	a.eventTokensUsed(session.ID, model, usage, cost)
 835
 836	if overrideCost != nil {
 837		session.Cost += *overrideCost
 838	} else {
 839		session.Cost += cost
 840	}
 841
 842	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 843	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 844}
 845
 846func (a *sessionAgent) Cancel(sessionID string) {
 847	// Cancel regular requests.
 848	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 849		slog.Info("Request cancellation initiated", "session_id", sessionID)
 850		cancel()
 851	}
 852
 853	// Also check for summarize requests.
 854	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 855		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 856		cancel()
 857	}
 858
 859	if a.QueuedPrompts(sessionID) > 0 {
 860		slog.Info("Clearing queued prompts", "session_id", sessionID)
 861		a.messageQueue.Del(sessionID)
 862	}
 863}
 864
 865func (a *sessionAgent) ClearQueue(sessionID string) {
 866	if a.QueuedPrompts(sessionID) > 0 {
 867		slog.Info("Clearing queued prompts", "session_id", sessionID)
 868		a.messageQueue.Del(sessionID)
 869	}
 870}
 871
 872func (a *sessionAgent) CancelAll() {
 873	if !a.IsBusy() {
 874		return
 875	}
 876	for key := range a.activeRequests.Seq2() {
 877		a.Cancel(key) // key is sessionID
 878	}
 879
 880	timeout := time.After(5 * time.Second)
 881	for a.IsBusy() {
 882		select {
 883		case <-timeout:
 884			return
 885		default:
 886			time.Sleep(200 * time.Millisecond)
 887		}
 888	}
 889}
 890
 891func (a *sessionAgent) IsBusy() bool {
 892	var busy bool
 893	for cancelFunc := range a.activeRequests.Seq() {
 894		if cancelFunc != nil {
 895			busy = true
 896			break
 897		}
 898	}
 899	return busy
 900}
 901
 902func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 903	_, busy := a.activeRequests.Get(sessionID)
 904	return busy
 905}
 906
 907func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 908	l, ok := a.messageQueue.Get(sessionID)
 909	if !ok {
 910		return 0
 911	}
 912	return len(l)
 913}
 914
 915func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 916	l, ok := a.messageQueue.Get(sessionID)
 917	if !ok {
 918		return nil
 919	}
 920	prompts := make([]string, len(l))
 921	for i, call := range l {
 922		prompts[i] = call.Prompt
 923	}
 924	return prompts
 925}
 926
 927func (a *sessionAgent) SetModels(large Model, small Model) {
 928	a.largeModel = large
 929	a.smallModel = small
 930}
 931
 932func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 933	a.tools = tools
 934}
 935
 936func (a *sessionAgent) Model() Model {
 937	return a.largeModel
 938}
 939
 940func (a *sessionAgent) promptPrefix() string {
 941	if a.isClaudeCode() {
 942		return "You are Claude Code, Anthropic's official CLI for Claude."
 943	}
 944	return a.systemPromptPrefix
 945}
 946
 947func (a *sessionAgent) isClaudeCode() bool {
 948	cfg := config.Get()
 949	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 950	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 951}
 952
 953// convertToToolResult converts a fantasy tool result to a message tool result.
 954func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 955	baseResult := message.ToolResult{
 956		ToolCallID: result.ToolCallID,
 957		Name:       result.ToolName,
 958		Metadata:   result.ClientMetadata,
 959	}
 960
 961	switch result.Result.GetType() {
 962	case fantasy.ToolResultContentTypeText:
 963		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 964			baseResult.Content = r.Text
 965		}
 966	case fantasy.ToolResultContentTypeError:
 967		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 968			baseResult.Content = r.Error.Error()
 969			baseResult.IsError = true
 970		}
 971	case fantasy.ToolResultContentTypeMedia:
 972		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 973			content := r.Text
 974			if content == "" {
 975				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 976			}
 977			baseResult.Content = content
 978			baseResult.Data = r.Data
 979			baseResult.MIMEType = r.MediaType
 980		}
 981	}
 982
 983	return baseResult
 984}
 985
 986// workaroundProviderMediaLimitations converts media content in tool results to
 987// user messages for providers that don't natively support images in tool results.
 988//
 989// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 990// don't support sending images/media in tool result messages - they only accept
 991// text in tool results. However, they DO support images in user messages.
 992//
 993// If we send media in tool results to these providers, the API returns an error.
 994//
 995// Solution: For these providers, we:
 996//  1. Replace the media in the tool result with a text placeholder
 997//  2. Inject a user message immediately after with the image as a file attachment
 998//  3. This maintains the tool execution flow while working around API limitations
 999//
1000// Anthropic and Bedrock support images natively in tool results, so we skip
1001// this workaround for them.
1002//
1003// Example transformation:
1004//
1005//	BEFORE: [tool result: image data]
1006//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1007func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1008	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1009		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1010
1011	if providerSupportsMedia {
1012		return messages
1013	}
1014
1015	convertedMessages := make([]fantasy.Message, 0, len(messages))
1016
1017	for _, msg := range messages {
1018		if msg.Role != fantasy.MessageRoleTool {
1019			convertedMessages = append(convertedMessages, msg)
1020			continue
1021		}
1022
1023		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1024		var mediaFiles []fantasy.FilePart
1025
1026		for _, part := range msg.Content {
1027			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1028			if !ok {
1029				textParts = append(textParts, part)
1030				continue
1031			}
1032
1033			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1034				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1035				if err != nil {
1036					slog.Warn("failed to decode media data", "error", err)
1037					textParts = append(textParts, part)
1038					continue
1039				}
1040
1041				mediaFiles = append(mediaFiles, fantasy.FilePart{
1042					Data:      decoded,
1043					MediaType: media.MediaType,
1044					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1045				})
1046
1047				textParts = append(textParts, fantasy.ToolResultPart{
1048					ToolCallID: toolResult.ToolCallID,
1049					Output: fantasy.ToolResultOutputContentText{
1050						Text: "[Image/media content loaded - see attached file]",
1051					},
1052					ProviderOptions: toolResult.ProviderOptions,
1053				})
1054			} else {
1055				textParts = append(textParts, part)
1056			}
1057		}
1058
1059		convertedMessages = append(convertedMessages, fantasy.Message{
1060			Role:    fantasy.MessageRoleTool,
1061			Content: textParts,
1062		})
1063
1064		if len(mediaFiles) > 0 {
1065			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1066				"Here is the media content from the tool result:",
1067				mediaFiles...,
1068			))
1069		}
1070	}
1071
1072	return convertedMessages
1073}