agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"charm.land/lipgloss/v2"
  31	"github.com/charmbracelet/catwalk/pkg/catwalk"
  32	"github.com/charmbracelet/crush/internal/agent/hyper"
  33	"github.com/charmbracelet/crush/internal/agent/tools"
  34	"github.com/charmbracelet/crush/internal/config"
  35	"github.com/charmbracelet/crush/internal/csync"
  36	"github.com/charmbracelet/crush/internal/message"
  37	"github.com/charmbracelet/crush/internal/permission"
  38	"github.com/charmbracelet/crush/internal/session"
  39	"github.com/charmbracelet/crush/internal/stringext"
  40)
  41
  42//go:embed templates/title.md
  43var titlePrompt []byte
  44
  45//go:embed templates/summary.md
  46var summaryPrompt []byte
  47
  48type SessionAgentCall struct {
  49	SessionID        string
  50	Prompt           string
  51	ProviderOptions  fantasy.ProviderOptions
  52	Attachments      []message.Attachment
  53	MaxOutputTokens  int64
  54	Temperature      *float64
  55	TopP             *float64
  56	TopK             *int64
  57	FrequencyPenalty *float64
  58	PresencePenalty  *float64
  59}
  60
  61type SessionAgent interface {
  62	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  63	SetModels(large Model, small Model)
  64	SetTools(tools []fantasy.AgentTool)
  65	Cancel(sessionID string)
  66	CancelAll()
  67	IsSessionBusy(sessionID string) bool
  68	IsBusy() bool
  69	QueuedPrompts(sessionID string) int
  70	QueuedPromptsList(sessionID string) []string
  71	ClearQueue(sessionID string)
  72	Summarize(context.Context, string, fantasy.ProviderOptions) error
  73	Model() Model
  74}
  75
  76type Model struct {
  77	Model      fantasy.LanguageModel
  78	CatwalkCfg catwalk.Model
  79	ModelCfg   config.SelectedModel
  80}
  81
  82type sessionAgent struct {
  83	largeModel           Model
  84	smallModel           Model
  85	systemPromptPrefix   string
  86	systemPrompt         string
  87	isSubAgent           bool
  88	tools                []fantasy.AgentTool
  89	sessions             session.Service
  90	messages             message.Service
  91	disableAutoSummarize bool
  92	isYolo               bool
  93
  94	messageQueue   *csync.Map[string, []SessionAgentCall]
  95	activeRequests *csync.Map[string, context.CancelFunc]
  96}
  97
  98type SessionAgentOptions struct {
  99	LargeModel           Model
 100	SmallModel           Model
 101	SystemPromptPrefix   string
 102	SystemPrompt         string
 103	IsSubAgent           bool
 104	DisableAutoSummarize bool
 105	IsYolo               bool
 106	Sessions             session.Service
 107	Messages             message.Service
 108	Tools                []fantasy.AgentTool
 109}
 110
 111func NewSessionAgent(
 112	opts SessionAgentOptions,
 113) SessionAgent {
 114	return &sessionAgent{
 115		largeModel:           opts.LargeModel,
 116		smallModel:           opts.SmallModel,
 117		systemPromptPrefix:   opts.SystemPromptPrefix,
 118		systemPrompt:         opts.SystemPrompt,
 119		isSubAgent:           opts.IsSubAgent,
 120		sessions:             opts.Sessions,
 121		messages:             opts.Messages,
 122		disableAutoSummarize: opts.DisableAutoSummarize,
 123		tools:                opts.Tools,
 124		isYolo:               opts.IsYolo,
 125		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 126		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 127	}
 128}
 129
 130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 131	if call.Prompt == "" {
 132		return nil, ErrEmptyPrompt
 133	}
 134	if call.SessionID == "" {
 135		return nil, ErrSessionMissing
 136	}
 137
 138	// Queue the message if busy
 139	if a.IsSessionBusy(call.SessionID) {
 140		existing, ok := a.messageQueue.Get(call.SessionID)
 141		if !ok {
 142			existing = []SessionAgentCall{}
 143		}
 144		existing = append(existing, call)
 145		a.messageQueue.Set(call.SessionID, existing)
 146		return nil, nil
 147	}
 148
 149	if len(a.tools) > 0 {
 150		// Add Anthropic caching to the last tool.
 151		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 152	}
 153
 154	agent := fantasy.NewAgent(
 155		a.largeModel.Model,
 156		fantasy.WithSystemPrompt(a.systemPrompt),
 157		fantasy.WithTools(a.tools...),
 158	)
 159
 160	sessionLock := sync.Mutex{}
 161	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 162	if err != nil {
 163		return nil, fmt.Errorf("failed to get session: %w", err)
 164	}
 165
 166	msgs, err := a.getSessionMessages(ctx, currentSession)
 167	if err != nil {
 168		return nil, fmt.Errorf("failed to get session messages: %w", err)
 169	}
 170
 171	var wg sync.WaitGroup
 172	// Generate title if first message.
 173	if len(msgs) == 0 {
 174		wg.Go(func() {
 175			sessionLock.Lock()
 176			a.generateTitle(ctx, &currentSession, call.Prompt)
 177			sessionLock.Unlock()
 178		})
 179	}
 180
 181	// Add the user message to the session.
 182	_, err = a.createUserMessage(ctx, call)
 183	if err != nil {
 184		return nil, err
 185	}
 186
 187	// Add the session to the context.
 188	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 189
 190	genCtx, cancel := context.WithCancel(ctx)
 191	a.activeRequests.Set(call.SessionID, cancel)
 192
 193	defer cancel()
 194	defer a.activeRequests.Del(call.SessionID)
 195
 196	history, files := a.preparePrompt(msgs, call.Attachments...)
 197
 198	startTime := time.Now()
 199	a.eventPromptSent(call.SessionID)
 200
 201	var currentAssistant *message.Message
 202	var shouldSummarize bool
 203	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 204		Prompt:           call.Prompt,
 205		Files:            files,
 206		Messages:         history,
 207		ProviderOptions:  call.ProviderOptions,
 208		MaxOutputTokens:  &call.MaxOutputTokens,
 209		TopP:             call.TopP,
 210		Temperature:      call.Temperature,
 211		PresencePenalty:  call.PresencePenalty,
 212		TopK:             call.TopK,
 213		FrequencyPenalty: call.FrequencyPenalty,
 214		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 215			prepared.Messages = options.Messages
 216			for i := range prepared.Messages {
 217				prepared.Messages[i].ProviderOptions = nil
 218			}
 219
 220			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 221			a.messageQueue.Del(call.SessionID)
 222			for _, queued := range queuedCalls {
 223				userMessage, createErr := a.createUserMessage(callContext, queued)
 224				if createErr != nil {
 225					return callContext, prepared, createErr
 226				}
 227				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 228			}
 229
 230			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 231
 232			lastSystemRoleInx := 0
 233			systemMessageUpdated := false
 234			for i, msg := range prepared.Messages {
 235				// Only add cache control to the last message.
 236				if msg.Role == fantasy.MessageRoleSystem {
 237					lastSystemRoleInx = i
 238				} else if !systemMessageUpdated {
 239					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 240					systemMessageUpdated = true
 241				}
 242				// Than add cache control to the last 2 messages.
 243				if i > len(prepared.Messages)-3 {
 244					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 245				}
 246			}
 247
 248			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 249				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 250			}
 251
 252			var assistantMsg message.Message
 253			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 254				Role:     message.Assistant,
 255				Parts:    []message.ContentPart{},
 256				Model:    a.largeModel.ModelCfg.Model,
 257				Provider: a.largeModel.ModelCfg.Provider,
 258			})
 259			if err != nil {
 260				return callContext, prepared, err
 261			}
 262			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 263			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 264			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 265			currentAssistant = &assistantMsg
 266			return callContext, prepared, err
 267		},
 268		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 269			currentAssistant.AppendReasoningContent(reasoning.Text)
 270			return a.messages.Update(genCtx, *currentAssistant)
 271		},
 272		OnReasoningDelta: func(id string, text string) error {
 273			currentAssistant.AppendReasoningContent(text)
 274			return a.messages.Update(genCtx, *currentAssistant)
 275		},
 276		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 277			// handle anthropic signature
 278			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 279				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 280					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 281				}
 282			}
 283			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 284				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 285					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 286				}
 287			}
 288			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 289				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 290					currentAssistant.SetReasoningResponsesData(reasoning)
 291				}
 292			}
 293			currentAssistant.FinishThinking()
 294			return a.messages.Update(genCtx, *currentAssistant)
 295		},
 296		OnTextDelta: func(id string, text string) error {
 297			// Strip leading newline from initial text content. This is is
 298			// particularly important in non-interactive mode where leading
 299			// newlines are very visible.
 300			if len(currentAssistant.Parts) == 0 {
 301				text = strings.TrimPrefix(text, "\n")
 302			}
 303
 304			currentAssistant.AppendContent(text)
 305			return a.messages.Update(genCtx, *currentAssistant)
 306		},
 307		OnToolInputStart: func(id string, toolName string) error {
 308			toolCall := message.ToolCall{
 309				ID:               id,
 310				Name:             toolName,
 311				ProviderExecuted: false,
 312				Finished:         false,
 313			}
 314			currentAssistant.AddToolCall(toolCall)
 315			return a.messages.Update(genCtx, *currentAssistant)
 316		},
 317		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 318			// TODO: implement
 319		},
 320		OnToolCall: func(tc fantasy.ToolCallContent) error {
 321			toolCall := message.ToolCall{
 322				ID:               tc.ToolCallID,
 323				Name:             tc.ToolName,
 324				Input:            tc.Input,
 325				ProviderExecuted: false,
 326				Finished:         true,
 327			}
 328			currentAssistant.AddToolCall(toolCall)
 329			return a.messages.Update(genCtx, *currentAssistant)
 330		},
 331		OnToolResult: func(result fantasy.ToolResultContent) error {
 332			toolResult := a.convertToToolResult(result)
 333			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 334				Role: message.Tool,
 335				Parts: []message.ContentPart{
 336					toolResult,
 337				},
 338			})
 339			return createMsgErr
 340		},
 341		OnStepFinish: func(stepResult fantasy.StepResult) error {
 342			finishReason := message.FinishReasonUnknown
 343			switch stepResult.FinishReason {
 344			case fantasy.FinishReasonLength:
 345				finishReason = message.FinishReasonMaxTokens
 346			case fantasy.FinishReasonStop:
 347				finishReason = message.FinishReasonEndTurn
 348			case fantasy.FinishReasonToolCalls:
 349				finishReason = message.FinishReasonToolUse
 350			}
 351			currentAssistant.AddFinish(finishReason, "", "")
 352			sessionLock.Lock()
 353			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 354			if getSessionErr != nil {
 355				sessionLock.Unlock()
 356				return getSessionErr
 357			}
 358			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 359			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 360			sessionLock.Unlock()
 361			if sessionErr != nil {
 362				return sessionErr
 363			}
 364			return a.messages.Update(genCtx, *currentAssistant)
 365		},
 366		StopWhen: []fantasy.StopCondition{
 367			func(_ []fantasy.StepResult) bool {
 368				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 369				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 370				remaining := cw - tokens
 371				var threshold int64
 372				if cw > 200_000 {
 373					threshold = 20_000
 374				} else {
 375					threshold = int64(float64(cw) * 0.2)
 376				}
 377				if (remaining <= threshold) && !a.disableAutoSummarize {
 378					shouldSummarize = true
 379					return true
 380				}
 381				return false
 382			},
 383		},
 384	})
 385
 386	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 387
 388	if err != nil {
 389		isCancelErr := errors.Is(err, context.Canceled)
 390		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 391		if currentAssistant == nil {
 392			return result, err
 393		}
 394		// Ensure we finish thinking on error to close the reasoning state.
 395		currentAssistant.FinishThinking()
 396		toolCalls := currentAssistant.ToolCalls()
 397		// INFO: we use the parent context here because the genCtx has been cancelled.
 398		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 399		if createErr != nil {
 400			return nil, createErr
 401		}
 402		for _, tc := range toolCalls {
 403			if !tc.Finished {
 404				tc.Finished = true
 405				tc.Input = "{}"
 406				currentAssistant.AddToolCall(tc)
 407				updateErr := a.messages.Update(ctx, *currentAssistant)
 408				if updateErr != nil {
 409					return nil, updateErr
 410				}
 411			}
 412
 413			found := false
 414			for _, msg := range msgs {
 415				if msg.Role == message.Tool {
 416					for _, tr := range msg.ToolResults() {
 417						if tr.ToolCallID == tc.ID {
 418							found = true
 419							break
 420						}
 421					}
 422				}
 423				if found {
 424					break
 425				}
 426			}
 427			if found {
 428				continue
 429			}
 430			content := "There was an error while executing the tool"
 431			if isCancelErr {
 432				content = "Tool execution canceled by user"
 433			} else if isPermissionErr {
 434				content = "User denied permission"
 435			}
 436			toolResult := message.ToolResult{
 437				ToolCallID: tc.ID,
 438				Name:       tc.Name,
 439				Content:    content,
 440				IsError:    true,
 441			}
 442			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 443				Role: message.Tool,
 444				Parts: []message.ContentPart{
 445					toolResult,
 446				},
 447			})
 448			if createErr != nil {
 449				return nil, createErr
 450			}
 451		}
 452		var fantasyErr *fantasy.Error
 453		var providerErr *fantasy.ProviderError
 454		const defaultTitle = "Provider Error"
 455		if isCancelErr {
 456			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 457		} else if isPermissionErr {
 458			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 459		} else if errors.Is(err, hyper.ErrNoCredits) {
 460			url := hyper.BaseURL()
 461			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 462			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 463		} else if errors.As(err, &providerErr) {
 464			if providerErr.Message == "The requested model is not supported." {
 465				url := "https://github.com/settings/copilot/features"
 466				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 467				currentAssistant.AddFinish(
 468					message.FinishReasonError,
 469					"Copilot model not enabled",
 470					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 471				)
 472			} else {
 473				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 474			}
 475		} else if errors.As(err, &fantasyErr) {
 476			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 477		} else {
 478			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 479		}
 480		// Note: we use the parent context here because the genCtx has been
 481		// cancelled.
 482		updateErr := a.messages.Update(ctx, *currentAssistant)
 483		if updateErr != nil {
 484			return nil, updateErr
 485		}
 486		return nil, err
 487	}
 488	wg.Wait()
 489
 490	if shouldSummarize {
 491		a.activeRequests.Del(call.SessionID)
 492		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 493			return nil, summarizeErr
 494		}
 495		// If the agent wasn't done...
 496		if len(currentAssistant.ToolCalls()) > 0 {
 497			existing, ok := a.messageQueue.Get(call.SessionID)
 498			if !ok {
 499				existing = []SessionAgentCall{}
 500			}
 501			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 502			existing = append(existing, call)
 503			a.messageQueue.Set(call.SessionID, existing)
 504		}
 505	}
 506
 507	// Release active request before processing queued messages.
 508	a.activeRequests.Del(call.SessionID)
 509	cancel()
 510
 511	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 512	if !ok || len(queuedMessages) == 0 {
 513		return result, err
 514	}
 515	// There are queued messages restart the loop.
 516	firstQueuedMessage := queuedMessages[0]
 517	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 518	return a.Run(ctx, firstQueuedMessage)
 519}
 520
 521func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 522	if a.IsSessionBusy(sessionID) {
 523		return ErrSessionBusy
 524	}
 525
 526	currentSession, err := a.sessions.Get(ctx, sessionID)
 527	if err != nil {
 528		return fmt.Errorf("failed to get session: %w", err)
 529	}
 530	msgs, err := a.getSessionMessages(ctx, currentSession)
 531	if err != nil {
 532		return err
 533	}
 534	if len(msgs) == 0 {
 535		// Nothing to summarize.
 536		return nil
 537	}
 538
 539	aiMsgs, _ := a.preparePrompt(msgs)
 540
 541	genCtx, cancel := context.WithCancel(ctx)
 542	a.activeRequests.Set(sessionID, cancel)
 543	defer a.activeRequests.Del(sessionID)
 544	defer cancel()
 545
 546	agent := fantasy.NewAgent(a.largeModel.Model,
 547		fantasy.WithSystemPrompt(string(summaryPrompt)),
 548	)
 549	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 550		Role:             message.Assistant,
 551		Model:            a.largeModel.Model.Model(),
 552		Provider:         a.largeModel.Model.Provider(),
 553		IsSummaryMessage: true,
 554	})
 555	if err != nil {
 556		return err
 557	}
 558
 559	summaryPromptText := "Provide a detailed summary of our conversation above."
 560	if len(currentSession.Todos) > 0 {
 561		summaryPromptText += "\n\n## Current Todo List\n\n"
 562		for _, t := range currentSession.Todos {
 563			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 564		}
 565		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 566		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 567	}
 568
 569	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 570		Prompt:          summaryPromptText,
 571		Messages:        aiMsgs,
 572		ProviderOptions: opts,
 573		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 574			prepared.Messages = options.Messages
 575			if a.systemPromptPrefix != "" {
 576				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 577			}
 578			return callContext, prepared, nil
 579		},
 580		OnReasoningDelta: func(id string, text string) error {
 581			summaryMessage.AppendReasoningContent(text)
 582			return a.messages.Update(genCtx, summaryMessage)
 583		},
 584		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 585			// Handle anthropic signature.
 586			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 587				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 588					summaryMessage.AppendReasoningSignature(signature.Signature)
 589				}
 590			}
 591			summaryMessage.FinishThinking()
 592			return a.messages.Update(genCtx, summaryMessage)
 593		},
 594		OnTextDelta: func(id, text string) error {
 595			summaryMessage.AppendContent(text)
 596			return a.messages.Update(genCtx, summaryMessage)
 597		},
 598	})
 599	if err != nil {
 600		isCancelErr := errors.Is(err, context.Canceled)
 601		if isCancelErr {
 602			// User cancelled summarize we need to remove the summary message.
 603			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 604			return deleteErr
 605		}
 606		return err
 607	}
 608
 609	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 610	err = a.messages.Update(genCtx, summaryMessage)
 611	if err != nil {
 612		return err
 613	}
 614
 615	var openrouterCost *float64
 616	for _, step := range resp.Steps {
 617		stepCost := a.openrouterCost(step.ProviderMetadata)
 618		if stepCost != nil {
 619			newCost := *stepCost
 620			if openrouterCost != nil {
 621				newCost += *openrouterCost
 622			}
 623			openrouterCost = &newCost
 624		}
 625	}
 626
 627	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 628
 629	// Just in case, get just the last usage info.
 630	usage := resp.Response.Usage
 631	currentSession.SummaryMessageID = summaryMessage.ID
 632	currentSession.CompletionTokens = usage.OutputTokens
 633	currentSession.PromptTokens = 0
 634	_, err = a.sessions.Save(genCtx, currentSession)
 635	return err
 636}
 637
 638func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 639	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 640		return fantasy.ProviderOptions{}
 641	}
 642	return fantasy.ProviderOptions{
 643		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 644			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 645		},
 646		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 647			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 648		},
 649	}
 650}
 651
 652func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 653	var attachmentParts []message.ContentPart
 654	for _, attachment := range call.Attachments {
 655		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 656	}
 657	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 658	parts = append(parts, attachmentParts...)
 659	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 660		Role:  message.User,
 661		Parts: parts,
 662	})
 663	if err != nil {
 664		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 665	}
 666	return msg, nil
 667}
 668
 669func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 670	var history []fantasy.Message
 671	if !a.isSubAgent {
 672		history = append(history, fantasy.NewUserMessage(
 673			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 674				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 675If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 676If not, please feel free to ignore. Again do not mention this message to the user.`,
 677			),
 678		))
 679	}
 680	for _, m := range msgs {
 681		if len(m.Parts) == 0 {
 682			continue
 683		}
 684		// Assistant message without content or tool calls (cancelled before it
 685		// returned anything).
 686		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 687			continue
 688		}
 689		history = append(history, m.ToAIMessage()...)
 690	}
 691
 692	var files []fantasy.FilePart
 693	for _, attachment := range attachments {
 694		files = append(files, fantasy.FilePart{
 695			Filename:  attachment.FileName,
 696			Data:      attachment.Content,
 697			MediaType: attachment.MimeType,
 698		})
 699	}
 700
 701	return history, files
 702}
 703
 704func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 705	msgs, err := a.messages.List(ctx, session.ID)
 706	if err != nil {
 707		return nil, fmt.Errorf("failed to list messages: %w", err)
 708	}
 709
 710	if session.SummaryMessageID != "" {
 711		summaryMsgInex := -1
 712		for i, msg := range msgs {
 713			if msg.ID == session.SummaryMessageID {
 714				summaryMsgInex = i
 715				break
 716			}
 717		}
 718		if summaryMsgInex != -1 {
 719			msgs = msgs[summaryMsgInex:]
 720			msgs[0].Role = message.User
 721		}
 722	}
 723	return msgs, nil
 724}
 725
 726func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
 727	if prompt == "" {
 728		return
 729	}
 730
 731	var maxOutput int64 = 40
 732	if a.smallModel.CatwalkCfg.CanReason {
 733		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 734	}
 735
 736	agent := fantasy.NewAgent(a.smallModel.Model,
 737		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 738		fantasy.WithMaxOutputTokens(maxOutput),
 739	)
 740
 741	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 742		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 743		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 744			prepared.Messages = options.Messages
 745			if a.systemPromptPrefix != "" {
 746				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 747			}
 748			return callContext, prepared, nil
 749		},
 750	})
 751	if err != nil {
 752		slog.Error("error generating title", "err", err)
 753		return
 754	}
 755
 756	title := resp.Response.Content.Text()
 757
 758	title = strings.ReplaceAll(title, "\n", " ")
 759
 760	// Remove thinking tags if present.
 761	if idx := strings.Index(title, "</think>"); idx > 0 {
 762		title = title[idx+len("</think>"):]
 763	}
 764
 765	title = strings.TrimSpace(title)
 766	if title == "" {
 767		slog.Warn("failed to generate title", "warn", "empty title")
 768		return
 769	}
 770
 771	session.Title = title
 772
 773	var openrouterCost *float64
 774	for _, step := range resp.Steps {
 775		stepCost := a.openrouterCost(step.ProviderMetadata)
 776		if stepCost != nil {
 777			newCost := *stepCost
 778			if openrouterCost != nil {
 779				newCost += *openrouterCost
 780			}
 781			openrouterCost = &newCost
 782		}
 783	}
 784
 785	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
 786	_, saveErr := a.sessions.Save(ctx, *session)
 787	if saveErr != nil {
 788		slog.Error("failed to save session title & usage", "error", saveErr)
 789		return
 790	}
 791}
 792
 793func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 794	openrouterMetadata, ok := metadata[openrouter.Name]
 795	if !ok {
 796		return nil
 797	}
 798
 799	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 800	if !ok {
 801		return nil
 802	}
 803	return &opts.Usage.Cost
 804}
 805
 806func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 807	modelConfig := model.CatwalkCfg
 808	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 809		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 810		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 811		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 812
 813	if a.isClaudeCode() {
 814		cost = 0
 815	}
 816
 817	a.eventTokensUsed(session.ID, model, usage, cost)
 818
 819	if overrideCost != nil {
 820		session.Cost += *overrideCost
 821	} else {
 822		session.Cost += cost
 823	}
 824
 825	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 826	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 827}
 828
 829func (a *sessionAgent) Cancel(sessionID string) {
 830	// Cancel regular requests.
 831	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 832		slog.Info("Request cancellation initiated", "session_id", sessionID)
 833		cancel()
 834	}
 835
 836	// Also check for summarize requests.
 837	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 838		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 839		cancel()
 840	}
 841
 842	if a.QueuedPrompts(sessionID) > 0 {
 843		slog.Info("Clearing queued prompts", "session_id", sessionID)
 844		a.messageQueue.Del(sessionID)
 845	}
 846}
 847
 848func (a *sessionAgent) ClearQueue(sessionID string) {
 849	if a.QueuedPrompts(sessionID) > 0 {
 850		slog.Info("Clearing queued prompts", "session_id", sessionID)
 851		a.messageQueue.Del(sessionID)
 852	}
 853}
 854
 855func (a *sessionAgent) CancelAll() {
 856	if !a.IsBusy() {
 857		return
 858	}
 859	for key := range a.activeRequests.Seq2() {
 860		a.Cancel(key) // key is sessionID
 861	}
 862
 863	timeout := time.After(5 * time.Second)
 864	for a.IsBusy() {
 865		select {
 866		case <-timeout:
 867			return
 868		default:
 869			time.Sleep(200 * time.Millisecond)
 870		}
 871	}
 872}
 873
 874func (a *sessionAgent) IsBusy() bool {
 875	var busy bool
 876	for cancelFunc := range a.activeRequests.Seq() {
 877		if cancelFunc != nil {
 878			busy = true
 879			break
 880		}
 881	}
 882	return busy
 883}
 884
 885func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 886	_, busy := a.activeRequests.Get(sessionID)
 887	return busy
 888}
 889
 890func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 891	l, ok := a.messageQueue.Get(sessionID)
 892	if !ok {
 893		return 0
 894	}
 895	return len(l)
 896}
 897
 898func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 899	l, ok := a.messageQueue.Get(sessionID)
 900	if !ok {
 901		return nil
 902	}
 903	prompts := make([]string, len(l))
 904	for i, call := range l {
 905		prompts[i] = call.Prompt
 906	}
 907	return prompts
 908}
 909
 910func (a *sessionAgent) SetModels(large Model, small Model) {
 911	a.largeModel = large
 912	a.smallModel = small
 913}
 914
 915func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 916	a.tools = tools
 917}
 918
 919func (a *sessionAgent) Model() Model {
 920	return a.largeModel
 921}
 922
 923func (a *sessionAgent) promptPrefix() string {
 924	if a.isClaudeCode() {
 925		return "You are Claude Code, Anthropic's official CLI for Claude."
 926	}
 927	return a.systemPromptPrefix
 928}
 929
 930func (a *sessionAgent) isClaudeCode() bool {
 931	cfg := config.Get()
 932	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 933	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 934}
 935
 936// convertToToolResult converts a fantasy tool result to a message tool result.
 937func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 938	baseResult := message.ToolResult{
 939		ToolCallID: result.ToolCallID,
 940		Name:       result.ToolName,
 941		Metadata:   result.ClientMetadata,
 942	}
 943
 944	switch result.Result.GetType() {
 945	case fantasy.ToolResultContentTypeText:
 946		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 947			baseResult.Content = r.Text
 948		}
 949	case fantasy.ToolResultContentTypeError:
 950		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 951			baseResult.Content = r.Error.Error()
 952			baseResult.IsError = true
 953		}
 954	case fantasy.ToolResultContentTypeMedia:
 955		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 956			content := r.Text
 957			if content == "" {
 958				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 959			}
 960			baseResult.Content = content
 961			baseResult.Data = r.Data
 962			baseResult.MIMEType = r.MediaType
 963		}
 964	}
 965
 966	return baseResult
 967}
 968
 969// workaroundProviderMediaLimitations converts media content in tool results to
 970// user messages for providers that don't natively support images in tool results.
 971//
 972// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 973// don't support sending images/media in tool result messages - they only accept
 974// text in tool results. However, they DO support images in user messages.
 975//
 976// If we send media in tool results to these providers, the API returns an error.
 977//
 978// Solution: For these providers, we:
 979//  1. Replace the media in the tool result with a text placeholder
 980//  2. Inject a user message immediately after with the image as a file attachment
 981//  3. This maintains the tool execution flow while working around API limitations
 982//
 983// Anthropic and Bedrock support images natively in tool results, so we skip
 984// this workaround for them.
 985//
 986// Example transformation:
 987//
 988//	BEFORE: [tool result: image data]
 989//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
 990func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
 991	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
 992		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
 993
 994	if providerSupportsMedia {
 995		return messages
 996	}
 997
 998	convertedMessages := make([]fantasy.Message, 0, len(messages))
 999
1000	for _, msg := range messages {
1001		if msg.Role != fantasy.MessageRoleTool {
1002			convertedMessages = append(convertedMessages, msg)
1003			continue
1004		}
1005
1006		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1007		var mediaFiles []fantasy.FilePart
1008
1009		for _, part := range msg.Content {
1010			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1011			if !ok {
1012				textParts = append(textParts, part)
1013				continue
1014			}
1015
1016			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1017				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1018				if err != nil {
1019					slog.Warn("failed to decode media data", "error", err)
1020					textParts = append(textParts, part)
1021					continue
1022				}
1023
1024				mediaFiles = append(mediaFiles, fantasy.FilePart{
1025					Data:      decoded,
1026					MediaType: media.MediaType,
1027					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1028				})
1029
1030				textParts = append(textParts, fantasy.ToolResultPart{
1031					ToolCallID: toolResult.ToolCallID,
1032					Output: fantasy.ToolResultOutputContentText{
1033						Text: "[Image/media content loaded - see attached file]",
1034					},
1035					ProviderOptions: toolResult.ProviderOptions,
1036				})
1037			} else {
1038				textParts = append(textParts, part)
1039			}
1040		}
1041
1042		convertedMessages = append(convertedMessages, fantasy.Message{
1043			Role:    fantasy.MessageRoleTool,
1044			Content: textParts,
1045		})
1046
1047		if len(mediaFiles) > 0 {
1048			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1049				"Here is the media content from the tool result:",
1050				mediaFiles...,
1051			))
1052		}
1053	}
1054
1055	return convertedMessages
1056}