agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	Cancel(sessionID string)
  72	CancelAll()
  73	IsSessionBusy(sessionID string) bool
  74	IsBusy() bool
  75	QueuedPrompts(sessionID string) int
  76	QueuedPromptsList(sessionID string) []string
  77	ClearQueue(sessionID string)
  78	Summarize(context.Context, string, fantasy.ProviderOptions) error
  79	Model() Model
  80}
  81
  82type Model struct {
  83	Model      fantasy.LanguageModel
  84	CatwalkCfg catwalk.Model
  85	ModelCfg   config.SelectedModel
  86}
  87
  88type sessionAgent struct {
  89	largeModel           Model
  90	smallModel           Model
  91	systemPromptPrefix   string
  92	systemPrompt         string
  93	isSubAgent           bool
  94	tools                []fantasy.AgentTool
  95	sessions             session.Service
  96	messages             message.Service
  97	disableAutoSummarize bool
  98	isYolo               bool
  99
 100	messageQueue   *csync.Map[string, []SessionAgentCall]
 101	activeRequests *csync.Map[string, context.CancelFunc]
 102}
 103
 104type SessionAgentOptions struct {
 105	LargeModel           Model
 106	SmallModel           Model
 107	SystemPromptPrefix   string
 108	SystemPrompt         string
 109	IsSubAgent           bool
 110	DisableAutoSummarize bool
 111	IsYolo               bool
 112	Sessions             session.Service
 113	Messages             message.Service
 114	Tools                []fantasy.AgentTool
 115}
 116
 117func NewSessionAgent(
 118	opts SessionAgentOptions,
 119) SessionAgent {
 120	return &sessionAgent{
 121		largeModel:           opts.LargeModel,
 122		smallModel:           opts.SmallModel,
 123		systemPromptPrefix:   opts.SystemPromptPrefix,
 124		systemPrompt:         opts.SystemPrompt,
 125		isSubAgent:           opts.IsSubAgent,
 126		sessions:             opts.Sessions,
 127		messages:             opts.Messages,
 128		disableAutoSummarize: opts.DisableAutoSummarize,
 129		tools:                opts.Tools,
 130		isYolo:               opts.IsYolo,
 131		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 132		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 133	}
 134}
 135
 136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 137	if call.Prompt == "" {
 138		return nil, ErrEmptyPrompt
 139	}
 140	if call.SessionID == "" {
 141		return nil, ErrSessionMissing
 142	}
 143
 144	// Queue the message if busy
 145	if a.IsSessionBusy(call.SessionID) {
 146		existing, ok := a.messageQueue.Get(call.SessionID)
 147		if !ok {
 148			existing = []SessionAgentCall{}
 149		}
 150		existing = append(existing, call)
 151		a.messageQueue.Set(call.SessionID, existing)
 152		return nil, nil
 153	}
 154
 155	if len(a.tools) > 0 {
 156		// Add Anthropic caching to the last tool.
 157		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 158	}
 159
 160	agent := fantasy.NewAgent(
 161		a.largeModel.Model,
 162		fantasy.WithSystemPrompt(a.systemPrompt),
 163		fantasy.WithTools(a.tools...),
 164	)
 165
 166	sessionLock := sync.Mutex{}
 167	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 168	if err != nil {
 169		return nil, fmt.Errorf("failed to get session: %w", err)
 170	}
 171
 172	msgs, err := a.getSessionMessages(ctx, currentSession)
 173	if err != nil {
 174		return nil, fmt.Errorf("failed to get session messages: %w", err)
 175	}
 176
 177	var wg sync.WaitGroup
 178	// Generate title if first message.
 179	if len(msgs) == 0 {
 180		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 181		wg.Go(func() {
 182			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 183		})
 184	}
 185
 186	// Add the user message to the session.
 187	_, err = a.createUserMessage(ctx, call)
 188	if err != nil {
 189		return nil, err
 190	}
 191
 192	// Add the session to the context.
 193	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 194
 195	genCtx, cancel := context.WithCancel(ctx)
 196	a.activeRequests.Set(call.SessionID, cancel)
 197
 198	defer cancel()
 199	defer a.activeRequests.Del(call.SessionID)
 200
 201	history, files := a.preparePrompt(msgs, call.Attachments...)
 202
 203	startTime := time.Now()
 204	a.eventPromptSent(call.SessionID)
 205
 206	var currentAssistant *message.Message
 207	var shouldSummarize bool
 208	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 209		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 210		Files:            files,
 211		Messages:         history,
 212		ProviderOptions:  call.ProviderOptions,
 213		MaxOutputTokens:  &call.MaxOutputTokens,
 214		TopP:             call.TopP,
 215		Temperature:      call.Temperature,
 216		PresencePenalty:  call.PresencePenalty,
 217		TopK:             call.TopK,
 218		FrequencyPenalty: call.FrequencyPenalty,
 219		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 220			prepared.Messages = options.Messages
 221			for i := range prepared.Messages {
 222				prepared.Messages[i].ProviderOptions = nil
 223			}
 224
 225			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 226			a.messageQueue.Del(call.SessionID)
 227			for _, queued := range queuedCalls {
 228				userMessage, createErr := a.createUserMessage(callContext, queued)
 229				if createErr != nil {
 230					return callContext, prepared, createErr
 231				}
 232				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 233			}
 234
 235			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 236
 237			lastSystemRoleInx := 0
 238			systemMessageUpdated := false
 239			for i, msg := range prepared.Messages {
 240				// Only add cache control to the last message.
 241				if msg.Role == fantasy.MessageRoleSystem {
 242					lastSystemRoleInx = i
 243				} else if !systemMessageUpdated {
 244					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 245					systemMessageUpdated = true
 246				}
 247				// Than add cache control to the last 2 messages.
 248				if i > len(prepared.Messages)-3 {
 249					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 250				}
 251			}
 252
 253			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 254				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 255			}
 256
 257			var assistantMsg message.Message
 258			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 259				Role:     message.Assistant,
 260				Parts:    []message.ContentPart{},
 261				Model:    a.largeModel.ModelCfg.Model,
 262				Provider: a.largeModel.ModelCfg.Provider,
 263			})
 264			if err != nil {
 265				return callContext, prepared, err
 266			}
 267			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 268			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 269			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 270			currentAssistant = &assistantMsg
 271			return callContext, prepared, err
 272		},
 273		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 274			currentAssistant.AppendReasoningContent(reasoning.Text)
 275			return a.messages.Update(genCtx, *currentAssistant)
 276		},
 277		OnReasoningDelta: func(id string, text string) error {
 278			currentAssistant.AppendReasoningContent(text)
 279			return a.messages.Update(genCtx, *currentAssistant)
 280		},
 281		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 282			// handle anthropic signature
 283			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 284				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 285					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 286				}
 287			}
 288			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 289				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 290					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 291				}
 292			}
 293			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 294				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 295					currentAssistant.SetReasoningResponsesData(reasoning)
 296				}
 297			}
 298			currentAssistant.FinishThinking()
 299			return a.messages.Update(genCtx, *currentAssistant)
 300		},
 301		OnTextDelta: func(id string, text string) error {
 302			// Strip leading newline from initial text content. This is is
 303			// particularly important in non-interactive mode where leading
 304			// newlines are very visible.
 305			if len(currentAssistant.Parts) == 0 {
 306				text = strings.TrimPrefix(text, "\n")
 307			}
 308
 309			currentAssistant.AppendContent(text)
 310			return a.messages.Update(genCtx, *currentAssistant)
 311		},
 312		OnToolInputStart: func(id string, toolName string) error {
 313			toolCall := message.ToolCall{
 314				ID:               id,
 315				Name:             toolName,
 316				ProviderExecuted: false,
 317				Finished:         false,
 318			}
 319			currentAssistant.AddToolCall(toolCall)
 320			return a.messages.Update(genCtx, *currentAssistant)
 321		},
 322		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 323			// TODO: implement
 324		},
 325		OnToolCall: func(tc fantasy.ToolCallContent) error {
 326			toolCall := message.ToolCall{
 327				ID:               tc.ToolCallID,
 328				Name:             tc.ToolName,
 329				Input:            tc.Input,
 330				ProviderExecuted: false,
 331				Finished:         true,
 332			}
 333			currentAssistant.AddToolCall(toolCall)
 334			return a.messages.Update(genCtx, *currentAssistant)
 335		},
 336		OnToolResult: func(result fantasy.ToolResultContent) error {
 337			toolResult := a.convertToToolResult(result)
 338			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 339				Role: message.Tool,
 340				Parts: []message.ContentPart{
 341					toolResult,
 342				},
 343			})
 344			return createMsgErr
 345		},
 346		OnStepFinish: func(stepResult fantasy.StepResult) error {
 347			finishReason := message.FinishReasonUnknown
 348			switch stepResult.FinishReason {
 349			case fantasy.FinishReasonLength:
 350				finishReason = message.FinishReasonMaxTokens
 351			case fantasy.FinishReasonStop:
 352				finishReason = message.FinishReasonEndTurn
 353			case fantasy.FinishReasonToolCalls:
 354				finishReason = message.FinishReasonToolUse
 355			}
 356			currentAssistant.AddFinish(finishReason, "", "")
 357			sessionLock.Lock()
 358			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 359			if getSessionErr != nil {
 360				sessionLock.Unlock()
 361				return getSessionErr
 362			}
 363			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 364			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 365			sessionLock.Unlock()
 366			if sessionErr != nil {
 367				return sessionErr
 368			}
 369			return a.messages.Update(genCtx, *currentAssistant)
 370		},
 371		StopWhen: []fantasy.StopCondition{
 372			func(_ []fantasy.StepResult) bool {
 373				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 374				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 375				remaining := cw - tokens
 376				var threshold int64
 377				if cw > 200_000 {
 378					threshold = 20_000
 379				} else {
 380					threshold = int64(float64(cw) * 0.2)
 381				}
 382				if (remaining <= threshold) && !a.disableAutoSummarize {
 383					shouldSummarize = true
 384					return true
 385				}
 386				return false
 387			},
 388		},
 389	})
 390
 391	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 392
 393	if err != nil {
 394		isCancelErr := errors.Is(err, context.Canceled)
 395		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 396		if currentAssistant == nil {
 397			return result, err
 398		}
 399		// Ensure we finish thinking on error to close the reasoning state.
 400		currentAssistant.FinishThinking()
 401		toolCalls := currentAssistant.ToolCalls()
 402		// INFO: we use the parent context here because the genCtx has been cancelled.
 403		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 404		if createErr != nil {
 405			return nil, createErr
 406		}
 407		for _, tc := range toolCalls {
 408			if !tc.Finished {
 409				tc.Finished = true
 410				tc.Input = "{}"
 411				currentAssistant.AddToolCall(tc)
 412				updateErr := a.messages.Update(ctx, *currentAssistant)
 413				if updateErr != nil {
 414					return nil, updateErr
 415				}
 416			}
 417
 418			found := false
 419			for _, msg := range msgs {
 420				if msg.Role == message.Tool {
 421					for _, tr := range msg.ToolResults() {
 422						if tr.ToolCallID == tc.ID {
 423							found = true
 424							break
 425						}
 426					}
 427				}
 428				if found {
 429					break
 430				}
 431			}
 432			if found {
 433				continue
 434			}
 435			content := "There was an error while executing the tool"
 436			if isCancelErr {
 437				content = "Tool execution canceled by user"
 438			} else if isPermissionErr {
 439				content = "User denied permission"
 440			}
 441			toolResult := message.ToolResult{
 442				ToolCallID: tc.ID,
 443				Name:       tc.Name,
 444				Content:    content,
 445				IsError:    true,
 446			}
 447			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 448				Role: message.Tool,
 449				Parts: []message.ContentPart{
 450					toolResult,
 451				},
 452			})
 453			if createErr != nil {
 454				return nil, createErr
 455			}
 456		}
 457		var fantasyErr *fantasy.Error
 458		var providerErr *fantasy.ProviderError
 459		const defaultTitle = "Provider Error"
 460		if isCancelErr {
 461			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 462		} else if isPermissionErr {
 463			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 464		} else if errors.Is(err, hyper.ErrNoCredits) {
 465			url := hyper.BaseURL()
 466			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 467			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 468		} else if errors.As(err, &providerErr) {
 469			if providerErr.Message == "The requested model is not supported." {
 470				url := "https://github.com/settings/copilot/features"
 471				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 472				currentAssistant.AddFinish(
 473					message.FinishReasonError,
 474					"Copilot model not enabled",
 475					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 476				)
 477			} else {
 478				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 479			}
 480		} else if errors.As(err, &fantasyErr) {
 481			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 482		} else {
 483			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 484		}
 485		// Note: we use the parent context here because the genCtx has been
 486		// cancelled.
 487		updateErr := a.messages.Update(ctx, *currentAssistant)
 488		if updateErr != nil {
 489			return nil, updateErr
 490		}
 491		return nil, err
 492	}
 493	wg.Wait()
 494
 495	if shouldSummarize {
 496		a.activeRequests.Del(call.SessionID)
 497		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 498			return nil, summarizeErr
 499		}
 500		// If the agent wasn't done...
 501		if len(currentAssistant.ToolCalls()) > 0 {
 502			existing, ok := a.messageQueue.Get(call.SessionID)
 503			if !ok {
 504				existing = []SessionAgentCall{}
 505			}
 506			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 507			existing = append(existing, call)
 508			a.messageQueue.Set(call.SessionID, existing)
 509		}
 510	}
 511
 512	// Release active request before processing queued messages.
 513	a.activeRequests.Del(call.SessionID)
 514	cancel()
 515
 516	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 517	if !ok || len(queuedMessages) == 0 {
 518		return result, err
 519	}
 520	// There are queued messages restart the loop.
 521	firstQueuedMessage := queuedMessages[0]
 522	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 523	return a.Run(ctx, firstQueuedMessage)
 524}
 525
 526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 527	if a.IsSessionBusy(sessionID) {
 528		return ErrSessionBusy
 529	}
 530
 531	currentSession, err := a.sessions.Get(ctx, sessionID)
 532	if err != nil {
 533		return fmt.Errorf("failed to get session: %w", err)
 534	}
 535	msgs, err := a.getSessionMessages(ctx, currentSession)
 536	if err != nil {
 537		return err
 538	}
 539	if len(msgs) == 0 {
 540		// Nothing to summarize.
 541		return nil
 542	}
 543
 544	aiMsgs, _ := a.preparePrompt(msgs)
 545
 546	genCtx, cancel := context.WithCancel(ctx)
 547	a.activeRequests.Set(sessionID, cancel)
 548	defer a.activeRequests.Del(sessionID)
 549	defer cancel()
 550
 551	agent := fantasy.NewAgent(a.largeModel.Model,
 552		fantasy.WithSystemPrompt(string(summaryPrompt)),
 553	)
 554	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 555		Role:             message.Assistant,
 556		Model:            a.largeModel.Model.Model(),
 557		Provider:         a.largeModel.Model.Provider(),
 558		IsSummaryMessage: true,
 559	})
 560	if err != nil {
 561		return err
 562	}
 563
 564	summaryPromptText := "Provide a detailed summary of our conversation above."
 565	if len(currentSession.Todos) > 0 {
 566		summaryPromptText += "\n\n## Current Todo List\n\n"
 567		for _, t := range currentSession.Todos {
 568			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 569		}
 570		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 571		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 572	}
 573
 574	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 575		Prompt:          summaryPromptText,
 576		Messages:        aiMsgs,
 577		ProviderOptions: opts,
 578		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 579			prepared.Messages = options.Messages
 580			if a.systemPromptPrefix != "" {
 581				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 582			}
 583			return callContext, prepared, nil
 584		},
 585		OnReasoningDelta: func(id string, text string) error {
 586			summaryMessage.AppendReasoningContent(text)
 587			return a.messages.Update(genCtx, summaryMessage)
 588		},
 589		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 590			// Handle anthropic signature.
 591			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 592				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 593					summaryMessage.AppendReasoningSignature(signature.Signature)
 594				}
 595			}
 596			summaryMessage.FinishThinking()
 597			return a.messages.Update(genCtx, summaryMessage)
 598		},
 599		OnTextDelta: func(id, text string) error {
 600			summaryMessage.AppendContent(text)
 601			return a.messages.Update(genCtx, summaryMessage)
 602		},
 603	})
 604	if err != nil {
 605		isCancelErr := errors.Is(err, context.Canceled)
 606		if isCancelErr {
 607			// User cancelled summarize we need to remove the summary message.
 608			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 609			return deleteErr
 610		}
 611		return err
 612	}
 613
 614	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 615	err = a.messages.Update(genCtx, summaryMessage)
 616	if err != nil {
 617		return err
 618	}
 619
 620	var openrouterCost *float64
 621	for _, step := range resp.Steps {
 622		stepCost := a.openrouterCost(step.ProviderMetadata)
 623		if stepCost != nil {
 624			newCost := *stepCost
 625			if openrouterCost != nil {
 626				newCost += *openrouterCost
 627			}
 628			openrouterCost = &newCost
 629		}
 630	}
 631
 632	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 633
 634	// Just in case, get just the last usage info.
 635	usage := resp.Response.Usage
 636	currentSession.SummaryMessageID = summaryMessage.ID
 637	currentSession.CompletionTokens = usage.OutputTokens
 638	currentSession.PromptTokens = 0
 639	_, err = a.sessions.Save(genCtx, currentSession)
 640	return err
 641}
 642
 643func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 644	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 645		return fantasy.ProviderOptions{}
 646	}
 647	return fantasy.ProviderOptions{
 648		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 649			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 650		},
 651		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 652			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 653		},
 654	}
 655}
 656
 657func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 658	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 659	var attachmentParts []message.ContentPart
 660	for _, attachment := range call.Attachments {
 661		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 662	}
 663	parts = append(parts, attachmentParts...)
 664	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 665		Role:  message.User,
 666		Parts: parts,
 667	})
 668	if err != nil {
 669		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 670	}
 671	return msg, nil
 672}
 673
 674func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 675	var history []fantasy.Message
 676	if !a.isSubAgent {
 677		history = append(history, fantasy.NewUserMessage(
 678			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 679				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 680If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 681If not, please feel free to ignore. Again do not mention this message to the user.`,
 682			),
 683		))
 684	}
 685	for _, m := range msgs {
 686		if len(m.Parts) == 0 {
 687			continue
 688		}
 689		// Assistant message without content or tool calls (cancelled before it
 690		// returned anything).
 691		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 692			continue
 693		}
 694		history = append(history, m.ToAIMessage()...)
 695	}
 696
 697	var files []fantasy.FilePart
 698	for _, attachment := range attachments {
 699		if attachment.IsText() {
 700			continue
 701		}
 702		files = append(files, fantasy.FilePart{
 703			Filename:  attachment.FileName,
 704			Data:      attachment.Content,
 705			MediaType: attachment.MimeType,
 706		})
 707	}
 708
 709	return history, files
 710}
 711
 712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 713	msgs, err := a.messages.List(ctx, session.ID)
 714	if err != nil {
 715		return nil, fmt.Errorf("failed to list messages: %w", err)
 716	}
 717
 718	if session.SummaryMessageID != "" {
 719		summaryMsgInex := -1
 720		for i, msg := range msgs {
 721			if msg.ID == session.SummaryMessageID {
 722				summaryMsgInex = i
 723				break
 724			}
 725		}
 726		if summaryMsgInex != -1 {
 727			msgs = msgs[summaryMsgInex:]
 728			msgs[0].Role = message.User
 729		}
 730	}
 731	return msgs, nil
 732}
 733
 734// generateTitle generates a session titled based on the initial prompt.
 735func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 736	if userPrompt == "" {
 737		return
 738	}
 739
 740	var maxOutputTokens int64 = 40
 741	if a.smallModel.CatwalkCfg.CanReason {
 742		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 743	}
 744
 745	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 746		return fantasy.NewAgent(m,
 747			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 748			fantasy.WithMaxOutputTokens(tok),
 749		)
 750	}
 751
 752	streamCall := fantasy.AgentStreamCall{
 753		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 754		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 755			prepared.Messages = opts.Messages
 756			if a.systemPromptPrefix != "" {
 757				prepared.Messages = append([]fantasy.Message{
 758					fantasy.NewSystemMessage(a.systemPromptPrefix),
 759				}, prepared.Messages...)
 760			}
 761			return callCtx, prepared, nil
 762		},
 763	}
 764
 765	// Use the small model to generate the title.
 766	model := &a.smallModel
 767	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 768	resp, err := agent.Stream(ctx, streamCall)
 769	if err == nil {
 770		// We successfully generated a title with the small model.
 771		slog.Info("generated title with small model")
 772	} else {
 773		// It didn't work. Let's try with the big model.
 774		slog.Error("error generating title with small model; trying big model", "err", err)
 775		model = &a.largeModel
 776		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 777		resp, err = agent.Stream(ctx, streamCall)
 778		if err == nil {
 779			slog.Info("generated title with large model")
 780		} else {
 781			// Welp, the large model didn't work either. Use the default
 782			// session name and return.
 783			slog.Error("error generating title with large model", "err", err)
 784			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 785			if saveErr != nil {
 786				slog.Error("failed to save session title and usage", "error", saveErr)
 787			}
 788			return
 789		}
 790	}
 791
 792	if resp == nil {
 793		// Actually, we didn't get a response so we can't. Use the default
 794		// session name and return.
 795		slog.Error("response is nil; can't generate title")
 796		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 797		if saveErr != nil {
 798			slog.Error("failed to save session title and usage", "error", saveErr)
 799		}
 800		return
 801	}
 802
 803	// Clean up title.
 804	var title string
 805	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 806	slog.Info("generated title", "title", title)
 807
 808	// Remove thinking tags if present.
 809	title = thinkTagRegex.ReplaceAllString(title, "")
 810
 811	title = strings.TrimSpace(title)
 812	if title == "" {
 813		slog.Warn("empty title; using fallback")
 814		title = defaultSessionName
 815	}
 816
 817	// Calculate usage and cost.
 818	var openrouterCost *float64
 819	for _, step := range resp.Steps {
 820		stepCost := a.openrouterCost(step.ProviderMetadata)
 821		if stepCost != nil {
 822			newCost := *stepCost
 823			if openrouterCost != nil {
 824				newCost += *openrouterCost
 825			}
 826			openrouterCost = &newCost
 827		}
 828	}
 829
 830	modelConfig := model.CatwalkCfg
 831	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 832		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 833		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 834		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 835
 836	if a.isClaudeCode() {
 837		cost = 0
 838	}
 839
 840	// Use override cost if available (e.g., from OpenRouter).
 841	if openrouterCost != nil {
 842		cost = *openrouterCost
 843	}
 844
 845	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 846	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 847
 848	// Atomically update only title and usage fields to avoid overriding other
 849	// concurrent session updates.
 850	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 851	if saveErr != nil {
 852		slog.Error("failed to save session title and usage", "error", saveErr)
 853		return
 854	}
 855}
 856
 857func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 858	openrouterMetadata, ok := metadata[openrouter.Name]
 859	if !ok {
 860		return nil
 861	}
 862
 863	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 864	if !ok {
 865		return nil
 866	}
 867	return &opts.Usage.Cost
 868}
 869
 870func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 871	modelConfig := model.CatwalkCfg
 872	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 873		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 874		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 875		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 876
 877	if a.isClaudeCode() {
 878		cost = 0
 879	}
 880
 881	a.eventTokensUsed(session.ID, model, usage, cost)
 882
 883	if overrideCost != nil {
 884		session.Cost += *overrideCost
 885	} else {
 886		session.Cost += cost
 887	}
 888
 889	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 890	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 891}
 892
 893func (a *sessionAgent) Cancel(sessionID string) {
 894	// Cancel regular requests.
 895	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 896		slog.Info("Request cancellation initiated", "session_id", sessionID)
 897		cancel()
 898	}
 899
 900	// Also check for summarize requests.
 901	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 902		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 903		cancel()
 904	}
 905
 906	if a.QueuedPrompts(sessionID) > 0 {
 907		slog.Info("Clearing queued prompts", "session_id", sessionID)
 908		a.messageQueue.Del(sessionID)
 909	}
 910}
 911
 912func (a *sessionAgent) ClearQueue(sessionID string) {
 913	if a.QueuedPrompts(sessionID) > 0 {
 914		slog.Info("Clearing queued prompts", "session_id", sessionID)
 915		a.messageQueue.Del(sessionID)
 916	}
 917}
 918
 919func (a *sessionAgent) CancelAll() {
 920	if !a.IsBusy() {
 921		return
 922	}
 923	for key := range a.activeRequests.Seq2() {
 924		a.Cancel(key) // key is sessionID
 925	}
 926
 927	timeout := time.After(5 * time.Second)
 928	for a.IsBusy() {
 929		select {
 930		case <-timeout:
 931			return
 932		default:
 933			time.Sleep(200 * time.Millisecond)
 934		}
 935	}
 936}
 937
 938func (a *sessionAgent) IsBusy() bool {
 939	var busy bool
 940	for cancelFunc := range a.activeRequests.Seq() {
 941		if cancelFunc != nil {
 942			busy = true
 943			break
 944		}
 945	}
 946	return busy
 947}
 948
 949func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 950	_, busy := a.activeRequests.Get(sessionID)
 951	return busy
 952}
 953
 954func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 955	l, ok := a.messageQueue.Get(sessionID)
 956	if !ok {
 957		return 0
 958	}
 959	return len(l)
 960}
 961
 962func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 963	l, ok := a.messageQueue.Get(sessionID)
 964	if !ok {
 965		return nil
 966	}
 967	prompts := make([]string, len(l))
 968	for i, call := range l {
 969		prompts[i] = call.Prompt
 970	}
 971	return prompts
 972}
 973
 974func (a *sessionAgent) SetModels(large Model, small Model) {
 975	a.largeModel = large
 976	a.smallModel = small
 977}
 978
 979func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 980	a.tools = tools
 981}
 982
 983func (a *sessionAgent) Model() Model {
 984	return a.largeModel
 985}
 986
 987func (a *sessionAgent) promptPrefix() string {
 988	if a.isClaudeCode() {
 989		return "You are Claude Code, Anthropic's official CLI for Claude."
 990	}
 991	return a.systemPromptPrefix
 992}
 993
 994// XXX: this should be generalized to cover other subscription plans, like Copilot.
 995func (a *sessionAgent) isClaudeCode() bool {
 996	cfg := config.Get()
 997	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 998	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 999}
1000
1001// convertToToolResult converts a fantasy tool result to a message tool result.
1002func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1003	baseResult := message.ToolResult{
1004		ToolCallID: result.ToolCallID,
1005		Name:       result.ToolName,
1006		Metadata:   result.ClientMetadata,
1007	}
1008
1009	switch result.Result.GetType() {
1010	case fantasy.ToolResultContentTypeText:
1011		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1012			baseResult.Content = r.Text
1013		}
1014	case fantasy.ToolResultContentTypeError:
1015		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1016			baseResult.Content = r.Error.Error()
1017			baseResult.IsError = true
1018		}
1019	case fantasy.ToolResultContentTypeMedia:
1020		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1021			content := r.Text
1022			if content == "" {
1023				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1024			}
1025			baseResult.Content = content
1026			baseResult.Data = r.Data
1027			baseResult.MIMEType = r.MediaType
1028		}
1029	}
1030
1031	return baseResult
1032}
1033
1034// workaroundProviderMediaLimitations converts media content in tool results to
1035// user messages for providers that don't natively support images in tool results.
1036//
1037// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1038// don't support sending images/media in tool result messages - they only accept
1039// text in tool results. However, they DO support images in user messages.
1040//
1041// If we send media in tool results to these providers, the API returns an error.
1042//
1043// Solution: For these providers, we:
1044//  1. Replace the media in the tool result with a text placeholder
1045//  2. Inject a user message immediately after with the image as a file attachment
1046//  3. This maintains the tool execution flow while working around API limitations
1047//
1048// Anthropic and Bedrock support images natively in tool results, so we skip
1049// this workaround for them.
1050//
1051// Example transformation:
1052//
1053//	BEFORE: [tool result: image data]
1054//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1055func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1056	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1057		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1058
1059	if providerSupportsMedia {
1060		return messages
1061	}
1062
1063	convertedMessages := make([]fantasy.Message, 0, len(messages))
1064
1065	for _, msg := range messages {
1066		if msg.Role != fantasy.MessageRoleTool {
1067			convertedMessages = append(convertedMessages, msg)
1068			continue
1069		}
1070
1071		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1072		var mediaFiles []fantasy.FilePart
1073
1074		for _, part := range msg.Content {
1075			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1076			if !ok {
1077				textParts = append(textParts, part)
1078				continue
1079			}
1080
1081			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1082				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1083				if err != nil {
1084					slog.Warn("failed to decode media data", "error", err)
1085					textParts = append(textParts, part)
1086					continue
1087				}
1088
1089				mediaFiles = append(mediaFiles, fantasy.FilePart{
1090					Data:      decoded,
1091					MediaType: media.MediaType,
1092					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1093				})
1094
1095				textParts = append(textParts, fantasy.ToolResultPart{
1096					ToolCallID: toolResult.ToolCallID,
1097					Output: fantasy.ToolResultOutputContentText{
1098						Text: "[Image/media content loaded - see attached file]",
1099					},
1100					ProviderOptions: toolResult.ProviderOptions,
1101				})
1102			} else {
1103				textParts = append(textParts, part)
1104			}
1105		}
1106
1107		convertedMessages = append(convertedMessages, fantasy.Message{
1108			Role:    fantasy.MessageRoleTool,
1109			Content: textParts,
1110		})
1111
1112		if len(mediaFiles) > 0 {
1113			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1114				"Here is the media content from the tool result:",
1115				mediaFiles...,
1116			))
1117		}
1118	}
1119
1120	return convertedMessages
1121}