agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	Cancel(sessionID string)
  72	CancelAll()
  73	IsSessionBusy(sessionID string) bool
  74	IsBusy() bool
  75	QueuedPrompts(sessionID string) int
  76	QueuedPromptsList(sessionID string) []string
  77	ClearQueue(sessionID string)
  78	Summarize(context.Context, string, fantasy.ProviderOptions) error
  79	Model() Model
  80}
  81
  82type Model struct {
  83	Model      fantasy.LanguageModel
  84	CatwalkCfg catwalk.Model
  85	ModelCfg   config.SelectedModel
  86}
  87
  88type sessionAgent struct {
  89	largeModel           Model
  90	smallModel           Model
  91	systemPromptPrefix   string
  92	systemPrompt         string
  93	isSubAgent           bool
  94	tools                []fantasy.AgentTool
  95	sessions             session.Service
  96	messages             message.Service
  97	disableAutoSummarize bool
  98	isYolo               bool
  99
 100	messageQueue   *csync.Map[string, []SessionAgentCall]
 101	activeRequests *csync.Map[string, context.CancelFunc]
 102}
 103
 104type SessionAgentOptions struct {
 105	LargeModel           Model
 106	SmallModel           Model
 107	SystemPromptPrefix   string
 108	SystemPrompt         string
 109	IsSubAgent           bool
 110	DisableAutoSummarize bool
 111	IsYolo               bool
 112	Sessions             session.Service
 113	Messages             message.Service
 114	Tools                []fantasy.AgentTool
 115}
 116
 117func NewSessionAgent(
 118	opts SessionAgentOptions,
 119) SessionAgent {
 120	return &sessionAgent{
 121		largeModel:           opts.LargeModel,
 122		smallModel:           opts.SmallModel,
 123		systemPromptPrefix:   opts.SystemPromptPrefix,
 124		systemPrompt:         opts.SystemPrompt,
 125		isSubAgent:           opts.IsSubAgent,
 126		sessions:             opts.Sessions,
 127		messages:             opts.Messages,
 128		disableAutoSummarize: opts.DisableAutoSummarize,
 129		tools:                opts.Tools,
 130		isYolo:               opts.IsYolo,
 131		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 132		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 133	}
 134}
 135
 136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 137	if call.Prompt == "" {
 138		return nil, ErrEmptyPrompt
 139	}
 140	if call.SessionID == "" {
 141		return nil, ErrSessionMissing
 142	}
 143
 144	// Queue the message if busy
 145	if a.IsSessionBusy(call.SessionID) {
 146		existing, ok := a.messageQueue.Get(call.SessionID)
 147		if !ok {
 148			existing = []SessionAgentCall{}
 149		}
 150		existing = append(existing, call)
 151		a.messageQueue.Set(call.SessionID, existing)
 152		return nil, nil
 153	}
 154
 155	if len(a.tools) > 0 {
 156		// Add Anthropic caching to the last tool.
 157		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 158	}
 159
 160	agent := fantasy.NewAgent(
 161		a.largeModel.Model,
 162		fantasy.WithSystemPrompt(a.systemPrompt),
 163		fantasy.WithTools(a.tools...),
 164	)
 165
 166	sessionLock := sync.Mutex{}
 167	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 168	if err != nil {
 169		return nil, fmt.Errorf("failed to get session: %w", err)
 170	}
 171
 172	msgs, err := a.getSessionMessages(ctx, currentSession)
 173	if err != nil {
 174		return nil, fmt.Errorf("failed to get session messages: %w", err)
 175	}
 176
 177	var wg sync.WaitGroup
 178	// Generate title if first message.
 179	if len(msgs) == 0 {
 180		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 181		wg.Go(func() {
 182			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 183		})
 184	}
 185
 186	// Add the user message to the session.
 187	_, err = a.createUserMessage(ctx, call)
 188	if err != nil {
 189		return nil, err
 190	}
 191
 192	// Add the session to the context.
 193	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 194
 195	genCtx, cancel := context.WithCancel(ctx)
 196	a.activeRequests.Set(call.SessionID, cancel)
 197
 198	defer cancel()
 199	defer a.activeRequests.Del(call.SessionID)
 200
 201	history, files := a.preparePrompt(msgs, call.Attachments...)
 202
 203	startTime := time.Now()
 204	a.eventPromptSent(call.SessionID)
 205
 206	var currentAssistant *message.Message
 207	var shouldSummarize bool
 208	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 209		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 210		Files:            files,
 211		Messages:         history,
 212		ProviderOptions:  call.ProviderOptions,
 213		MaxOutputTokens:  &call.MaxOutputTokens,
 214		TopP:             call.TopP,
 215		Temperature:      call.Temperature,
 216		PresencePenalty:  call.PresencePenalty,
 217		TopK:             call.TopK,
 218		FrequencyPenalty: call.FrequencyPenalty,
 219		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 220			prepared.Messages = options.Messages
 221			for i := range prepared.Messages {
 222				prepared.Messages[i].ProviderOptions = nil
 223			}
 224
 225			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 226			a.messageQueue.Del(call.SessionID)
 227			for _, queued := range queuedCalls {
 228				userMessage, createErr := a.createUserMessage(callContext, queued)
 229				if createErr != nil {
 230					return callContext, prepared, createErr
 231				}
 232				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 233			}
 234
 235			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 236
 237			lastSystemRoleInx := 0
 238			systemMessageUpdated := false
 239			for i, msg := range prepared.Messages {
 240				// Only add cache control to the last message.
 241				if msg.Role == fantasy.MessageRoleSystem {
 242					lastSystemRoleInx = i
 243				} else if !systemMessageUpdated {
 244					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 245					systemMessageUpdated = true
 246				}
 247				// Than add cache control to the last 2 messages.
 248				if i > len(prepared.Messages)-3 {
 249					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 250				}
 251			}
 252
 253			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 254				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 255			}
 256
 257			var assistantMsg message.Message
 258			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 259				Role:     message.Assistant,
 260				Parts:    []message.ContentPart{},
 261				Model:    a.largeModel.ModelCfg.Model,
 262				Provider: a.largeModel.ModelCfg.Provider,
 263			})
 264			if err != nil {
 265				return callContext, prepared, err
 266			}
 267			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 268			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 269			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 270			currentAssistant = &assistantMsg
 271			return callContext, prepared, err
 272		},
 273		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 274			currentAssistant.AppendReasoningContent(reasoning.Text)
 275			return a.messages.Update(genCtx, *currentAssistant)
 276		},
 277		OnReasoningDelta: func(id string, text string) error {
 278			currentAssistant.AppendReasoningContent(text)
 279			return a.messages.Update(genCtx, *currentAssistant)
 280		},
 281		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 282			// handle anthropic signature
 283			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 284				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 285					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 286				}
 287			}
 288			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 289				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 290					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 291				}
 292			}
 293			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 294				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 295					currentAssistant.SetReasoningResponsesData(reasoning)
 296				}
 297			}
 298			currentAssistant.FinishThinking()
 299			return a.messages.Update(genCtx, *currentAssistant)
 300		},
 301		OnTextDelta: func(id string, text string) error {
 302			// Strip leading newline from initial text content. This is is
 303			// particularly important in non-interactive mode where leading
 304			// newlines are very visible.
 305			if len(currentAssistant.Parts) == 0 {
 306				text = strings.TrimPrefix(text, "\n")
 307			}
 308
 309			currentAssistant.AppendContent(text)
 310			return a.messages.Update(genCtx, *currentAssistant)
 311		},
 312		OnToolInputStart: func(id string, toolName string) error {
 313			toolCall := message.ToolCall{
 314				ID:               id,
 315				Name:             toolName,
 316				ProviderExecuted: false,
 317				Finished:         false,
 318			}
 319			currentAssistant.AddToolCall(toolCall)
 320			return a.messages.Update(genCtx, *currentAssistant)
 321		},
 322		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 323			// TODO: implement
 324		},
 325		OnToolCall: func(tc fantasy.ToolCallContent) error {
 326			toolCall := message.ToolCall{
 327				ID:               tc.ToolCallID,
 328				Name:             tc.ToolName,
 329				Input:            tc.Input,
 330				ProviderExecuted: false,
 331				Finished:         true,
 332			}
 333			currentAssistant.AddToolCall(toolCall)
 334			return a.messages.Update(genCtx, *currentAssistant)
 335		},
 336		OnToolResult: func(result fantasy.ToolResultContent) error {
 337			toolResult := a.convertToToolResult(result)
 338			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 339				Role: message.Tool,
 340				Parts: []message.ContentPart{
 341					toolResult,
 342				},
 343			})
 344			return createMsgErr
 345		},
 346		OnStepFinish: func(stepResult fantasy.StepResult) error {
 347			finishReason := message.FinishReasonUnknown
 348			switch stepResult.FinishReason {
 349			case fantasy.FinishReasonLength:
 350				finishReason = message.FinishReasonMaxTokens
 351			case fantasy.FinishReasonStop:
 352				finishReason = message.FinishReasonEndTurn
 353			case fantasy.FinishReasonToolCalls:
 354				finishReason = message.FinishReasonToolUse
 355			}
 356			currentAssistant.AddFinish(finishReason, "", "")
 357			sessionLock.Lock()
 358			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 359			if getSessionErr != nil {
 360				sessionLock.Unlock()
 361				return getSessionErr
 362			}
 363			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 364			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 365			sessionLock.Unlock()
 366			if sessionErr != nil {
 367				return sessionErr
 368			}
 369			return a.messages.Update(genCtx, *currentAssistant)
 370		},
 371		StopWhen: []fantasy.StopCondition{
 372			func(_ []fantasy.StepResult) bool {
 373				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 374				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 375				remaining := cw - tokens
 376				var threshold int64
 377				if cw > 200_000 {
 378					threshold = 20_000
 379				} else {
 380					threshold = int64(float64(cw) * 0.2)
 381				}
 382				if (remaining <= threshold) && !a.disableAutoSummarize {
 383					shouldSummarize = true
 384					return true
 385				}
 386				return false
 387			},
 388		},
 389	})
 390
 391	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 392
 393	if err != nil {
 394		isCancelErr := errors.Is(err, context.Canceled)
 395		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 396		if currentAssistant == nil {
 397			return result, err
 398		}
 399		// Ensure we finish thinking on error to close the reasoning state.
 400		currentAssistant.FinishThinking()
 401		toolCalls := currentAssistant.ToolCalls()
 402		// INFO: we use the parent context here because the genCtx has been cancelled.
 403		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 404		if createErr != nil {
 405			return nil, createErr
 406		}
 407		for _, tc := range toolCalls {
 408			if !tc.Finished {
 409				tc.Finished = true
 410				tc.Input = "{}"
 411				currentAssistant.AddToolCall(tc)
 412				updateErr := a.messages.Update(ctx, *currentAssistant)
 413				if updateErr != nil {
 414					return nil, updateErr
 415				}
 416			}
 417
 418			found := false
 419			for _, msg := range msgs {
 420				if msg.Role == message.Tool {
 421					for _, tr := range msg.ToolResults() {
 422						if tr.ToolCallID == tc.ID {
 423							found = true
 424							break
 425						}
 426					}
 427				}
 428				if found {
 429					break
 430				}
 431			}
 432			if found {
 433				continue
 434			}
 435			content := "There was an error while executing the tool"
 436			if isCancelErr {
 437				content = "Tool execution canceled by user"
 438			} else if isPermissionErr {
 439				content = "User denied permission"
 440			}
 441			toolResult := message.ToolResult{
 442				ToolCallID: tc.ID,
 443				Name:       tc.Name,
 444				Content:    content,
 445				IsError:    true,
 446			}
 447			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 448				Role: message.Tool,
 449				Parts: []message.ContentPart{
 450					toolResult,
 451				},
 452			})
 453			if createErr != nil {
 454				return nil, createErr
 455			}
 456		}
 457		var fantasyErr *fantasy.Error
 458		var providerErr *fantasy.ProviderError
 459		const defaultTitle = "Provider Error"
 460		if isCancelErr {
 461			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 462		} else if isPermissionErr {
 463			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 464		} else if errors.Is(err, hyper.ErrNoCredits) {
 465			url := hyper.BaseURL()
 466			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 467			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 468		} else if errors.As(err, &providerErr) {
 469			if providerErr.Message == "The requested model is not supported." {
 470				url := "https://github.com/settings/copilot/features"
 471				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 472				currentAssistant.AddFinish(
 473					message.FinishReasonError,
 474					"Copilot model not enabled",
 475					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 476				)
 477			} else {
 478				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 479			}
 480		} else if errors.As(err, &fantasyErr) {
 481			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 482		} else {
 483			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 484		}
 485		// Note: we use the parent context here because the genCtx has been
 486		// cancelled.
 487		updateErr := a.messages.Update(ctx, *currentAssistant)
 488		if updateErr != nil {
 489			return nil, updateErr
 490		}
 491		return nil, err
 492	}
 493	wg.Wait()
 494
 495	if shouldSummarize {
 496		a.activeRequests.Del(call.SessionID)
 497		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 498			return nil, summarizeErr
 499		}
 500		// If the agent wasn't done...
 501		if len(currentAssistant.ToolCalls()) > 0 {
 502			existing, ok := a.messageQueue.Get(call.SessionID)
 503			if !ok {
 504				existing = []SessionAgentCall{}
 505			}
 506			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 507			existing = append(existing, call)
 508			a.messageQueue.Set(call.SessionID, existing)
 509		}
 510	}
 511
 512	// Release active request before processing queued messages.
 513	a.activeRequests.Del(call.SessionID)
 514	cancel()
 515
 516	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 517	if !ok || len(queuedMessages) == 0 {
 518		return result, err
 519	}
 520	// There are queued messages restart the loop.
 521	firstQueuedMessage := queuedMessages[0]
 522	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 523	return a.Run(ctx, firstQueuedMessage)
 524}
 525
 526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 527	if a.IsSessionBusy(sessionID) {
 528		return ErrSessionBusy
 529	}
 530
 531	currentSession, err := a.sessions.Get(ctx, sessionID)
 532	if err != nil {
 533		return fmt.Errorf("failed to get session: %w", err)
 534	}
 535	msgs, err := a.getSessionMessages(ctx, currentSession)
 536	if err != nil {
 537		return err
 538	}
 539	if len(msgs) == 0 {
 540		// Nothing to summarize.
 541		return nil
 542	}
 543
 544	aiMsgs, _ := a.preparePrompt(msgs)
 545
 546	genCtx, cancel := context.WithCancel(ctx)
 547	a.activeRequests.Set(sessionID, cancel)
 548	defer a.activeRequests.Del(sessionID)
 549	defer cancel()
 550
 551	agent := fantasy.NewAgent(a.largeModel.Model,
 552		fantasy.WithSystemPrompt(string(summaryPrompt)),
 553	)
 554	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 555		Role:             message.Assistant,
 556		Model:            a.largeModel.Model.Model(),
 557		Provider:         a.largeModel.Model.Provider(),
 558		IsSummaryMessage: true,
 559	})
 560	if err != nil {
 561		return err
 562	}
 563
 564	summaryPromptText := "Provide a detailed summary of our conversation above."
 565	if len(currentSession.Todos) > 0 {
 566		summaryPromptText += "\n\n## Current Todo List\n\n"
 567		for _, t := range currentSession.Todos {
 568			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 569		}
 570		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 571		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 572	}
 573
 574	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 575		Prompt:          summaryPromptText,
 576		Messages:        aiMsgs,
 577		ProviderOptions: opts,
 578		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 579			prepared.Messages = options.Messages
 580			if a.systemPromptPrefix != "" {
 581				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 582			}
 583			return callContext, prepared, nil
 584		},
 585		OnReasoningDelta: func(id string, text string) error {
 586			summaryMessage.AppendReasoningContent(text)
 587			return a.messages.Update(genCtx, summaryMessage)
 588		},
 589		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 590			// Handle anthropic signature.
 591			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 592				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 593					summaryMessage.AppendReasoningSignature(signature.Signature)
 594				}
 595			}
 596			summaryMessage.FinishThinking()
 597			return a.messages.Update(genCtx, summaryMessage)
 598		},
 599		OnTextDelta: func(id, text string) error {
 600			summaryMessage.AppendContent(text)
 601			return a.messages.Update(genCtx, summaryMessage)
 602		},
 603	})
 604	if err != nil {
 605		isCancelErr := errors.Is(err, context.Canceled)
 606		if isCancelErr {
 607			// User cancelled summarize we need to remove the summary message.
 608			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 609			return deleteErr
 610		}
 611		return err
 612	}
 613
 614	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 615	err = a.messages.Update(genCtx, summaryMessage)
 616	if err != nil {
 617		return err
 618	}
 619
 620	var openrouterCost *float64
 621	for _, step := range resp.Steps {
 622		stepCost := a.openrouterCost(step.ProviderMetadata)
 623		if stepCost != nil {
 624			newCost := *stepCost
 625			if openrouterCost != nil {
 626				newCost += *openrouterCost
 627			}
 628			openrouterCost = &newCost
 629		}
 630	}
 631
 632	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 633
 634	// Just in case, get just the last usage info.
 635	usage := resp.Response.Usage
 636	currentSession.SummaryMessageID = summaryMessage.ID
 637	currentSession.CompletionTokens = usage.OutputTokens
 638	currentSession.PromptTokens = 0
 639	_, err = a.sessions.Save(genCtx, currentSession)
 640	return err
 641}
 642
 643func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 644	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 645		return fantasy.ProviderOptions{}
 646	}
 647	return fantasy.ProviderOptions{
 648		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 649			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 650		},
 651		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 652			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 653		},
 654	}
 655}
 656
 657func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 658	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 659	var attachmentParts []message.ContentPart
 660	for _, attachment := range call.Attachments {
 661		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 662	}
 663	parts = append(parts, attachmentParts...)
 664	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 665		Role:  message.User,
 666		Parts: parts,
 667	})
 668	if err != nil {
 669		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 670	}
 671	return msg, nil
 672}
 673
 674func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 675	var history []fantasy.Message
 676	if !a.isSubAgent {
 677		history = append(history, fantasy.NewUserMessage(
 678			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 679				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 680If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 681If not, please feel free to ignore. Again do not mention this message to the user.`,
 682			),
 683		))
 684	}
 685	for _, m := range msgs {
 686		if len(m.Parts) == 0 {
 687			continue
 688		}
 689		// Assistant message without content or tool calls (cancelled before it
 690		// returned anything).
 691		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 692			continue
 693		}
 694		history = append(history, m.ToAIMessage()...)
 695	}
 696
 697	var files []fantasy.FilePart
 698	for _, attachment := range attachments {
 699		if attachment.IsText() {
 700			continue
 701		}
 702		files = append(files, fantasy.FilePart{
 703			Filename:  attachment.FileName,
 704			Data:      attachment.Content,
 705			MediaType: attachment.MimeType,
 706		})
 707	}
 708
 709	return history, files
 710}
 711
 712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 713	msgs, err := a.messages.List(ctx, session.ID)
 714	if err != nil {
 715		return nil, fmt.Errorf("failed to list messages: %w", err)
 716	}
 717
 718	if session.SummaryMessageID != "" {
 719		summaryMsgInex := -1
 720		for i, msg := range msgs {
 721			if msg.ID == session.SummaryMessageID {
 722				summaryMsgInex = i
 723				break
 724			}
 725		}
 726		if summaryMsgInex != -1 {
 727			msgs = msgs[summaryMsgInex:]
 728			msgs[0].Role = message.User
 729		}
 730	}
 731	return msgs, nil
 732}
 733
 734// generateTitle generates a session titled based on the initial prompt.
 735func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 736	if userPrompt == "" {
 737		return
 738	}
 739
 740	var maxOutputTokens int64 = 40
 741	if a.smallModel.CatwalkCfg.CanReason {
 742		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 743	}
 744
 745	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 746		return fantasy.NewAgent(m,
 747			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 748			fantasy.WithMaxOutputTokens(tok),
 749		)
 750	}
 751
 752	streamCall := fantasy.AgentStreamCall{
 753		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 754		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 755			prepared.Messages = opts.Messages
 756			if a.systemPromptPrefix != "" {
 757				prepared.Messages = append([]fantasy.Message{
 758					fantasy.NewSystemMessage(a.systemPromptPrefix),
 759				}, prepared.Messages...)
 760			}
 761			return callCtx, prepared, nil
 762		},
 763	}
 764
 765	// Use the small model to generate the title.
 766	model := &a.smallModel
 767	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 768	resp, err := agent.Stream(ctx, streamCall)
 769	if err == nil {
 770		// We successfully generated a title with the small model.
 771		slog.Info("generated title with small model")
 772	} else {
 773		// It didn't work. Let's try with the big model.
 774		slog.Error("error generating title with small model; trying big model", "err", err)
 775		model = &a.largeModel
 776		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 777		resp, err = agent.Stream(ctx, streamCall)
 778		if err == nil {
 779			slog.Info("generated title with large model")
 780		} else {
 781			// Welp, the large model didn't work either.
 782			slog.Error("error generating title with large model", "err", err)
 783		}
 784	}
 785
 786	title := strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 787	slog.Info("generated title", "title", title)
 788
 789	// Remove thinking tags if present.
 790	title = thinkTagRegex.ReplaceAllString(title, "")
 791
 792	title = strings.TrimSpace(title)
 793	if title == "" {
 794		slog.Warn("empty title; using fallback")
 795		title = defaultSessionName
 796	}
 797
 798	// Calculate usage and cost.
 799	var openrouterCost *float64
 800	for _, step := range resp.Steps {
 801		stepCost := a.openrouterCost(step.ProviderMetadata)
 802		if stepCost != nil {
 803			newCost := *stepCost
 804			if openrouterCost != nil {
 805				newCost += *openrouterCost
 806			}
 807			openrouterCost = &newCost
 808		}
 809	}
 810
 811	if model == nil {
 812		slog.Error("no model available for cost calculation")
 813		return
 814	}
 815	modelConfig := model.CatwalkCfg
 816	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 817		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 818		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 819		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 820
 821	if a.isClaudeCode() {
 822		cost = 0
 823	}
 824
 825	// Use override cost if available (e.g., from OpenRouter).
 826	if openrouterCost != nil {
 827		cost = *openrouterCost
 828	}
 829
 830	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 831	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 832
 833	// Atomically update only title and usage fields to avoid overriding other
 834	// concurrent session updates.
 835	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 836	if saveErr != nil {
 837		slog.Error("failed to save session title and usage", "error", saveErr)
 838		return
 839	}
 840}
 841
 842func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 843	openrouterMetadata, ok := metadata[openrouter.Name]
 844	if !ok {
 845		return nil
 846	}
 847
 848	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 849	if !ok {
 850		return nil
 851	}
 852	return &opts.Usage.Cost
 853}
 854
 855func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 856	modelConfig := model.CatwalkCfg
 857	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 858		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 859		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 860		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 861
 862	if a.isClaudeCode() {
 863		cost = 0
 864	}
 865
 866	a.eventTokensUsed(session.ID, model, usage, cost)
 867
 868	if overrideCost != nil {
 869		session.Cost += *overrideCost
 870	} else {
 871		session.Cost += cost
 872	}
 873
 874	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 875	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 876}
 877
 878func (a *sessionAgent) Cancel(sessionID string) {
 879	// Cancel regular requests.
 880	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 881		slog.Info("Request cancellation initiated", "session_id", sessionID)
 882		cancel()
 883	}
 884
 885	// Also check for summarize requests.
 886	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 887		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 888		cancel()
 889	}
 890
 891	if a.QueuedPrompts(sessionID) > 0 {
 892		slog.Info("Clearing queued prompts", "session_id", sessionID)
 893		a.messageQueue.Del(sessionID)
 894	}
 895}
 896
 897func (a *sessionAgent) ClearQueue(sessionID string) {
 898	if a.QueuedPrompts(sessionID) > 0 {
 899		slog.Info("Clearing queued prompts", "session_id", sessionID)
 900		a.messageQueue.Del(sessionID)
 901	}
 902}
 903
 904func (a *sessionAgent) CancelAll() {
 905	if !a.IsBusy() {
 906		return
 907	}
 908	for key := range a.activeRequests.Seq2() {
 909		a.Cancel(key) // key is sessionID
 910	}
 911
 912	timeout := time.After(5 * time.Second)
 913	for a.IsBusy() {
 914		select {
 915		case <-timeout:
 916			return
 917		default:
 918			time.Sleep(200 * time.Millisecond)
 919		}
 920	}
 921}
 922
 923func (a *sessionAgent) IsBusy() bool {
 924	var busy bool
 925	for cancelFunc := range a.activeRequests.Seq() {
 926		if cancelFunc != nil {
 927			busy = true
 928			break
 929		}
 930	}
 931	return busy
 932}
 933
 934func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 935	_, busy := a.activeRequests.Get(sessionID)
 936	return busy
 937}
 938
 939func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 940	l, ok := a.messageQueue.Get(sessionID)
 941	if !ok {
 942		return 0
 943	}
 944	return len(l)
 945}
 946
 947func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 948	l, ok := a.messageQueue.Get(sessionID)
 949	if !ok {
 950		return nil
 951	}
 952	prompts := make([]string, len(l))
 953	for i, call := range l {
 954		prompts[i] = call.Prompt
 955	}
 956	return prompts
 957}
 958
 959func (a *sessionAgent) SetModels(large Model, small Model) {
 960	a.largeModel = large
 961	a.smallModel = small
 962}
 963
 964func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 965	a.tools = tools
 966}
 967
 968func (a *sessionAgent) Model() Model {
 969	return a.largeModel
 970}
 971
 972func (a *sessionAgent) promptPrefix() string {
 973	if a.isClaudeCode() {
 974		return "You are Claude Code, Anthropic's official CLI for Claude."
 975	}
 976	return a.systemPromptPrefix
 977}
 978
 979// XXX: this should be generalized to cover other subscription plans, like Copilot.
 980func (a *sessionAgent) isClaudeCode() bool {
 981	cfg := config.Get()
 982	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 983	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 984}
 985
 986// convertToToolResult converts a fantasy tool result to a message tool result.
 987func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 988	baseResult := message.ToolResult{
 989		ToolCallID: result.ToolCallID,
 990		Name:       result.ToolName,
 991		Metadata:   result.ClientMetadata,
 992	}
 993
 994	switch result.Result.GetType() {
 995	case fantasy.ToolResultContentTypeText:
 996		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 997			baseResult.Content = r.Text
 998		}
 999	case fantasy.ToolResultContentTypeError:
1000		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1001			baseResult.Content = r.Error.Error()
1002			baseResult.IsError = true
1003		}
1004	case fantasy.ToolResultContentTypeMedia:
1005		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1006			content := r.Text
1007			if content == "" {
1008				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1009			}
1010			baseResult.Content = content
1011			baseResult.Data = r.Data
1012			baseResult.MIMEType = r.MediaType
1013		}
1014	}
1015
1016	return baseResult
1017}
1018
1019// workaroundProviderMediaLimitations converts media content in tool results to
1020// user messages for providers that don't natively support images in tool results.
1021//
1022// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1023// don't support sending images/media in tool result messages - they only accept
1024// text in tool results. However, they DO support images in user messages.
1025//
1026// If we send media in tool results to these providers, the API returns an error.
1027//
1028// Solution: For these providers, we:
1029//  1. Replace the media in the tool result with a text placeholder
1030//  2. Inject a user message immediately after with the image as a file attachment
1031//  3. This maintains the tool execution flow while working around API limitations
1032//
1033// Anthropic and Bedrock support images natively in tool results, so we skip
1034// this workaround for them.
1035//
1036// Example transformation:
1037//
1038//	BEFORE: [tool result: image data]
1039//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1040func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1041	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1042		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1043
1044	if providerSupportsMedia {
1045		return messages
1046	}
1047
1048	convertedMessages := make([]fantasy.Message, 0, len(messages))
1049
1050	for _, msg := range messages {
1051		if msg.Role != fantasy.MessageRoleTool {
1052			convertedMessages = append(convertedMessages, msg)
1053			continue
1054		}
1055
1056		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1057		var mediaFiles []fantasy.FilePart
1058
1059		for _, part := range msg.Content {
1060			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1061			if !ok {
1062				textParts = append(textParts, part)
1063				continue
1064			}
1065
1066			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1067				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1068				if err != nil {
1069					slog.Warn("failed to decode media data", "error", err)
1070					textParts = append(textParts, part)
1071					continue
1072				}
1073
1074				mediaFiles = append(mediaFiles, fantasy.FilePart{
1075					Data:      decoded,
1076					MediaType: media.MediaType,
1077					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1078				})
1079
1080				textParts = append(textParts, fantasy.ToolResultPart{
1081					ToolCallID: toolResult.ToolCallID,
1082					Output: fantasy.ToolResultOutputContentText{
1083						Text: "[Image/media content loaded - see attached file]",
1084					},
1085					ProviderOptions: toolResult.ProviderOptions,
1086				})
1087			} else {
1088				textParts = append(textParts, part)
1089			}
1090		}
1091
1092		convertedMessages = append(convertedMessages, fantasy.Message{
1093			Role:    fantasy.MessageRoleTool,
1094			Content: textParts,
1095		})
1096
1097		if len(mediaFiles) > 0 {
1098			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1099				"Here is the media content from the tool result:",
1100				mediaFiles...,
1101			))
1102		}
1103	}
1104
1105	return convertedMessages
1106}