agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	Cancel(sessionID string)
  72	CancelAll()
  73	IsSessionBusy(sessionID string) bool
  74	IsBusy() bool
  75	QueuedPrompts(sessionID string) int
  76	QueuedPromptsList(sessionID string) []string
  77	ClearQueue(sessionID string)
  78	Summarize(context.Context, string, fantasy.ProviderOptions) error
  79	Model() Model
  80}
  81
  82type Model struct {
  83	Model      fantasy.LanguageModel
  84	CatwalkCfg catwalk.Model
  85	ModelCfg   config.SelectedModel
  86}
  87
  88type sessionAgent struct {
  89	largeModel           Model
  90	smallModel           Model
  91	systemPromptPrefix   string
  92	systemPrompt         string
  93	isSubAgent           bool
  94	tools                []fantasy.AgentTool
  95	sessions             session.Service
  96	messages             message.Service
  97	disableAutoSummarize bool
  98	isYolo               bool
  99
 100	messageQueue   *csync.Map[string, []SessionAgentCall]
 101	activeRequests *csync.Map[string, context.CancelFunc]
 102}
 103
 104type SessionAgentOptions struct {
 105	LargeModel           Model
 106	SmallModel           Model
 107	SystemPromptPrefix   string
 108	SystemPrompt         string
 109	IsSubAgent           bool
 110	DisableAutoSummarize bool
 111	IsYolo               bool
 112	Sessions             session.Service
 113	Messages             message.Service
 114	Tools                []fantasy.AgentTool
 115}
 116
 117func NewSessionAgent(
 118	opts SessionAgentOptions,
 119) SessionAgent {
 120	return &sessionAgent{
 121		largeModel:           opts.LargeModel,
 122		smallModel:           opts.SmallModel,
 123		systemPromptPrefix:   opts.SystemPromptPrefix,
 124		systemPrompt:         opts.SystemPrompt,
 125		isSubAgent:           opts.IsSubAgent,
 126		sessions:             opts.Sessions,
 127		messages:             opts.Messages,
 128		disableAutoSummarize: opts.DisableAutoSummarize,
 129		tools:                opts.Tools,
 130		isYolo:               opts.IsYolo,
 131		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 132		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 133	}
 134}
 135
 136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 137	if call.Prompt == "" {
 138		return nil, ErrEmptyPrompt
 139	}
 140	if call.SessionID == "" {
 141		return nil, ErrSessionMissing
 142	}
 143
 144	// Queue the message if busy
 145	if a.IsSessionBusy(call.SessionID) {
 146		existing, ok := a.messageQueue.Get(call.SessionID)
 147		if !ok {
 148			existing = []SessionAgentCall{}
 149		}
 150		existing = append(existing, call)
 151		a.messageQueue.Set(call.SessionID, existing)
 152		return nil, nil
 153	}
 154
 155	if len(a.tools) > 0 {
 156		// Add Anthropic caching to the last tool.
 157		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 158	}
 159
 160	agent := fantasy.NewAgent(
 161		a.largeModel.Model,
 162		fantasy.WithSystemPrompt(a.systemPrompt),
 163		fantasy.WithTools(a.tools...),
 164	)
 165
 166	sessionLock := sync.Mutex{}
 167	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 168	if err != nil {
 169		return nil, fmt.Errorf("failed to get session: %w", err)
 170	}
 171
 172	msgs, err := a.getSessionMessages(ctx, currentSession)
 173	if err != nil {
 174		return nil, fmt.Errorf("failed to get session messages: %w", err)
 175	}
 176
 177	var wg sync.WaitGroup
 178	// Generate title if first message.
 179	if len(msgs) == 0 {
 180		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 181		wg.Go(func() {
 182			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 183		})
 184	}
 185
 186	// Add the user message to the session.
 187	_, err = a.createUserMessage(ctx, call)
 188	if err != nil {
 189		return nil, err
 190	}
 191
 192	// Add the session to the context.
 193	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 194
 195	genCtx, cancel := context.WithCancel(ctx)
 196	a.activeRequests.Set(call.SessionID, cancel)
 197
 198	defer cancel()
 199	defer a.activeRequests.Del(call.SessionID)
 200
 201	history, files := a.preparePrompt(msgs, call.Attachments...)
 202
 203	startTime := time.Now()
 204	a.eventPromptSent(call.SessionID)
 205
 206	var currentAssistant *message.Message
 207	var shouldSummarize bool
 208	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 209		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 210		Files:            files,
 211		Messages:         history,
 212		ProviderOptions:  call.ProviderOptions,
 213		MaxOutputTokens:  &call.MaxOutputTokens,
 214		TopP:             call.TopP,
 215		Temperature:      call.Temperature,
 216		PresencePenalty:  call.PresencePenalty,
 217		TopK:             call.TopK,
 218		FrequencyPenalty: call.FrequencyPenalty,
 219		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 220			prepared.Messages = options.Messages
 221			for i := range prepared.Messages {
 222				prepared.Messages[i].ProviderOptions = nil
 223			}
 224
 225			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 226			a.messageQueue.Del(call.SessionID)
 227			for _, queued := range queuedCalls {
 228				userMessage, createErr := a.createUserMessage(callContext, queued)
 229				if createErr != nil {
 230					return callContext, prepared, createErr
 231				}
 232				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 233			}
 234
 235			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 236
 237			lastSystemRoleInx := 0
 238			systemMessageUpdated := false
 239			for i, msg := range prepared.Messages {
 240				// Only add cache control to the last message.
 241				if msg.Role == fantasy.MessageRoleSystem {
 242					lastSystemRoleInx = i
 243				} else if !systemMessageUpdated {
 244					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 245					systemMessageUpdated = true
 246				}
 247				// Than add cache control to the last 2 messages.
 248				if i > len(prepared.Messages)-3 {
 249					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 250				}
 251			}
 252
 253			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 254				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 255			}
 256
 257			var assistantMsg message.Message
 258			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 259				Role:     message.Assistant,
 260				Parts:    []message.ContentPart{},
 261				Model:    a.largeModel.ModelCfg.Model,
 262				Provider: a.largeModel.ModelCfg.Provider,
 263			})
 264			if err != nil {
 265				return callContext, prepared, err
 266			}
 267			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 268			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 269			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 270			currentAssistant = &assistantMsg
 271			return callContext, prepared, err
 272		},
 273		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 274			currentAssistant.AppendReasoningContent(reasoning.Text)
 275			return a.messages.Update(genCtx, *currentAssistant)
 276		},
 277		OnReasoningDelta: func(id string, text string) error {
 278			currentAssistant.AppendReasoningContent(text)
 279			return a.messages.Update(genCtx, *currentAssistant)
 280		},
 281		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 282			// handle anthropic signature
 283			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 284				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 285					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 286				}
 287			}
 288			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 289				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 290					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 291				}
 292			}
 293			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 294				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 295					currentAssistant.SetReasoningResponsesData(reasoning)
 296				}
 297			}
 298			currentAssistant.FinishThinking()
 299			return a.messages.Update(genCtx, *currentAssistant)
 300		},
 301		OnTextDelta: func(id string, text string) error {
 302			// Strip leading newline from initial text content. This is is
 303			// particularly important in non-interactive mode where leading
 304			// newlines are very visible.
 305			if len(currentAssistant.Parts) == 0 {
 306				text = strings.TrimPrefix(text, "\n")
 307			}
 308
 309			currentAssistant.AppendContent(text)
 310			return a.messages.Update(genCtx, *currentAssistant)
 311		},
 312		OnToolInputStart: func(id string, toolName string) error {
 313			toolCall := message.ToolCall{
 314				ID:               id,
 315				Name:             toolName,
 316				ProviderExecuted: false,
 317				Finished:         false,
 318			}
 319			currentAssistant.AddToolCall(toolCall)
 320			return a.messages.Update(genCtx, *currentAssistant)
 321		},
 322		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 323			// TODO: implement
 324		},
 325		OnToolCall: func(tc fantasy.ToolCallContent) error {
 326			toolCall := message.ToolCall{
 327				ID:               tc.ToolCallID,
 328				Name:             tc.ToolName,
 329				Input:            tc.Input,
 330				ProviderExecuted: false,
 331				Finished:         true,
 332			}
 333			currentAssistant.AddToolCall(toolCall)
 334			return a.messages.Update(genCtx, *currentAssistant)
 335		},
 336		OnToolResult: func(result fantasy.ToolResultContent) error {
 337			toolResult := a.convertToToolResult(result)
 338			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 339				Role: message.Tool,
 340				Parts: []message.ContentPart{
 341					toolResult,
 342				},
 343			})
 344			return createMsgErr
 345		},
 346		OnStepFinish: func(stepResult fantasy.StepResult) error {
 347			finishReason := message.FinishReasonUnknown
 348			switch stepResult.FinishReason {
 349			case fantasy.FinishReasonLength:
 350				finishReason = message.FinishReasonMaxTokens
 351			case fantasy.FinishReasonStop:
 352				finishReason = message.FinishReasonEndTurn
 353			case fantasy.FinishReasonToolCalls:
 354				finishReason = message.FinishReasonToolUse
 355			}
 356			currentAssistant.AddFinish(finishReason, "", "")
 357			sessionLock.Lock()
 358			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 359			if getSessionErr != nil {
 360				sessionLock.Unlock()
 361				return getSessionErr
 362			}
 363			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 364			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 365			sessionLock.Unlock()
 366			if sessionErr != nil {
 367				return sessionErr
 368			}
 369			return a.messages.Update(genCtx, *currentAssistant)
 370		},
 371		StopWhen: []fantasy.StopCondition{
 372			func(_ []fantasy.StepResult) bool {
 373				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 374				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 375				remaining := cw - tokens
 376				var threshold int64
 377				if cw > 200_000 {
 378					threshold = 20_000
 379				} else {
 380					threshold = int64(float64(cw) * 0.2)
 381				}
 382				if (remaining <= threshold) && !a.disableAutoSummarize {
 383					shouldSummarize = true
 384					return true
 385				}
 386				return false
 387			},
 388		},
 389	})
 390
 391	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 392
 393	if err != nil {
 394		isCancelErr := errors.Is(err, context.Canceled)
 395		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 396		if currentAssistant == nil {
 397			return result, err
 398		}
 399		// Ensure we finish thinking on error to close the reasoning state.
 400		currentAssistant.FinishThinking()
 401		toolCalls := currentAssistant.ToolCalls()
 402		// INFO: we use the parent context here because the genCtx has been cancelled.
 403		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 404		if createErr != nil {
 405			return nil, createErr
 406		}
 407		for _, tc := range toolCalls {
 408			if !tc.Finished {
 409				tc.Finished = true
 410				tc.Input = "{}"
 411				currentAssistant.AddToolCall(tc)
 412				updateErr := a.messages.Update(ctx, *currentAssistant)
 413				if updateErr != nil {
 414					return nil, updateErr
 415				}
 416			}
 417
 418			found := false
 419			for _, msg := range msgs {
 420				if msg.Role == message.Tool {
 421					for _, tr := range msg.ToolResults() {
 422						if tr.ToolCallID == tc.ID {
 423							found = true
 424							break
 425						}
 426					}
 427				}
 428				if found {
 429					break
 430				}
 431			}
 432			if found {
 433				continue
 434			}
 435			content := "There was an error while executing the tool"
 436			if isCancelErr {
 437				content = "Tool execution canceled by user"
 438			} else if isPermissionErr {
 439				content = "User denied permission"
 440			}
 441			toolResult := message.ToolResult{
 442				ToolCallID: tc.ID,
 443				Name:       tc.Name,
 444				Content:    content,
 445				IsError:    true,
 446			}
 447			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 448				Role: message.Tool,
 449				Parts: []message.ContentPart{
 450					toolResult,
 451				},
 452			})
 453			if createErr != nil {
 454				return nil, createErr
 455			}
 456		}
 457		var fantasyErr *fantasy.Error
 458		var providerErr *fantasy.ProviderError
 459		const defaultTitle = "Provider Error"
 460		if isCancelErr {
 461			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 462		} else if isPermissionErr {
 463			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 464		} else if errors.Is(err, hyper.ErrNoCredits) {
 465			url := hyper.BaseURL()
 466			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 467			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 468		} else if errors.As(err, &providerErr) {
 469			if providerErr.Message == "The requested model is not supported." {
 470				url := "https://github.com/settings/copilot/features"
 471				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 472				currentAssistant.AddFinish(
 473					message.FinishReasonError,
 474					"Copilot model not enabled",
 475					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 476				)
 477			} else {
 478				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 479			}
 480		} else if errors.As(err, &fantasyErr) {
 481			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 482		} else {
 483			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 484		}
 485		// Note: we use the parent context here because the genCtx has been
 486		// cancelled.
 487		updateErr := a.messages.Update(ctx, *currentAssistant)
 488		if updateErr != nil {
 489			return nil, updateErr
 490		}
 491		return nil, err
 492	}
 493	wg.Wait()
 494
 495	if shouldSummarize {
 496		a.activeRequests.Del(call.SessionID)
 497		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 498			return nil, summarizeErr
 499		}
 500		// If the agent wasn't done...
 501		if len(currentAssistant.ToolCalls()) > 0 {
 502			existing, ok := a.messageQueue.Get(call.SessionID)
 503			if !ok {
 504				existing = []SessionAgentCall{}
 505			}
 506			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 507			existing = append(existing, call)
 508			a.messageQueue.Set(call.SessionID, existing)
 509		}
 510	}
 511
 512	// Release active request before processing queued messages.
 513	a.activeRequests.Del(call.SessionID)
 514	cancel()
 515
 516	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 517	if !ok || len(queuedMessages) == 0 {
 518		return result, err
 519	}
 520	// There are queued messages restart the loop.
 521	firstQueuedMessage := queuedMessages[0]
 522	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 523	return a.Run(ctx, firstQueuedMessage)
 524}
 525
 526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 527	if a.IsSessionBusy(sessionID) {
 528		return ErrSessionBusy
 529	}
 530
 531	currentSession, err := a.sessions.Get(ctx, sessionID)
 532	if err != nil {
 533		return fmt.Errorf("failed to get session: %w", err)
 534	}
 535	msgs, err := a.getSessionMessages(ctx, currentSession)
 536	if err != nil {
 537		return err
 538	}
 539	if len(msgs) == 0 {
 540		// Nothing to summarize.
 541		return nil
 542	}
 543
 544	aiMsgs, _ := a.preparePrompt(msgs)
 545
 546	genCtx, cancel := context.WithCancel(ctx)
 547	a.activeRequests.Set(sessionID, cancel)
 548	defer a.activeRequests.Del(sessionID)
 549	defer cancel()
 550
 551	agent := fantasy.NewAgent(a.largeModel.Model,
 552		fantasy.WithSystemPrompt(string(summaryPrompt)),
 553	)
 554	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 555		Role:             message.Assistant,
 556		Model:            a.largeModel.Model.Model(),
 557		Provider:         a.largeModel.Model.Provider(),
 558		IsSummaryMessage: true,
 559	})
 560	if err != nil {
 561		return err
 562	}
 563
 564	summaryPromptText := "Provide a detailed summary of our conversation above."
 565	if len(currentSession.Todos) > 0 {
 566		summaryPromptText += "\n\n## Current Todo List\n\n"
 567		for _, t := range currentSession.Todos {
 568			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 569		}
 570		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 571		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 572	}
 573
 574	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 575		Prompt:          summaryPromptText,
 576		Messages:        aiMsgs,
 577		ProviderOptions: opts,
 578		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 579			prepared.Messages = options.Messages
 580			if a.systemPromptPrefix != "" {
 581				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 582			}
 583			return callContext, prepared, nil
 584		},
 585		OnReasoningDelta: func(id string, text string) error {
 586			summaryMessage.AppendReasoningContent(text)
 587			return a.messages.Update(genCtx, summaryMessage)
 588		},
 589		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 590			// Handle anthropic signature.
 591			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 592				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 593					summaryMessage.AppendReasoningSignature(signature.Signature)
 594				}
 595			}
 596			summaryMessage.FinishThinking()
 597			return a.messages.Update(genCtx, summaryMessage)
 598		},
 599		OnTextDelta: func(id, text string) error {
 600			summaryMessage.AppendContent(text)
 601			return a.messages.Update(genCtx, summaryMessage)
 602		},
 603	})
 604	if err != nil {
 605		isCancelErr := errors.Is(err, context.Canceled)
 606		if isCancelErr {
 607			// User cancelled summarize we need to remove the summary message.
 608			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 609			return deleteErr
 610		}
 611		return err
 612	}
 613
 614	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 615	err = a.messages.Update(genCtx, summaryMessage)
 616	if err != nil {
 617		return err
 618	}
 619
 620	var openrouterCost *float64
 621	for _, step := range resp.Steps {
 622		stepCost := a.openrouterCost(step.ProviderMetadata)
 623		if stepCost != nil {
 624			newCost := *stepCost
 625			if openrouterCost != nil {
 626				newCost += *openrouterCost
 627			}
 628			openrouterCost = &newCost
 629		}
 630	}
 631
 632	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 633
 634	// Just in case, get just the last usage info.
 635	usage := resp.Response.Usage
 636	currentSession.SummaryMessageID = summaryMessage.ID
 637	currentSession.CompletionTokens = usage.OutputTokens
 638	currentSession.PromptTokens = 0
 639	_, err = a.sessions.Save(genCtx, currentSession)
 640	return err
 641}
 642
 643func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 644	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 645		return fantasy.ProviderOptions{}
 646	}
 647	return fantasy.ProviderOptions{
 648		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 649			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 650		},
 651		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 652			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 653		},
 654	}
 655}
 656
 657func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 658	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 659	var attachmentParts []message.ContentPart
 660	for _, attachment := range call.Attachments {
 661		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 662	}
 663	parts = append(parts, attachmentParts...)
 664	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 665		Role:  message.User,
 666		Parts: parts,
 667	})
 668	if err != nil {
 669		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 670	}
 671	return msg, nil
 672}
 673
 674func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 675	var history []fantasy.Message
 676	if !a.isSubAgent {
 677		history = append(history, fantasy.NewUserMessage(
 678			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 679				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 680If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 681If not, please feel free to ignore. Again do not mention this message to the user.`,
 682			),
 683		))
 684	}
 685	for _, m := range msgs {
 686		if len(m.Parts) == 0 {
 687			continue
 688		}
 689		// Assistant message without content or tool calls (cancelled before it
 690		// returned anything).
 691		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 692			continue
 693		}
 694		history = append(history, m.ToAIMessage()...)
 695	}
 696
 697	var files []fantasy.FilePart
 698	for _, attachment := range attachments {
 699		if attachment.IsText() {
 700			continue
 701		}
 702		files = append(files, fantasy.FilePart{
 703			Filename:  attachment.FileName,
 704			Data:      attachment.Content,
 705			MediaType: attachment.MimeType,
 706		})
 707	}
 708
 709	return history, files
 710}
 711
 712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 713	msgs, err := a.messages.List(ctx, session.ID)
 714	if err != nil {
 715		return nil, fmt.Errorf("failed to list messages: %w", err)
 716	}
 717
 718	if session.SummaryMessageID != "" {
 719		summaryMsgInex := -1
 720		for i, msg := range msgs {
 721			if msg.ID == session.SummaryMessageID {
 722				summaryMsgInex = i
 723				break
 724			}
 725		}
 726		if summaryMsgInex != -1 {
 727			msgs = msgs[summaryMsgInex:]
 728			msgs[0].Role = message.User
 729		}
 730	}
 731	return msgs, nil
 732}
 733
 734// generateTitle generates a session titled based on the initial prompt.
 735func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 736	if userPrompt == "" {
 737		return
 738	}
 739
 740	var maxOutputTokens int64 = 40
 741	if a.smallModel.CatwalkCfg.CanReason {
 742		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 743	}
 744
 745	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 746		return fantasy.NewAgent(m,
 747			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 748			fantasy.WithMaxOutputTokens(tok),
 749		)
 750	}
 751
 752	streamCall := fantasy.AgentStreamCall{
 753		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 754		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 755			prepared.Messages = opts.Messages
 756			if a.systemPromptPrefix != "" {
 757				prepared.Messages = append([]fantasy.Message{
 758					fantasy.NewSystemMessage(a.systemPromptPrefix),
 759				}, prepared.Messages...)
 760			}
 761			return callCtx, prepared, nil
 762		},
 763	}
 764
 765	// Use the small model to generate the title.
 766	model := &a.smallModel
 767	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 768	resp, err := agent.Stream(ctx, streamCall)
 769	if err == nil {
 770		// We successfully generated a title with the small model.
 771		slog.Info("generated title with small model")
 772	} else {
 773		// It didn't work. Let's try with the big model.
 774		slog.Error("error generating title with small model; trying big model", "err", err)
 775		model = &a.largeModel
 776		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 777		resp, err = agent.Stream(ctx, streamCall)
 778		if err == nil {
 779			slog.Info("generated title with large model")
 780		} else {
 781			// Welp, the large model didn't work either. Use the default
 782			// session name and return.
 783			slog.Error("error generating title with large model", "err", err)
 784			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 785			if saveErr != nil {
 786				slog.Error("failed to save session title and usage", "error", saveErr)
 787			}
 788			return
 789		}
 790	}
 791
 792	if resp == nil {
 793		// Actually, we didn't get a response so we can't. Use the default
 794		// session name and return.
 795		slog.Error("response is nil; can't generate title")
 796		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 797		if saveErr != nil {
 798			slog.Error("failed to save session title and usage", "error", saveErr)
 799		}
 800		return
 801	}
 802
 803	// Clean up title.
 804	var title string
 805	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 806	slog.Info("generated title", "title", title)
 807
 808	// Remove thinking tags if present.
 809	title = thinkTagRegex.ReplaceAllString(title, "")
 810
 811	title = strings.TrimSpace(title)
 812	if title == "" {
 813		slog.Warn("empty title; using fallback")
 814		title = defaultSessionName
 815	}
 816
 817	// Calculate usage and cost.
 818	var openrouterCost *float64
 819	for _, step := range resp.Steps {
 820		stepCost := a.openrouterCost(step.ProviderMetadata)
 821		if stepCost != nil {
 822			newCost := *stepCost
 823			if openrouterCost != nil {
 824				newCost += *openrouterCost
 825			}
 826			openrouterCost = &newCost
 827		}
 828	}
 829
 830	modelConfig := model.CatwalkCfg
 831	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 832		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 833		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 834		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 835
 836	// Use override cost if available (e.g., from OpenRouter).
 837	if openrouterCost != nil {
 838		cost = *openrouterCost
 839	}
 840
 841	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 842	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 843
 844	// Atomically update only title and usage fields to avoid overriding other
 845	// concurrent session updates.
 846	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 847	if saveErr != nil {
 848		slog.Error("failed to save session title and usage", "error", saveErr)
 849		return
 850	}
 851}
 852
 853func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 854	openrouterMetadata, ok := metadata[openrouter.Name]
 855	if !ok {
 856		return nil
 857	}
 858
 859	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 860	if !ok {
 861		return nil
 862	}
 863	return &opts.Usage.Cost
 864}
 865
 866func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 867	modelConfig := model.CatwalkCfg
 868	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 869		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 870		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 871		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 872
 873	a.eventTokensUsed(session.ID, model, usage, cost)
 874
 875	if overrideCost != nil {
 876		session.Cost += *overrideCost
 877	} else {
 878		session.Cost += cost
 879	}
 880
 881	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 882	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 883}
 884
 885func (a *sessionAgent) Cancel(sessionID string) {
 886	// Cancel regular requests.
 887	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 888		slog.Info("Request cancellation initiated", "session_id", sessionID)
 889		cancel()
 890	}
 891
 892	// Also check for summarize requests.
 893	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 894		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 895		cancel()
 896	}
 897
 898	if a.QueuedPrompts(sessionID) > 0 {
 899		slog.Info("Clearing queued prompts", "session_id", sessionID)
 900		a.messageQueue.Del(sessionID)
 901	}
 902}
 903
 904func (a *sessionAgent) ClearQueue(sessionID string) {
 905	if a.QueuedPrompts(sessionID) > 0 {
 906		slog.Info("Clearing queued prompts", "session_id", sessionID)
 907		a.messageQueue.Del(sessionID)
 908	}
 909}
 910
 911func (a *sessionAgent) CancelAll() {
 912	if !a.IsBusy() {
 913		return
 914	}
 915	for key := range a.activeRequests.Seq2() {
 916		a.Cancel(key) // key is sessionID
 917	}
 918
 919	timeout := time.After(5 * time.Second)
 920	for a.IsBusy() {
 921		select {
 922		case <-timeout:
 923			return
 924		default:
 925			time.Sleep(200 * time.Millisecond)
 926		}
 927	}
 928}
 929
 930func (a *sessionAgent) IsBusy() bool {
 931	var busy bool
 932	for cancelFunc := range a.activeRequests.Seq() {
 933		if cancelFunc != nil {
 934			busy = true
 935			break
 936		}
 937	}
 938	return busy
 939}
 940
 941func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 942	_, busy := a.activeRequests.Get(sessionID)
 943	return busy
 944}
 945
 946func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 947	l, ok := a.messageQueue.Get(sessionID)
 948	if !ok {
 949		return 0
 950	}
 951	return len(l)
 952}
 953
 954func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 955	l, ok := a.messageQueue.Get(sessionID)
 956	if !ok {
 957		return nil
 958	}
 959	prompts := make([]string, len(l))
 960	for i, call := range l {
 961		prompts[i] = call.Prompt
 962	}
 963	return prompts
 964}
 965
 966func (a *sessionAgent) SetModels(large Model, small Model) {
 967	a.largeModel = large
 968	a.smallModel = small
 969}
 970
 971func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 972	a.tools = tools
 973}
 974
 975func (a *sessionAgent) Model() Model {
 976	return a.largeModel
 977}
 978
 979func (a *sessionAgent) promptPrefix() string {
 980	return a.systemPromptPrefix
 981}
 982
 983// convertToToolResult converts a fantasy tool result to a message tool result.
 984func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 985	baseResult := message.ToolResult{
 986		ToolCallID: result.ToolCallID,
 987		Name:       result.ToolName,
 988		Metadata:   result.ClientMetadata,
 989	}
 990
 991	switch result.Result.GetType() {
 992	case fantasy.ToolResultContentTypeText:
 993		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 994			baseResult.Content = r.Text
 995		}
 996	case fantasy.ToolResultContentTypeError:
 997		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 998			baseResult.Content = r.Error.Error()
 999			baseResult.IsError = true
1000		}
1001	case fantasy.ToolResultContentTypeMedia:
1002		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1003			content := r.Text
1004			if content == "" {
1005				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1006			}
1007			baseResult.Content = content
1008			baseResult.Data = r.Data
1009			baseResult.MIMEType = r.MediaType
1010		}
1011	}
1012
1013	return baseResult
1014}
1015
1016// workaroundProviderMediaLimitations converts media content in tool results to
1017// user messages for providers that don't natively support images in tool results.
1018//
1019// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1020// don't support sending images/media in tool result messages - they only accept
1021// text in tool results. However, they DO support images in user messages.
1022//
1023// If we send media in tool results to these providers, the API returns an error.
1024//
1025// Solution: For these providers, we:
1026//  1. Replace the media in the tool result with a text placeholder
1027//  2. Inject a user message immediately after with the image as a file attachment
1028//  3. This maintains the tool execution flow while working around API limitations
1029//
1030// Anthropic and Bedrock support images natively in tool results, so we skip
1031// this workaround for them.
1032//
1033// Example transformation:
1034//
1035//	BEFORE: [tool result: image data]
1036//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1037func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1038	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1039		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1040
1041	if providerSupportsMedia {
1042		return messages
1043	}
1044
1045	convertedMessages := make([]fantasy.Message, 0, len(messages))
1046
1047	for _, msg := range messages {
1048		if msg.Role != fantasy.MessageRoleTool {
1049			convertedMessages = append(convertedMessages, msg)
1050			continue
1051		}
1052
1053		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1054		var mediaFiles []fantasy.FilePart
1055
1056		for _, part := range msg.Content {
1057			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1058			if !ok {
1059				textParts = append(textParts, part)
1060				continue
1061			}
1062
1063			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1064				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1065				if err != nil {
1066					slog.Warn("failed to decode media data", "error", err)
1067					textParts = append(textParts, part)
1068					continue
1069				}
1070
1071				mediaFiles = append(mediaFiles, fantasy.FilePart{
1072					Data:      decoded,
1073					MediaType: media.MediaType,
1074					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1075				})
1076
1077				textParts = append(textParts, fantasy.ToolResultPart{
1078					ToolCallID: toolResult.ToolCallID,
1079					Output: fantasy.ToolResultOutputContentText{
1080						Text: "[Image/media content loaded - see attached file]",
1081					},
1082					ProviderOptions: toolResult.ProviderOptions,
1083				})
1084			} else {
1085				textParts = append(textParts, part)
1086			}
1087		}
1088
1089		convertedMessages = append(convertedMessages, fantasy.Message{
1090			Role:    fantasy.MessageRoleTool,
1091			Content: textParts,
1092		})
1093
1094		if len(mediaFiles) > 0 {
1095			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1096				"Here is the media content from the tool result:",
1097				mediaFiles...,
1098			))
1099		}
1100	}
1101
1102	return convertedMessages
1103}