agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"charm.land/lipgloss/v2"
  31	"github.com/charmbracelet/catwalk/pkg/catwalk"
  32	"github.com/charmbracelet/crush/internal/agent/hyper"
  33	"github.com/charmbracelet/crush/internal/agent/tools"
  34	"github.com/charmbracelet/crush/internal/config"
  35	"github.com/charmbracelet/crush/internal/csync"
  36	"github.com/charmbracelet/crush/internal/message"
  37	"github.com/charmbracelet/crush/internal/permission"
  38	"github.com/charmbracelet/crush/internal/session"
  39	"github.com/charmbracelet/crush/internal/stringext"
  40)
  41
  42//go:embed templates/title.md
  43var titlePrompt []byte
  44
  45//go:embed templates/summary.md
  46var summaryPrompt []byte
  47
  48type SessionAgentCall struct {
  49	SessionID        string
  50	Prompt           string
  51	ProviderOptions  fantasy.ProviderOptions
  52	Attachments      []message.Attachment
  53	MaxOutputTokens  int64
  54	Temperature      *float64
  55	TopP             *float64
  56	TopK             *int64
  57	FrequencyPenalty *float64
  58	PresencePenalty  *float64
  59}
  60
  61type SessionAgent interface {
  62	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  63	SetModels(large Model, small Model)
  64	SetTools(tools []fantasy.AgentTool)
  65	Cancel(sessionID string)
  66	CancelAll()
  67	IsSessionBusy(sessionID string) bool
  68	IsBusy() bool
  69	QueuedPrompts(sessionID string) int
  70	QueuedPromptsList(sessionID string) []string
  71	ClearQueue(sessionID string)
  72	Summarize(context.Context, string, fantasy.ProviderOptions) error
  73	Model() Model
  74}
  75
  76type Model struct {
  77	Model      fantasy.LanguageModel
  78	CatwalkCfg catwalk.Model
  79	ModelCfg   config.SelectedModel
  80}
  81
  82type sessionAgent struct {
  83	largeModel           Model
  84	smallModel           Model
  85	systemPromptPrefix   string
  86	systemPrompt         string
  87	isSubAgent           bool
  88	tools                []fantasy.AgentTool
  89	sessions             session.Service
  90	messages             message.Service
  91	disableAutoSummarize bool
  92	isYolo               bool
  93
  94	messageQueue   *csync.Map[string, []SessionAgentCall]
  95	activeRequests *csync.Map[string, context.CancelFunc]
  96}
  97
  98type SessionAgentOptions struct {
  99	LargeModel           Model
 100	SmallModel           Model
 101	SystemPromptPrefix   string
 102	SystemPrompt         string
 103	IsSubAgent           bool
 104	DisableAutoSummarize bool
 105	IsYolo               bool
 106	Sessions             session.Service
 107	Messages             message.Service
 108	Tools                []fantasy.AgentTool
 109}
 110
 111func NewSessionAgent(
 112	opts SessionAgentOptions,
 113) SessionAgent {
 114	return &sessionAgent{
 115		largeModel:           opts.LargeModel,
 116		smallModel:           opts.SmallModel,
 117		systemPromptPrefix:   opts.SystemPromptPrefix,
 118		systemPrompt:         opts.SystemPrompt,
 119		isSubAgent:           opts.IsSubAgent,
 120		sessions:             opts.Sessions,
 121		messages:             opts.Messages,
 122		disableAutoSummarize: opts.DisableAutoSummarize,
 123		tools:                opts.Tools,
 124		isYolo:               opts.IsYolo,
 125		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 126		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 127	}
 128}
 129
 130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 131	if call.Prompt == "" {
 132		return nil, ErrEmptyPrompt
 133	}
 134	if call.SessionID == "" {
 135		return nil, ErrSessionMissing
 136	}
 137
 138	// Queue the message if busy
 139	if a.IsSessionBusy(call.SessionID) {
 140		existing, ok := a.messageQueue.Get(call.SessionID)
 141		if !ok {
 142			existing = []SessionAgentCall{}
 143		}
 144		existing = append(existing, call)
 145		a.messageQueue.Set(call.SessionID, existing)
 146		return nil, nil
 147	}
 148
 149	if len(a.tools) > 0 {
 150		// Add Anthropic caching to the last tool.
 151		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 152	}
 153
 154	agent := fantasy.NewAgent(
 155		a.largeModel.Model,
 156		fantasy.WithSystemPrompt(a.systemPrompt),
 157		fantasy.WithTools(a.tools...),
 158	)
 159
 160	sessionLock := sync.Mutex{}
 161	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 162	if err != nil {
 163		return nil, fmt.Errorf("failed to get session: %w", err)
 164	}
 165
 166	msgs, err := a.getSessionMessages(ctx, currentSession)
 167	if err != nil {
 168		return nil, fmt.Errorf("failed to get session messages: %w", err)
 169	}
 170
 171	var wg sync.WaitGroup
 172	// Generate title if first message.
 173	if len(msgs) == 0 {
 174		wg.Go(func() {
 175			sessionLock.Lock()
 176			a.generateTitle(ctx, &currentSession, call.Prompt)
 177			sessionLock.Unlock()
 178		})
 179	}
 180
 181	// Add the user message to the session.
 182	_, err = a.createUserMessage(ctx, call)
 183	if err != nil {
 184		return nil, err
 185	}
 186
 187	// Add the session to the context.
 188	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 189
 190	genCtx, cancel := context.WithCancel(ctx)
 191	a.activeRequests.Set(call.SessionID, cancel)
 192
 193	defer cancel()
 194	defer a.activeRequests.Del(call.SessionID)
 195
 196	history, files := a.preparePrompt(msgs, call.Attachments...)
 197
 198	startTime := time.Now()
 199	a.eventPromptSent(call.SessionID)
 200
 201	var currentAssistant *message.Message
 202	var shouldSummarize bool
 203	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 204		Prompt:           call.Prompt,
 205		Files:            files,
 206		Messages:         history,
 207		ProviderOptions:  call.ProviderOptions,
 208		MaxOutputTokens:  &call.MaxOutputTokens,
 209		TopP:             call.TopP,
 210		Temperature:      call.Temperature,
 211		PresencePenalty:  call.PresencePenalty,
 212		TopK:             call.TopK,
 213		FrequencyPenalty: call.FrequencyPenalty,
 214		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 215			prepared.Messages = options.Messages
 216			for i := range prepared.Messages {
 217				prepared.Messages[i].ProviderOptions = nil
 218			}
 219
 220			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 221			a.messageQueue.Del(call.SessionID)
 222			for _, queued := range queuedCalls {
 223				userMessage, createErr := a.createUserMessage(callContext, queued)
 224				if createErr != nil {
 225					return callContext, prepared, createErr
 226				}
 227				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 228			}
 229
 230			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 231
 232			lastSystemRoleInx := 0
 233			systemMessageUpdated := false
 234			for i, msg := range prepared.Messages {
 235				// Only add cache control to the last message.
 236				if msg.Role == fantasy.MessageRoleSystem {
 237					lastSystemRoleInx = i
 238				} else if !systemMessageUpdated {
 239					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 240					systemMessageUpdated = true
 241				}
 242				// Than add cache control to the last 2 messages.
 243				if i > len(prepared.Messages)-3 {
 244					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 245				}
 246			}
 247
 248			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 249				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 250			}
 251
 252			var assistantMsg message.Message
 253			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 254				Role:     message.Assistant,
 255				Parts:    []message.ContentPart{},
 256				Model:    a.largeModel.ModelCfg.Model,
 257				Provider: a.largeModel.ModelCfg.Provider,
 258			})
 259			if err != nil {
 260				return callContext, prepared, err
 261			}
 262			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 263			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 264			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 265			currentAssistant = &assistantMsg
 266			return callContext, prepared, err
 267		},
 268		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 269			currentAssistant.AppendReasoningContent(reasoning.Text)
 270			return a.messages.Update(genCtx, *currentAssistant)
 271		},
 272		OnReasoningDelta: func(id string, text string) error {
 273			currentAssistant.AppendReasoningContent(text)
 274			return a.messages.Update(genCtx, *currentAssistant)
 275		},
 276		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 277			// handle anthropic signature
 278			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 279				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 280					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 281				}
 282			}
 283			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 284				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 285					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 286				}
 287			}
 288			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 289				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 290					currentAssistant.SetReasoningResponsesData(reasoning)
 291				}
 292			}
 293			currentAssistant.FinishThinking()
 294			return a.messages.Update(genCtx, *currentAssistant)
 295		},
 296		OnTextDelta: func(id string, text string) error {
 297			// Strip leading newline from initial text content. This is is
 298			// particularly important in non-interactive mode where leading
 299			// newlines are very visible.
 300			if len(currentAssistant.Parts) == 0 {
 301				text = strings.TrimPrefix(text, "\n")
 302			}
 303
 304			currentAssistant.AppendContent(text)
 305			return a.messages.Update(genCtx, *currentAssistant)
 306		},
 307		OnToolInputStart: func(id string, toolName string) error {
 308			toolCall := message.ToolCall{
 309				ID:               id,
 310				Name:             toolName,
 311				ProviderExecuted: false,
 312				Finished:         false,
 313			}
 314			currentAssistant.AddToolCall(toolCall)
 315			return a.messages.Update(genCtx, *currentAssistant)
 316		},
 317		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 318			// TODO: implement
 319		},
 320		OnToolCall: func(tc fantasy.ToolCallContent) error {
 321			toolCall := message.ToolCall{
 322				ID:               tc.ToolCallID,
 323				Name:             tc.ToolName,
 324				Input:            tc.Input,
 325				ProviderExecuted: false,
 326				Finished:         true,
 327			}
 328			currentAssistant.AddToolCall(toolCall)
 329			return a.messages.Update(genCtx, *currentAssistant)
 330		},
 331		OnToolResult: func(result fantasy.ToolResultContent) error {
 332			toolResult := a.convertToToolResult(result)
 333			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 334				Role: message.Tool,
 335				Parts: []message.ContentPart{
 336					toolResult,
 337				},
 338			})
 339			return createMsgErr
 340		},
 341		OnStepFinish: func(stepResult fantasy.StepResult) error {
 342			finishReason := message.FinishReasonUnknown
 343			switch stepResult.FinishReason {
 344			case fantasy.FinishReasonLength:
 345				finishReason = message.FinishReasonMaxTokens
 346			case fantasy.FinishReasonStop:
 347				finishReason = message.FinishReasonEndTurn
 348			case fantasy.FinishReasonToolCalls:
 349				finishReason = message.FinishReasonToolUse
 350			}
 351			currentAssistant.AddFinish(finishReason, "", "")
 352			sessionLock.Lock()
 353			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 354			if getSessionErr != nil {
 355				sessionLock.Unlock()
 356				return getSessionErr
 357			}
 358			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 359			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 360			sessionLock.Unlock()
 361			if sessionErr != nil {
 362				return sessionErr
 363			}
 364			return a.messages.Update(genCtx, *currentAssistant)
 365		},
 366		StopWhen: []fantasy.StopCondition{
 367			func(_ []fantasy.StepResult) bool {
 368				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 369				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 370				remaining := cw - tokens
 371				var threshold int64
 372				if cw > 200_000 {
 373					threshold = 20_000
 374				} else {
 375					threshold = int64(float64(cw) * 0.2)
 376				}
 377				if (remaining <= threshold) && !a.disableAutoSummarize {
 378					shouldSummarize = true
 379					return true
 380				}
 381				return false
 382			},
 383		},
 384	})
 385
 386	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 387
 388	if err != nil {
 389		isCancelErr := errors.Is(err, context.Canceled)
 390		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 391		if currentAssistant == nil {
 392			return result, err
 393		}
 394		// Ensure we finish thinking on error to close the reasoning state.
 395		currentAssistant.FinishThinking()
 396		toolCalls := currentAssistant.ToolCalls()
 397		// INFO: we use the parent context here because the genCtx has been cancelled.
 398		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 399		if createErr != nil {
 400			return nil, createErr
 401		}
 402		for _, tc := range toolCalls {
 403			if !tc.Finished {
 404				tc.Finished = true
 405				tc.Input = "{}"
 406				currentAssistant.AddToolCall(tc)
 407				updateErr := a.messages.Update(ctx, *currentAssistant)
 408				if updateErr != nil {
 409					return nil, updateErr
 410				}
 411			}
 412
 413			found := false
 414			for _, msg := range msgs {
 415				if msg.Role == message.Tool {
 416					for _, tr := range msg.ToolResults() {
 417						if tr.ToolCallID == tc.ID {
 418							found = true
 419							break
 420						}
 421					}
 422				}
 423				if found {
 424					break
 425				}
 426			}
 427			if found {
 428				continue
 429			}
 430			content := "There was an error while executing the tool"
 431			if isCancelErr {
 432				content = "Tool execution canceled by user"
 433			} else if isPermissionErr {
 434				content = "User denied permission"
 435			}
 436			toolResult := message.ToolResult{
 437				ToolCallID: tc.ID,
 438				Name:       tc.Name,
 439				Content:    content,
 440				IsError:    true,
 441			}
 442			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 443				Role: message.Tool,
 444				Parts: []message.ContentPart{
 445					toolResult,
 446				},
 447			})
 448			if createErr != nil {
 449				return nil, createErr
 450			}
 451		}
 452		var fantasyErr *fantasy.Error
 453		var providerErr *fantasy.ProviderError
 454		const defaultTitle = "Provider Error"
 455		if isCancelErr {
 456			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 457		} else if isPermissionErr {
 458			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 459		} else if errors.Is(err, hyper.ErrNoCredits) {
 460			url := hyper.BaseURL()
 461			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 462			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 463		} else if errors.As(err, &providerErr) {
 464			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 465		} else if errors.As(err, &fantasyErr) {
 466			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 467		} else {
 468			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 469		}
 470		// Note: we use the parent context here because the genCtx has been
 471		// cancelled.
 472		updateErr := a.messages.Update(ctx, *currentAssistant)
 473		if updateErr != nil {
 474			return nil, updateErr
 475		}
 476		return nil, err
 477	}
 478	wg.Wait()
 479
 480	if shouldSummarize {
 481		a.activeRequests.Del(call.SessionID)
 482		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 483			return nil, summarizeErr
 484		}
 485		// If the agent wasn't done...
 486		if len(currentAssistant.ToolCalls()) > 0 {
 487			existing, ok := a.messageQueue.Get(call.SessionID)
 488			if !ok {
 489				existing = []SessionAgentCall{}
 490			}
 491			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 492			existing = append(existing, call)
 493			a.messageQueue.Set(call.SessionID, existing)
 494		}
 495	}
 496
 497	// Release active request before processing queued messages.
 498	a.activeRequests.Del(call.SessionID)
 499	cancel()
 500
 501	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 502	if !ok || len(queuedMessages) == 0 {
 503		return result, err
 504	}
 505	// There are queued messages restart the loop.
 506	firstQueuedMessage := queuedMessages[0]
 507	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 508	return a.Run(ctx, firstQueuedMessage)
 509}
 510
 511func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 512	if a.IsSessionBusy(sessionID) {
 513		return ErrSessionBusy
 514	}
 515
 516	currentSession, err := a.sessions.Get(ctx, sessionID)
 517	if err != nil {
 518		return fmt.Errorf("failed to get session: %w", err)
 519	}
 520	msgs, err := a.getSessionMessages(ctx, currentSession)
 521	if err != nil {
 522		return err
 523	}
 524	if len(msgs) == 0 {
 525		// Nothing to summarize.
 526		return nil
 527	}
 528
 529	aiMsgs, _ := a.preparePrompt(msgs)
 530
 531	genCtx, cancel := context.WithCancel(ctx)
 532	a.activeRequests.Set(sessionID, cancel)
 533	defer a.activeRequests.Del(sessionID)
 534	defer cancel()
 535
 536	agent := fantasy.NewAgent(a.largeModel.Model,
 537		fantasy.WithSystemPrompt(string(summaryPrompt)),
 538	)
 539	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 540		Role:             message.Assistant,
 541		Model:            a.largeModel.Model.Model(),
 542		Provider:         a.largeModel.Model.Provider(),
 543		IsSummaryMessage: true,
 544	})
 545	if err != nil {
 546		return err
 547	}
 548
 549	summaryPromptText := "Provide a detailed summary of our conversation above."
 550	if len(currentSession.Todos) > 0 {
 551		summaryPromptText += "\n\n## Current Todo List\n\n"
 552		for _, t := range currentSession.Todos {
 553			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 554		}
 555		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 556		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 557	}
 558
 559	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 560		Prompt:          summaryPromptText,
 561		Messages:        aiMsgs,
 562		ProviderOptions: opts,
 563		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 564			prepared.Messages = options.Messages
 565			if a.systemPromptPrefix != "" {
 566				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 567			}
 568			return callContext, prepared, nil
 569		},
 570		OnReasoningDelta: func(id string, text string) error {
 571			summaryMessage.AppendReasoningContent(text)
 572			return a.messages.Update(genCtx, summaryMessage)
 573		},
 574		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 575			// Handle anthropic signature.
 576			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 577				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 578					summaryMessage.AppendReasoningSignature(signature.Signature)
 579				}
 580			}
 581			summaryMessage.FinishThinking()
 582			return a.messages.Update(genCtx, summaryMessage)
 583		},
 584		OnTextDelta: func(id, text string) error {
 585			summaryMessage.AppendContent(text)
 586			return a.messages.Update(genCtx, summaryMessage)
 587		},
 588	})
 589	if err != nil {
 590		isCancelErr := errors.Is(err, context.Canceled)
 591		if isCancelErr {
 592			// User cancelled summarize we need to remove the summary message.
 593			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 594			return deleteErr
 595		}
 596		return err
 597	}
 598
 599	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 600	err = a.messages.Update(genCtx, summaryMessage)
 601	if err != nil {
 602		return err
 603	}
 604
 605	var openrouterCost *float64
 606	for _, step := range resp.Steps {
 607		stepCost := a.openrouterCost(step.ProviderMetadata)
 608		if stepCost != nil {
 609			newCost := *stepCost
 610			if openrouterCost != nil {
 611				newCost += *openrouterCost
 612			}
 613			openrouterCost = &newCost
 614		}
 615	}
 616
 617	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 618
 619	// Just in case, get just the last usage info.
 620	usage := resp.Response.Usage
 621	currentSession.SummaryMessageID = summaryMessage.ID
 622	currentSession.CompletionTokens = usage.OutputTokens
 623	currentSession.PromptTokens = 0
 624	_, err = a.sessions.Save(genCtx, currentSession)
 625	return err
 626}
 627
 628func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 629	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 630		return fantasy.ProviderOptions{}
 631	}
 632	return fantasy.ProviderOptions{
 633		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 634			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 635		},
 636		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 637			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 638		},
 639	}
 640}
 641
 642func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 643	var attachmentParts []message.ContentPart
 644	for _, attachment := range call.Attachments {
 645		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 646	}
 647	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 648	parts = append(parts, attachmentParts...)
 649	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 650		Role:  message.User,
 651		Parts: parts,
 652	})
 653	if err != nil {
 654		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 655	}
 656	return msg, nil
 657}
 658
 659func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 660	var history []fantasy.Message
 661	if !a.isSubAgent {
 662		history = append(history, fantasy.NewUserMessage(
 663			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 664				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 665If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 666If not, please feel free to ignore. Again do not mention this message to the user.`,
 667			),
 668		))
 669	}
 670	for _, m := range msgs {
 671		if len(m.Parts) == 0 {
 672			continue
 673		}
 674		// Assistant message without content or tool calls (cancelled before it
 675		// returned anything).
 676		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 677			continue
 678		}
 679		history = append(history, m.ToAIMessage()...)
 680	}
 681
 682	var files []fantasy.FilePart
 683	for _, attachment := range attachments {
 684		files = append(files, fantasy.FilePart{
 685			Filename:  attachment.FileName,
 686			Data:      attachment.Content,
 687			MediaType: attachment.MimeType,
 688		})
 689	}
 690
 691	return history, files
 692}
 693
 694func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 695	msgs, err := a.messages.List(ctx, session.ID)
 696	if err != nil {
 697		return nil, fmt.Errorf("failed to list messages: %w", err)
 698	}
 699
 700	if session.SummaryMessageID != "" {
 701		summaryMsgInex := -1
 702		for i, msg := range msgs {
 703			if msg.ID == session.SummaryMessageID {
 704				summaryMsgInex = i
 705				break
 706			}
 707		}
 708		if summaryMsgInex != -1 {
 709			msgs = msgs[summaryMsgInex:]
 710			msgs[0].Role = message.User
 711		}
 712	}
 713	return msgs, nil
 714}
 715
 716func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
 717	if prompt == "" {
 718		return
 719	}
 720
 721	var maxOutput int64 = 40
 722	if a.smallModel.CatwalkCfg.CanReason {
 723		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 724	}
 725
 726	agent := fantasy.NewAgent(a.smallModel.Model,
 727		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 728		fantasy.WithMaxOutputTokens(maxOutput),
 729	)
 730
 731	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 732		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 733		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 734			prepared.Messages = options.Messages
 735			if a.systemPromptPrefix != "" {
 736				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 737			}
 738			return callContext, prepared, nil
 739		},
 740	})
 741	if err != nil {
 742		slog.Error("error generating title", "err", err)
 743		return
 744	}
 745
 746	title := resp.Response.Content.Text()
 747
 748	title = strings.ReplaceAll(title, "\n", " ")
 749
 750	// Remove thinking tags if present.
 751	if idx := strings.Index(title, "</think>"); idx > 0 {
 752		title = title[idx+len("</think>"):]
 753	}
 754
 755	title = strings.TrimSpace(title)
 756	if title == "" {
 757		slog.Warn("failed to generate title", "warn", "empty title")
 758		return
 759	}
 760
 761	session.Title = title
 762
 763	var openrouterCost *float64
 764	for _, step := range resp.Steps {
 765		stepCost := a.openrouterCost(step.ProviderMetadata)
 766		if stepCost != nil {
 767			newCost := *stepCost
 768			if openrouterCost != nil {
 769				newCost += *openrouterCost
 770			}
 771			openrouterCost = &newCost
 772		}
 773	}
 774
 775	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
 776	_, saveErr := a.sessions.Save(ctx, *session)
 777	if saveErr != nil {
 778		slog.Error("failed to save session title & usage", "error", saveErr)
 779		return
 780	}
 781}
 782
 783func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 784	openrouterMetadata, ok := metadata[openrouter.Name]
 785	if !ok {
 786		return nil
 787	}
 788
 789	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 790	if !ok {
 791		return nil
 792	}
 793	return &opts.Usage.Cost
 794}
 795
 796func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 797	modelConfig := model.CatwalkCfg
 798	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 799		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 800		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 801		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 802
 803	if a.isClaudeCode() {
 804		cost = 0
 805	}
 806
 807	a.eventTokensUsed(session.ID, model, usage, cost)
 808
 809	if overrideCost != nil {
 810		session.Cost += *overrideCost
 811	} else {
 812		session.Cost += cost
 813	}
 814
 815	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 816	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 817}
 818
 819func (a *sessionAgent) Cancel(sessionID string) {
 820	// Cancel regular requests.
 821	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 822		slog.Info("Request cancellation initiated", "session_id", sessionID)
 823		cancel()
 824	}
 825
 826	// Also check for summarize requests.
 827	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 828		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 829		cancel()
 830	}
 831
 832	if a.QueuedPrompts(sessionID) > 0 {
 833		slog.Info("Clearing queued prompts", "session_id", sessionID)
 834		a.messageQueue.Del(sessionID)
 835	}
 836}
 837
 838func (a *sessionAgent) ClearQueue(sessionID string) {
 839	if a.QueuedPrompts(sessionID) > 0 {
 840		slog.Info("Clearing queued prompts", "session_id", sessionID)
 841		a.messageQueue.Del(sessionID)
 842	}
 843}
 844
 845func (a *sessionAgent) CancelAll() {
 846	if !a.IsBusy() {
 847		return
 848	}
 849	for key := range a.activeRequests.Seq2() {
 850		a.Cancel(key) // key is sessionID
 851	}
 852
 853	timeout := time.After(5 * time.Second)
 854	for a.IsBusy() {
 855		select {
 856		case <-timeout:
 857			return
 858		default:
 859			time.Sleep(200 * time.Millisecond)
 860		}
 861	}
 862}
 863
 864func (a *sessionAgent) IsBusy() bool {
 865	var busy bool
 866	for cancelFunc := range a.activeRequests.Seq() {
 867		if cancelFunc != nil {
 868			busy = true
 869			break
 870		}
 871	}
 872	return busy
 873}
 874
 875func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 876	_, busy := a.activeRequests.Get(sessionID)
 877	return busy
 878}
 879
 880func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 881	l, ok := a.messageQueue.Get(sessionID)
 882	if !ok {
 883		return 0
 884	}
 885	return len(l)
 886}
 887
 888func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 889	l, ok := a.messageQueue.Get(sessionID)
 890	if !ok {
 891		return nil
 892	}
 893	prompts := make([]string, len(l))
 894	for i, call := range l {
 895		prompts[i] = call.Prompt
 896	}
 897	return prompts
 898}
 899
 900func (a *sessionAgent) SetModels(large Model, small Model) {
 901	a.largeModel = large
 902	a.smallModel = small
 903}
 904
 905func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 906	a.tools = tools
 907}
 908
 909func (a *sessionAgent) Model() Model {
 910	return a.largeModel
 911}
 912
 913func (a *sessionAgent) promptPrefix() string {
 914	if a.isClaudeCode() {
 915		return "You are Claude Code, Anthropic's official CLI for Claude."
 916	}
 917	return a.systemPromptPrefix
 918}
 919
 920func (a *sessionAgent) isClaudeCode() bool {
 921	cfg := config.Get()
 922	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 923	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 924}
 925
 926// convertToToolResult converts a fantasy tool result to a message tool result.
 927func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 928	baseResult := message.ToolResult{
 929		ToolCallID: result.ToolCallID,
 930		Name:       result.ToolName,
 931		Metadata:   result.ClientMetadata,
 932	}
 933
 934	switch result.Result.GetType() {
 935	case fantasy.ToolResultContentTypeText:
 936		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 937			baseResult.Content = r.Text
 938		}
 939	case fantasy.ToolResultContentTypeError:
 940		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 941			baseResult.Content = r.Error.Error()
 942			baseResult.IsError = true
 943		}
 944	case fantasy.ToolResultContentTypeMedia:
 945		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 946			content := r.Text
 947			if content == "" {
 948				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 949			}
 950			baseResult.Content = content
 951			baseResult.Data = r.Data
 952			baseResult.MIMEType = r.MediaType
 953		}
 954	}
 955
 956	return baseResult
 957}
 958
 959// workaroundProviderMediaLimitations converts media content in tool results to
 960// user messages for providers that don't natively support images in tool results.
 961//
 962// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 963// don't support sending images/media in tool result messages - they only accept
 964// text in tool results. However, they DO support images in user messages.
 965//
 966// If we send media in tool results to these providers, the API returns an error.
 967//
 968// Solution: For these providers, we:
 969//  1. Replace the media in the tool result with a text placeholder
 970//  2. Inject a user message immediately after with the image as a file attachment
 971//  3. This maintains the tool execution flow while working around API limitations
 972//
 973// Anthropic and Bedrock support images natively in tool results, so we skip
 974// this workaround for them.
 975//
 976// Example transformation:
 977//
 978//	BEFORE: [tool result: image data]
 979//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
 980func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
 981	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
 982		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
 983
 984	if providerSupportsMedia {
 985		return messages
 986	}
 987
 988	convertedMessages := make([]fantasy.Message, 0, len(messages))
 989
 990	for _, msg := range messages {
 991		if msg.Role != fantasy.MessageRoleTool {
 992			convertedMessages = append(convertedMessages, msg)
 993			continue
 994		}
 995
 996		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
 997		var mediaFiles []fantasy.FilePart
 998
 999		for _, part := range msg.Content {
1000			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1001			if !ok {
1002				textParts = append(textParts, part)
1003				continue
1004			}
1005
1006			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1007				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1008				if err != nil {
1009					slog.Warn("failed to decode media data", "error", err)
1010					textParts = append(textParts, part)
1011					continue
1012				}
1013
1014				mediaFiles = append(mediaFiles, fantasy.FilePart{
1015					Data:      decoded,
1016					MediaType: media.MediaType,
1017					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1018				})
1019
1020				textParts = append(textParts, fantasy.ToolResultPart{
1021					ToolCallID: toolResult.ToolCallID,
1022					Output: fantasy.ToolResultOutputContentText{
1023						Text: "[Image/media content loaded - see attached file]",
1024					},
1025					ProviderOptions: toolResult.ProviderOptions,
1026				})
1027			} else {
1028				textParts = append(textParts, part)
1029			}
1030		}
1031
1032		convertedMessages = append(convertedMessages, fantasy.Message{
1033			Role:    fantasy.MessageRoleTool,
1034			Content: textParts,
1035		})
1036
1037		if len(mediaFiles) > 0 {
1038			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1039				"Here is the media content from the tool result:",
1040				mediaFiles...,
1041			))
1042		}
1043	}
1044
1045	return convertedMessages
1046}