agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"charm.land/lipgloss/v2"
  31	"github.com/charmbracelet/catwalk/pkg/catwalk"
  32	"github.com/charmbracelet/crush/internal/agent/hyper"
  33	"github.com/charmbracelet/crush/internal/agent/tools"
  34	"github.com/charmbracelet/crush/internal/config"
  35	"github.com/charmbracelet/crush/internal/csync"
  36	"github.com/charmbracelet/crush/internal/message"
  37	"github.com/charmbracelet/crush/internal/permission"
  38	"github.com/charmbracelet/crush/internal/session"
  39	"github.com/charmbracelet/crush/internal/stringext"
  40)
  41
  42//go:embed templates/title.md
  43var titlePrompt []byte
  44
  45//go:embed templates/summary.md
  46var summaryPrompt []byte
  47
  48type SessionAgentCall struct {
  49	SessionID        string
  50	Prompt           string
  51	ProviderOptions  fantasy.ProviderOptions
  52	Attachments      []message.Attachment
  53	MaxOutputTokens  int64
  54	Temperature      *float64
  55	TopP             *float64
  56	TopK             *int64
  57	FrequencyPenalty *float64
  58	PresencePenalty  *float64
  59}
  60
  61type SessionAgent interface {
  62	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  63	SetModels(large Model, small Model)
  64	SetTools(tools []fantasy.AgentTool)
  65	Cancel(sessionID string)
  66	CancelAll()
  67	IsSessionBusy(sessionID string) bool
  68	IsBusy() bool
  69	QueuedPrompts(sessionID string) int
  70	QueuedPromptsList(sessionID string) []string
  71	ClearQueue(sessionID string)
  72	Summarize(context.Context, string, fantasy.ProviderOptions) error
  73	Model() Model
  74}
  75
  76type Model struct {
  77	Model      fantasy.LanguageModel
  78	CatwalkCfg catwalk.Model
  79	ModelCfg   config.SelectedModel
  80}
  81
  82type sessionAgent struct {
  83	largeModel           Model
  84	smallModel           Model
  85	systemPromptPrefix   string
  86	systemPrompt         string
  87	isSubAgent           bool
  88	tools                []fantasy.AgentTool
  89	sessions             session.Service
  90	messages             message.Service
  91	disableAutoSummarize bool
  92	isYolo               bool
  93
  94	messageQueue   *csync.Map[string, []SessionAgentCall]
  95	activeRequests *csync.Map[string, context.CancelFunc]
  96}
  97
  98type SessionAgentOptions struct {
  99	LargeModel           Model
 100	SmallModel           Model
 101	SystemPromptPrefix   string
 102	SystemPrompt         string
 103	IsSubAgent           bool
 104	DisableAutoSummarize bool
 105	IsYolo               bool
 106	Sessions             session.Service
 107	Messages             message.Service
 108	Tools                []fantasy.AgentTool
 109}
 110
 111func NewSessionAgent(
 112	opts SessionAgentOptions,
 113) SessionAgent {
 114	return &sessionAgent{
 115		largeModel:           opts.LargeModel,
 116		smallModel:           opts.SmallModel,
 117		systemPromptPrefix:   opts.SystemPromptPrefix,
 118		systemPrompt:         opts.SystemPrompt,
 119		isSubAgent:           opts.IsSubAgent,
 120		sessions:             opts.Sessions,
 121		messages:             opts.Messages,
 122		disableAutoSummarize: opts.DisableAutoSummarize,
 123		tools:                opts.Tools,
 124		isYolo:               opts.IsYolo,
 125		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 126		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 127	}
 128}
 129
 130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 131	if call.Prompt == "" {
 132		return nil, ErrEmptyPrompt
 133	}
 134	if call.SessionID == "" {
 135		return nil, ErrSessionMissing
 136	}
 137
 138	// Queue the message if busy
 139	if a.IsSessionBusy(call.SessionID) {
 140		existing, ok := a.messageQueue.Get(call.SessionID)
 141		if !ok {
 142			existing = []SessionAgentCall{}
 143		}
 144		existing = append(existing, call)
 145		a.messageQueue.Set(call.SessionID, existing)
 146		return nil, nil
 147	}
 148
 149	if len(a.tools) > 0 {
 150		// Add Anthropic caching to the last tool.
 151		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 152	}
 153
 154	agent := fantasy.NewAgent(
 155		a.largeModel.Model,
 156		fantasy.WithSystemPrompt(a.systemPrompt),
 157		fantasy.WithTools(a.tools...),
 158	)
 159
 160	sessionLock := sync.Mutex{}
 161	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 162	if err != nil {
 163		return nil, fmt.Errorf("failed to get session: %w", err)
 164	}
 165
 166	msgs, err := a.getSessionMessages(ctx, currentSession)
 167	if err != nil {
 168		return nil, fmt.Errorf("failed to get session messages: %w", err)
 169	}
 170
 171	var wg sync.WaitGroup
 172	// Generate title if first message.
 173	if len(msgs) == 0 {
 174		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 175		wg.Go(func() {
 176			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 177		})
 178	}
 179
 180	// Add the user message to the session.
 181	_, err = a.createUserMessage(ctx, call)
 182	if err != nil {
 183		return nil, err
 184	}
 185
 186	// Add the session to the context.
 187	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 188
 189	genCtx, cancel := context.WithCancel(ctx)
 190	a.activeRequests.Set(call.SessionID, cancel)
 191
 192	defer cancel()
 193	defer a.activeRequests.Del(call.SessionID)
 194
 195	history, files := a.preparePrompt(msgs, call.Attachments...)
 196
 197	startTime := time.Now()
 198	a.eventPromptSent(call.SessionID)
 199
 200	var currentAssistant *message.Message
 201	var shouldSummarize bool
 202	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 203		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 204		Files:            files,
 205		Messages:         history,
 206		ProviderOptions:  call.ProviderOptions,
 207		MaxOutputTokens:  &call.MaxOutputTokens,
 208		TopP:             call.TopP,
 209		Temperature:      call.Temperature,
 210		PresencePenalty:  call.PresencePenalty,
 211		TopK:             call.TopK,
 212		FrequencyPenalty: call.FrequencyPenalty,
 213		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 214			prepared.Messages = options.Messages
 215			for i := range prepared.Messages {
 216				prepared.Messages[i].ProviderOptions = nil
 217			}
 218
 219			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 220			a.messageQueue.Del(call.SessionID)
 221			for _, queued := range queuedCalls {
 222				userMessage, createErr := a.createUserMessage(callContext, queued)
 223				if createErr != nil {
 224					return callContext, prepared, createErr
 225				}
 226				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 227			}
 228
 229			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 230
 231			lastSystemRoleInx := 0
 232			systemMessageUpdated := false
 233			for i, msg := range prepared.Messages {
 234				// Only add cache control to the last message.
 235				if msg.Role == fantasy.MessageRoleSystem {
 236					lastSystemRoleInx = i
 237				} else if !systemMessageUpdated {
 238					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 239					systemMessageUpdated = true
 240				}
 241				// Than add cache control to the last 2 messages.
 242				if i > len(prepared.Messages)-3 {
 243					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 244				}
 245			}
 246
 247			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 248				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 249			}
 250
 251			var assistantMsg message.Message
 252			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 253				Role:     message.Assistant,
 254				Parts:    []message.ContentPart{},
 255				Model:    a.largeModel.ModelCfg.Model,
 256				Provider: a.largeModel.ModelCfg.Provider,
 257			})
 258			if err != nil {
 259				return callContext, prepared, err
 260			}
 261			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 262			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 263			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 264			currentAssistant = &assistantMsg
 265			return callContext, prepared, err
 266		},
 267		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 268			currentAssistant.AppendReasoningContent(reasoning.Text)
 269			return a.messages.Update(genCtx, *currentAssistant)
 270		},
 271		OnReasoningDelta: func(id string, text string) error {
 272			currentAssistant.AppendReasoningContent(text)
 273			return a.messages.Update(genCtx, *currentAssistant)
 274		},
 275		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 276			// handle anthropic signature
 277			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 278				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 279					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 280				}
 281			}
 282			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 283				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 284					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 285				}
 286			}
 287			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 288				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 289					currentAssistant.SetReasoningResponsesData(reasoning)
 290				}
 291			}
 292			currentAssistant.FinishThinking()
 293			return a.messages.Update(genCtx, *currentAssistant)
 294		},
 295		OnTextDelta: func(id string, text string) error {
 296			// Strip leading newline from initial text content. This is is
 297			// particularly important in non-interactive mode where leading
 298			// newlines are very visible.
 299			if len(currentAssistant.Parts) == 0 {
 300				text = strings.TrimPrefix(text, "\n")
 301			}
 302
 303			currentAssistant.AppendContent(text)
 304			return a.messages.Update(genCtx, *currentAssistant)
 305		},
 306		OnToolInputStart: func(id string, toolName string) error {
 307			toolCall := message.ToolCall{
 308				ID:               id,
 309				Name:             toolName,
 310				ProviderExecuted: false,
 311				Finished:         false,
 312			}
 313			currentAssistant.AddToolCall(toolCall)
 314			return a.messages.Update(genCtx, *currentAssistant)
 315		},
 316		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 317			// TODO: implement
 318		},
 319		OnToolCall: func(tc fantasy.ToolCallContent) error {
 320			toolCall := message.ToolCall{
 321				ID:               tc.ToolCallID,
 322				Name:             tc.ToolName,
 323				Input:            tc.Input,
 324				ProviderExecuted: false,
 325				Finished:         true,
 326			}
 327			currentAssistant.AddToolCall(toolCall)
 328			return a.messages.Update(genCtx, *currentAssistant)
 329		},
 330		OnToolResult: func(result fantasy.ToolResultContent) error {
 331			toolResult := a.convertToToolResult(result)
 332			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 333				Role: message.Tool,
 334				Parts: []message.ContentPart{
 335					toolResult,
 336				},
 337			})
 338			return createMsgErr
 339		},
 340		OnStepFinish: func(stepResult fantasy.StepResult) error {
 341			finishReason := message.FinishReasonUnknown
 342			switch stepResult.FinishReason {
 343			case fantasy.FinishReasonLength:
 344				finishReason = message.FinishReasonMaxTokens
 345			case fantasy.FinishReasonStop:
 346				finishReason = message.FinishReasonEndTurn
 347			case fantasy.FinishReasonToolCalls:
 348				finishReason = message.FinishReasonToolUse
 349			}
 350			currentAssistant.AddFinish(finishReason, "", "")
 351			sessionLock.Lock()
 352			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 353			if getSessionErr != nil {
 354				sessionLock.Unlock()
 355				return getSessionErr
 356			}
 357			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 358			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 359			sessionLock.Unlock()
 360			if sessionErr != nil {
 361				return sessionErr
 362			}
 363			return a.messages.Update(genCtx, *currentAssistant)
 364		},
 365		StopWhen: []fantasy.StopCondition{
 366			func(_ []fantasy.StepResult) bool {
 367				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 368				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 369				remaining := cw - tokens
 370				var threshold int64
 371				if cw > 200_000 {
 372					threshold = 20_000
 373				} else {
 374					threshold = int64(float64(cw) * 0.2)
 375				}
 376				if (remaining <= threshold) && !a.disableAutoSummarize {
 377					shouldSummarize = true
 378					return true
 379				}
 380				return false
 381			},
 382		},
 383	})
 384
 385	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 386
 387	if err != nil {
 388		isCancelErr := errors.Is(err, context.Canceled)
 389		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 390		if currentAssistant == nil {
 391			return result, err
 392		}
 393		// Ensure we finish thinking on error to close the reasoning state.
 394		currentAssistant.FinishThinking()
 395		toolCalls := currentAssistant.ToolCalls()
 396		// INFO: we use the parent context here because the genCtx has been cancelled.
 397		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 398		if createErr != nil {
 399			return nil, createErr
 400		}
 401		for _, tc := range toolCalls {
 402			if !tc.Finished {
 403				tc.Finished = true
 404				tc.Input = "{}"
 405				currentAssistant.AddToolCall(tc)
 406				updateErr := a.messages.Update(ctx, *currentAssistant)
 407				if updateErr != nil {
 408					return nil, updateErr
 409				}
 410			}
 411
 412			found := false
 413			for _, msg := range msgs {
 414				if msg.Role == message.Tool {
 415					for _, tr := range msg.ToolResults() {
 416						if tr.ToolCallID == tc.ID {
 417							found = true
 418							break
 419						}
 420					}
 421				}
 422				if found {
 423					break
 424				}
 425			}
 426			if found {
 427				continue
 428			}
 429			content := "There was an error while executing the tool"
 430			if isCancelErr {
 431				content = "Tool execution canceled by user"
 432			} else if isPermissionErr {
 433				content = "User denied permission"
 434			}
 435			toolResult := message.ToolResult{
 436				ToolCallID: tc.ID,
 437				Name:       tc.Name,
 438				Content:    content,
 439				IsError:    true,
 440			}
 441			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 442				Role: message.Tool,
 443				Parts: []message.ContentPart{
 444					toolResult,
 445				},
 446			})
 447			if createErr != nil {
 448				return nil, createErr
 449			}
 450		}
 451		var fantasyErr *fantasy.Error
 452		var providerErr *fantasy.ProviderError
 453		const defaultTitle = "Provider Error"
 454		if isCancelErr {
 455			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 456		} else if isPermissionErr {
 457			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 458		} else if errors.Is(err, hyper.ErrNoCredits) {
 459			url := hyper.BaseURL()
 460			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 461			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 462		} else if errors.As(err, &providerErr) {
 463			if providerErr.Message == "The requested model is not supported." {
 464				url := "https://github.com/settings/copilot/features"
 465				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 466				currentAssistant.AddFinish(
 467					message.FinishReasonError,
 468					"Copilot model not enabled",
 469					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 470				)
 471			} else {
 472				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 473			}
 474		} else if errors.As(err, &fantasyErr) {
 475			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 476		} else {
 477			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 478		}
 479		// Note: we use the parent context here because the genCtx has been
 480		// cancelled.
 481		updateErr := a.messages.Update(ctx, *currentAssistant)
 482		if updateErr != nil {
 483			return nil, updateErr
 484		}
 485		return nil, err
 486	}
 487	wg.Wait()
 488
 489	if shouldSummarize {
 490		a.activeRequests.Del(call.SessionID)
 491		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 492			return nil, summarizeErr
 493		}
 494		// If the agent wasn't done...
 495		if len(currentAssistant.ToolCalls()) > 0 {
 496			existing, ok := a.messageQueue.Get(call.SessionID)
 497			if !ok {
 498				existing = []SessionAgentCall{}
 499			}
 500			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 501			existing = append(existing, call)
 502			a.messageQueue.Set(call.SessionID, existing)
 503		}
 504	}
 505
 506	// Release active request before processing queued messages.
 507	a.activeRequests.Del(call.SessionID)
 508	cancel()
 509
 510	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 511	if !ok || len(queuedMessages) == 0 {
 512		return result, err
 513	}
 514	// There are queued messages restart the loop.
 515	firstQueuedMessage := queuedMessages[0]
 516	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 517	return a.Run(ctx, firstQueuedMessage)
 518}
 519
 520func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 521	if a.IsSessionBusy(sessionID) {
 522		return ErrSessionBusy
 523	}
 524
 525	currentSession, err := a.sessions.Get(ctx, sessionID)
 526	if err != nil {
 527		return fmt.Errorf("failed to get session: %w", err)
 528	}
 529	msgs, err := a.getSessionMessages(ctx, currentSession)
 530	if err != nil {
 531		return err
 532	}
 533	if len(msgs) == 0 {
 534		// Nothing to summarize.
 535		return nil
 536	}
 537
 538	aiMsgs, _ := a.preparePrompt(msgs)
 539
 540	genCtx, cancel := context.WithCancel(ctx)
 541	a.activeRequests.Set(sessionID, cancel)
 542	defer a.activeRequests.Del(sessionID)
 543	defer cancel()
 544
 545	agent := fantasy.NewAgent(a.largeModel.Model,
 546		fantasy.WithSystemPrompt(string(summaryPrompt)),
 547	)
 548	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 549		Role:             message.Assistant,
 550		Model:            a.largeModel.Model.Model(),
 551		Provider:         a.largeModel.Model.Provider(),
 552		IsSummaryMessage: true,
 553	})
 554	if err != nil {
 555		return err
 556	}
 557
 558	summaryPromptText := "Provide a detailed summary of our conversation above."
 559	if len(currentSession.Todos) > 0 {
 560		summaryPromptText += "\n\n## Current Todo List\n\n"
 561		for _, t := range currentSession.Todos {
 562			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 563		}
 564		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 565		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 566	}
 567
 568	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 569		Prompt:          summaryPromptText,
 570		Messages:        aiMsgs,
 571		ProviderOptions: opts,
 572		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 573			prepared.Messages = options.Messages
 574			if a.systemPromptPrefix != "" {
 575				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 576			}
 577			return callContext, prepared, nil
 578		},
 579		OnReasoningDelta: func(id string, text string) error {
 580			summaryMessage.AppendReasoningContent(text)
 581			return a.messages.Update(genCtx, summaryMessage)
 582		},
 583		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 584			// Handle anthropic signature.
 585			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 586				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 587					summaryMessage.AppendReasoningSignature(signature.Signature)
 588				}
 589			}
 590			summaryMessage.FinishThinking()
 591			return a.messages.Update(genCtx, summaryMessage)
 592		},
 593		OnTextDelta: func(id, text string) error {
 594			summaryMessage.AppendContent(text)
 595			return a.messages.Update(genCtx, summaryMessage)
 596		},
 597	})
 598	if err != nil {
 599		isCancelErr := errors.Is(err, context.Canceled)
 600		if isCancelErr {
 601			// User cancelled summarize we need to remove the summary message.
 602			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 603			return deleteErr
 604		}
 605		return err
 606	}
 607
 608	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 609	err = a.messages.Update(genCtx, summaryMessage)
 610	if err != nil {
 611		return err
 612	}
 613
 614	var openrouterCost *float64
 615	for _, step := range resp.Steps {
 616		stepCost := a.openrouterCost(step.ProviderMetadata)
 617		if stepCost != nil {
 618			newCost := *stepCost
 619			if openrouterCost != nil {
 620				newCost += *openrouterCost
 621			}
 622			openrouterCost = &newCost
 623		}
 624	}
 625
 626	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 627
 628	// Just in case, get just the last usage info.
 629	usage := resp.Response.Usage
 630	currentSession.SummaryMessageID = summaryMessage.ID
 631	currentSession.CompletionTokens = usage.OutputTokens
 632	currentSession.PromptTokens = 0
 633	_, err = a.sessions.Save(genCtx, currentSession)
 634	return err
 635}
 636
 637func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 638	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 639		return fantasy.ProviderOptions{}
 640	}
 641	return fantasy.ProviderOptions{
 642		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 643			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 644		},
 645		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 646			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 647		},
 648	}
 649}
 650
 651func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 652	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 653	var attachmentParts []message.ContentPart
 654	for _, attachment := range call.Attachments {
 655		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 656	}
 657	parts = append(parts, attachmentParts...)
 658	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 659		Role:  message.User,
 660		Parts: parts,
 661	})
 662	if err != nil {
 663		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 664	}
 665	return msg, nil
 666}
 667
 668func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 669	var history []fantasy.Message
 670	if !a.isSubAgent {
 671		history = append(history, fantasy.NewUserMessage(
 672			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 673				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 674If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 675If not, please feel free to ignore. Again do not mention this message to the user.`,
 676			),
 677		))
 678	}
 679	for _, m := range msgs {
 680		if len(m.Parts) == 0 {
 681			continue
 682		}
 683		// Assistant message without content or tool calls (cancelled before it
 684		// returned anything).
 685		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 686			continue
 687		}
 688		history = append(history, m.ToAIMessage()...)
 689	}
 690
 691	var files []fantasy.FilePart
 692	for _, attachment := range attachments {
 693		if attachment.IsText() {
 694			continue
 695		}
 696		files = append(files, fantasy.FilePart{
 697			Filename:  attachment.FileName,
 698			Data:      attachment.Content,
 699			MediaType: attachment.MimeType,
 700		})
 701	}
 702
 703	return history, files
 704}
 705
 706func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 707	msgs, err := a.messages.List(ctx, session.ID)
 708	if err != nil {
 709		return nil, fmt.Errorf("failed to list messages: %w", err)
 710	}
 711
 712	if session.SummaryMessageID != "" {
 713		summaryMsgInex := -1
 714		for i, msg := range msgs {
 715			if msg.ID == session.SummaryMessageID {
 716				summaryMsgInex = i
 717				break
 718			}
 719		}
 720		if summaryMsgInex != -1 {
 721			msgs = msgs[summaryMsgInex:]
 722			msgs[0].Role = message.User
 723		}
 724	}
 725	return msgs, nil
 726}
 727
 728func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
 729	if prompt == "" {
 730		return
 731	}
 732
 733	var maxOutput int64 = 40
 734	if a.smallModel.CatwalkCfg.CanReason {
 735		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 736	}
 737
 738	agent := fantasy.NewAgent(a.smallModel.Model,
 739		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 740		fantasy.WithMaxOutputTokens(maxOutput),
 741	)
 742
 743	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 744		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 745		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 746			prepared.Messages = options.Messages
 747			if a.systemPromptPrefix != "" {
 748				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 749			}
 750			return callContext, prepared, nil
 751		},
 752	})
 753	if err != nil {
 754		slog.Error("error generating title", "err", err)
 755		return
 756	}
 757
 758	title := resp.Response.Content.Text()
 759
 760	title = strings.ReplaceAll(title, "\n", " ")
 761
 762	// Remove thinking tags if present.
 763	if idx := strings.Index(title, "</think>"); idx > 0 {
 764		title = title[idx+len("</think>"):]
 765	}
 766
 767	title = strings.TrimSpace(title)
 768	if title == "" {
 769		slog.Warn("failed to generate title", "warn", "empty title")
 770		return
 771	}
 772
 773	// Calculate usage and cost.
 774	var openrouterCost *float64
 775	for _, step := range resp.Steps {
 776		stepCost := a.openrouterCost(step.ProviderMetadata)
 777		if stepCost != nil {
 778			newCost := *stepCost
 779			if openrouterCost != nil {
 780				newCost += *openrouterCost
 781			}
 782			openrouterCost = &newCost
 783		}
 784	}
 785
 786	modelConfig := a.smallModel.CatwalkCfg
 787	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 788		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 789		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 790		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 791
 792	if a.isClaudeCode() {
 793		cost = 0
 794	}
 795
 796	// Use override cost if available (e.g., from OpenRouter).
 797	if openrouterCost != nil {
 798		cost = *openrouterCost
 799	}
 800
 801	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 802	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 803
 804	// Atomically update only title and usage fields to avoid overriding other
 805	// concurrent session updates.
 806	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 807	if saveErr != nil {
 808		slog.Error("failed to save session title & usage", "error", saveErr)
 809		return
 810	}
 811}
 812
 813func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 814	openrouterMetadata, ok := metadata[openrouter.Name]
 815	if !ok {
 816		return nil
 817	}
 818
 819	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 820	if !ok {
 821		return nil
 822	}
 823	return &opts.Usage.Cost
 824}
 825
 826func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 827	modelConfig := model.CatwalkCfg
 828	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 829		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 830		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 831		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 832
 833	if a.isClaudeCode() {
 834		cost = 0
 835	}
 836
 837	a.eventTokensUsed(session.ID, model, usage, cost)
 838
 839	if overrideCost != nil {
 840		session.Cost += *overrideCost
 841	} else {
 842		session.Cost += cost
 843	}
 844
 845	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 846	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 847}
 848
 849func (a *sessionAgent) Cancel(sessionID string) {
 850	// Cancel regular requests.
 851	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 852		slog.Info("Request cancellation initiated", "session_id", sessionID)
 853		cancel()
 854	}
 855
 856	// Also check for summarize requests.
 857	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 858		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 859		cancel()
 860	}
 861
 862	if a.QueuedPrompts(sessionID) > 0 {
 863		slog.Info("Clearing queued prompts", "session_id", sessionID)
 864		a.messageQueue.Del(sessionID)
 865	}
 866}
 867
 868func (a *sessionAgent) ClearQueue(sessionID string) {
 869	if a.QueuedPrompts(sessionID) > 0 {
 870		slog.Info("Clearing queued prompts", "session_id", sessionID)
 871		a.messageQueue.Del(sessionID)
 872	}
 873}
 874
 875func (a *sessionAgent) CancelAll() {
 876	if !a.IsBusy() {
 877		return
 878	}
 879	for key := range a.activeRequests.Seq2() {
 880		a.Cancel(key) // key is sessionID
 881	}
 882
 883	timeout := time.After(5 * time.Second)
 884	for a.IsBusy() {
 885		select {
 886		case <-timeout:
 887			return
 888		default:
 889			time.Sleep(200 * time.Millisecond)
 890		}
 891	}
 892}
 893
 894func (a *sessionAgent) IsBusy() bool {
 895	var busy bool
 896	for cancelFunc := range a.activeRequests.Seq() {
 897		if cancelFunc != nil {
 898			busy = true
 899			break
 900		}
 901	}
 902	return busy
 903}
 904
 905func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 906	_, busy := a.activeRequests.Get(sessionID)
 907	return busy
 908}
 909
 910func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 911	l, ok := a.messageQueue.Get(sessionID)
 912	if !ok {
 913		return 0
 914	}
 915	return len(l)
 916}
 917
 918func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 919	l, ok := a.messageQueue.Get(sessionID)
 920	if !ok {
 921		return nil
 922	}
 923	prompts := make([]string, len(l))
 924	for i, call := range l {
 925		prompts[i] = call.Prompt
 926	}
 927	return prompts
 928}
 929
 930func (a *sessionAgent) SetModels(large Model, small Model) {
 931	a.largeModel = large
 932	a.smallModel = small
 933}
 934
 935func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 936	a.tools = tools
 937}
 938
 939func (a *sessionAgent) Model() Model {
 940	return a.largeModel
 941}
 942
 943func (a *sessionAgent) promptPrefix() string {
 944	if a.isClaudeCode() {
 945		return "You are Claude Code, Anthropic's official CLI for Claude."
 946	}
 947	return a.systemPromptPrefix
 948}
 949
 950func (a *sessionAgent) isClaudeCode() bool {
 951	cfg := config.Get()
 952	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 953	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 954}
 955
 956// convertToToolResult converts a fantasy tool result to a message tool result.
 957func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 958	baseResult := message.ToolResult{
 959		ToolCallID: result.ToolCallID,
 960		Name:       result.ToolName,
 961		Metadata:   result.ClientMetadata,
 962	}
 963
 964	switch result.Result.GetType() {
 965	case fantasy.ToolResultContentTypeText:
 966		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 967			baseResult.Content = r.Text
 968		}
 969	case fantasy.ToolResultContentTypeError:
 970		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 971			baseResult.Content = r.Error.Error()
 972			baseResult.IsError = true
 973		}
 974	case fantasy.ToolResultContentTypeMedia:
 975		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 976			content := r.Text
 977			if content == "" {
 978				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 979			}
 980			baseResult.Content = content
 981			baseResult.Data = r.Data
 982			baseResult.MIMEType = r.MediaType
 983		}
 984	}
 985
 986	return baseResult
 987}
 988
 989// workaroundProviderMediaLimitations converts media content in tool results to
 990// user messages for providers that don't natively support images in tool results.
 991//
 992// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 993// don't support sending images/media in tool result messages - they only accept
 994// text in tool results. However, they DO support images in user messages.
 995//
 996// If we send media in tool results to these providers, the API returns an error.
 997//
 998// Solution: For these providers, we:
 999//  1. Replace the media in the tool result with a text placeholder
1000//  2. Inject a user message immediately after with the image as a file attachment
1001//  3. This maintains the tool execution flow while working around API limitations
1002//
1003// Anthropic and Bedrock support images natively in tool results, so we skip
1004// this workaround for them.
1005//
1006// Example transformation:
1007//
1008//	BEFORE: [tool result: image data]
1009//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1010func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1011	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1012		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1013
1014	if providerSupportsMedia {
1015		return messages
1016	}
1017
1018	convertedMessages := make([]fantasy.Message, 0, len(messages))
1019
1020	for _, msg := range messages {
1021		if msg.Role != fantasy.MessageRoleTool {
1022			convertedMessages = append(convertedMessages, msg)
1023			continue
1024		}
1025
1026		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1027		var mediaFiles []fantasy.FilePart
1028
1029		for _, part := range msg.Content {
1030			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1031			if !ok {
1032				textParts = append(textParts, part)
1033				continue
1034			}
1035
1036			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1037				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1038				if err != nil {
1039					slog.Warn("failed to decode media data", "error", err)
1040					textParts = append(textParts, part)
1041					continue
1042				}
1043
1044				mediaFiles = append(mediaFiles, fantasy.FilePart{
1045					Data:      decoded,
1046					MediaType: media.MediaType,
1047					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1048				})
1049
1050				textParts = append(textParts, fantasy.ToolResultPart{
1051					ToolCallID: toolResult.ToolCallID,
1052					Output: fantasy.ToolResultOutputContentText{
1053						Text: "[Image/media content loaded - see attached file]",
1054					},
1055					ProviderOptions: toolResult.ProviderOptions,
1056				})
1057			} else {
1058				textParts = append(textParts, part)
1059			}
1060		}
1061
1062		convertedMessages = append(convertedMessages, fantasy.Message{
1063			Role:    fantasy.MessageRoleTool,
1064			Content: textParts,
1065		})
1066
1067		if len(mediaFiles) > 0 {
1068			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1069				"Here is the media content from the tool result:",
1070				mediaFiles...,
1071			))
1072		}
1073	}
1074
1075	return convertedMessages
1076}