agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	Cancel(sessionID string)
  72	CancelAll()
  73	IsSessionBusy(sessionID string) bool
  74	IsBusy() bool
  75	QueuedPrompts(sessionID string) int
  76	QueuedPromptsList(sessionID string) []string
  77	ClearQueue(sessionID string)
  78	Summarize(context.Context, string, fantasy.ProviderOptions) error
  79	Model() Model
  80}
  81
  82type Model struct {
  83	Model      fantasy.LanguageModel
  84	CatwalkCfg catwalk.Model
  85	ModelCfg   config.SelectedModel
  86}
  87
  88type sessionAgent struct {
  89	largeModel           Model
  90	smallModel           Model
  91	systemPromptPrefix   string
  92	systemPrompt         string
  93	isSubAgent           bool
  94	tools                []fantasy.AgentTool
  95	sessions             session.Service
  96	messages             message.Service
  97	disableAutoSummarize bool
  98	isYolo               bool
  99
 100	messageQueue   *csync.Map[string, []SessionAgentCall]
 101	activeRequests *csync.Map[string, context.CancelFunc]
 102}
 103
 104type SessionAgentOptions struct {
 105	LargeModel           Model
 106	SmallModel           Model
 107	SystemPromptPrefix   string
 108	SystemPrompt         string
 109	IsSubAgent           bool
 110	DisableAutoSummarize bool
 111	IsYolo               bool
 112	Sessions             session.Service
 113	Messages             message.Service
 114	Tools                []fantasy.AgentTool
 115}
 116
 117func NewSessionAgent(
 118	opts SessionAgentOptions,
 119) SessionAgent {
 120	return &sessionAgent{
 121		largeModel:           opts.LargeModel,
 122		smallModel:           opts.SmallModel,
 123		systemPromptPrefix:   opts.SystemPromptPrefix,
 124		systemPrompt:         opts.SystemPrompt,
 125		isSubAgent:           opts.IsSubAgent,
 126		sessions:             opts.Sessions,
 127		messages:             opts.Messages,
 128		disableAutoSummarize: opts.DisableAutoSummarize,
 129		tools:                opts.Tools,
 130		isYolo:               opts.IsYolo,
 131		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 132		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 133	}
 134}
 135
 136func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 137	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 138		return nil, ErrEmptyPrompt
 139	}
 140	if call.SessionID == "" {
 141		return nil, ErrSessionMissing
 142	}
 143
 144	// Queue the message if busy
 145	if a.IsSessionBusy(call.SessionID) {
 146		existing, ok := a.messageQueue.Get(call.SessionID)
 147		if !ok {
 148			existing = []SessionAgentCall{}
 149		}
 150		existing = append(existing, call)
 151		a.messageQueue.Set(call.SessionID, existing)
 152		return nil, nil
 153	}
 154
 155	if len(a.tools) > 0 {
 156		// Add Anthropic caching to the last tool.
 157		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 158	}
 159
 160	agent := fantasy.NewAgent(
 161		a.largeModel.Model,
 162		fantasy.WithSystemPrompt(a.systemPrompt),
 163		fantasy.WithTools(a.tools...),
 164	)
 165
 166	sessionLock := sync.Mutex{}
 167	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 168	if err != nil {
 169		return nil, fmt.Errorf("failed to get session: %w", err)
 170	}
 171
 172	msgs, err := a.getSessionMessages(ctx, currentSession)
 173	if err != nil {
 174		return nil, fmt.Errorf("failed to get session messages: %w", err)
 175	}
 176
 177	var wg sync.WaitGroup
 178	// Generate title if first message.
 179	if len(msgs) == 0 {
 180		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 181		wg.Go(func() {
 182			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 183		})
 184	}
 185
 186	// Add the user message to the session.
 187	_, err = a.createUserMessage(ctx, call)
 188	if err != nil {
 189		return nil, err
 190	}
 191
 192	// Add the session to the context.
 193	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 194
 195	genCtx, cancel := context.WithCancel(ctx)
 196	a.activeRequests.Set(call.SessionID, cancel)
 197
 198	defer cancel()
 199	defer a.activeRequests.Del(call.SessionID)
 200
 201	history, files := a.preparePrompt(msgs, call.Attachments...)
 202
 203	startTime := time.Now()
 204	a.eventPromptSent(call.SessionID)
 205
 206	var currentAssistant *message.Message
 207	var shouldSummarize bool
 208	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 209		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 210		Files:            files,
 211		Messages:         history,
 212		ProviderOptions:  call.ProviderOptions,
 213		MaxOutputTokens:  &call.MaxOutputTokens,
 214		TopP:             call.TopP,
 215		Temperature:      call.Temperature,
 216		PresencePenalty:  call.PresencePenalty,
 217		TopK:             call.TopK,
 218		FrequencyPenalty: call.FrequencyPenalty,
 219		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 220			prepared.Messages = options.Messages
 221			for i := range prepared.Messages {
 222				prepared.Messages[i].ProviderOptions = nil
 223			}
 224
 225			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 226			a.messageQueue.Del(call.SessionID)
 227			for _, queued := range queuedCalls {
 228				userMessage, createErr := a.createUserMessage(callContext, queued)
 229				if createErr != nil {
 230					return callContext, prepared, createErr
 231				}
 232				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 233			}
 234
 235			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 236
 237			lastSystemRoleInx := 0
 238			systemMessageUpdated := false
 239			for i, msg := range prepared.Messages {
 240				// Only add cache control to the last message.
 241				if msg.Role == fantasy.MessageRoleSystem {
 242					lastSystemRoleInx = i
 243				} else if !systemMessageUpdated {
 244					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 245					systemMessageUpdated = true
 246				}
 247				// Than add cache control to the last 2 messages.
 248				if i > len(prepared.Messages)-3 {
 249					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 250				}
 251			}
 252
 253			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 254				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 255			}
 256
 257			var assistantMsg message.Message
 258			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 259				Role:     message.Assistant,
 260				Parts:    []message.ContentPart{},
 261				Model:    a.largeModel.ModelCfg.Model,
 262				Provider: a.largeModel.ModelCfg.Provider,
 263			})
 264			if err != nil {
 265				return callContext, prepared, err
 266			}
 267			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 268			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 269			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 270			currentAssistant = &assistantMsg
 271			return callContext, prepared, err
 272		},
 273		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 274			currentAssistant.AppendReasoningContent(reasoning.Text)
 275			return a.messages.Update(genCtx, *currentAssistant)
 276		},
 277		OnReasoningDelta: func(id string, text string) error {
 278			currentAssistant.AppendReasoningContent(text)
 279			return a.messages.Update(genCtx, *currentAssistant)
 280		},
 281		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 282			// handle anthropic signature
 283			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 284				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 285					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 286				}
 287			}
 288			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 289				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 290					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 291				}
 292			}
 293			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 294				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 295					currentAssistant.SetReasoningResponsesData(reasoning)
 296				}
 297			}
 298			currentAssistant.FinishThinking()
 299			return a.messages.Update(genCtx, *currentAssistant)
 300		},
 301		OnTextDelta: func(id string, text string) error {
 302			// Strip leading newline from initial text content. This is is
 303			// particularly important in non-interactive mode where leading
 304			// newlines are very visible.
 305			if len(currentAssistant.Parts) == 0 {
 306				text = strings.TrimPrefix(text, "\n")
 307			}
 308
 309			currentAssistant.AppendContent(text)
 310			return a.messages.Update(genCtx, *currentAssistant)
 311		},
 312		OnToolInputStart: func(id string, toolName string) error {
 313			toolCall := message.ToolCall{
 314				ID:               id,
 315				Name:             toolName,
 316				ProviderExecuted: false,
 317				Finished:         false,
 318			}
 319			currentAssistant.AddToolCall(toolCall)
 320			return a.messages.Update(genCtx, *currentAssistant)
 321		},
 322		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 323			// TODO: implement
 324		},
 325		OnToolCall: func(tc fantasy.ToolCallContent) error {
 326			toolCall := message.ToolCall{
 327				ID:               tc.ToolCallID,
 328				Name:             tc.ToolName,
 329				Input:            tc.Input,
 330				ProviderExecuted: false,
 331				Finished:         true,
 332			}
 333			currentAssistant.AddToolCall(toolCall)
 334			return a.messages.Update(genCtx, *currentAssistant)
 335		},
 336		OnToolResult: func(result fantasy.ToolResultContent) error {
 337			toolResult := a.convertToToolResult(result)
 338			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 339				Role: message.Tool,
 340				Parts: []message.ContentPart{
 341					toolResult,
 342				},
 343			})
 344			return createMsgErr
 345		},
 346		OnStepFinish: func(stepResult fantasy.StepResult) error {
 347			finishReason := message.FinishReasonUnknown
 348			switch stepResult.FinishReason {
 349			case fantasy.FinishReasonLength:
 350				finishReason = message.FinishReasonMaxTokens
 351			case fantasy.FinishReasonStop:
 352				finishReason = message.FinishReasonEndTurn
 353			case fantasy.FinishReasonToolCalls:
 354				finishReason = message.FinishReasonToolUse
 355			}
 356			currentAssistant.AddFinish(finishReason, "", "")
 357			sessionLock.Lock()
 358			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 359			if getSessionErr != nil {
 360				sessionLock.Unlock()
 361				return getSessionErr
 362			}
 363			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 364			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 365			sessionLock.Unlock()
 366			if sessionErr != nil {
 367				return sessionErr
 368			}
 369			return a.messages.Update(genCtx, *currentAssistant)
 370		},
 371		StopWhen: []fantasy.StopCondition{
 372			func(_ []fantasy.StepResult) bool {
 373				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 374				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 375				remaining := cw - tokens
 376				var threshold int64
 377				if cw > 200_000 {
 378					threshold = 20_000
 379				} else {
 380					threshold = int64(float64(cw) * 0.2)
 381				}
 382				if (remaining <= threshold) && !a.disableAutoSummarize {
 383					shouldSummarize = true
 384					return true
 385				}
 386				return false
 387			},
 388		},
 389	})
 390
 391	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 392
 393	if err != nil {
 394		isCancelErr := errors.Is(err, context.Canceled)
 395		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 396		if currentAssistant == nil {
 397			return result, err
 398		}
 399		// Ensure we finish thinking on error to close the reasoning state.
 400		currentAssistant.FinishThinking()
 401		toolCalls := currentAssistant.ToolCalls()
 402		// INFO: we use the parent context here because the genCtx has been cancelled.
 403		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 404		if createErr != nil {
 405			return nil, createErr
 406		}
 407		for _, tc := range toolCalls {
 408			if !tc.Finished {
 409				tc.Finished = true
 410				tc.Input = "{}"
 411				currentAssistant.AddToolCall(tc)
 412				updateErr := a.messages.Update(ctx, *currentAssistant)
 413				if updateErr != nil {
 414					return nil, updateErr
 415				}
 416			}
 417
 418			found := false
 419			for _, msg := range msgs {
 420				if msg.Role == message.Tool {
 421					for _, tr := range msg.ToolResults() {
 422						if tr.ToolCallID == tc.ID {
 423							found = true
 424							break
 425						}
 426					}
 427				}
 428				if found {
 429					break
 430				}
 431			}
 432			if found {
 433				continue
 434			}
 435			content := "There was an error while executing the tool"
 436			if isCancelErr {
 437				content = "Tool execution canceled by user"
 438			} else if isPermissionErr {
 439				content = "User denied permission"
 440			}
 441			toolResult := message.ToolResult{
 442				ToolCallID: tc.ID,
 443				Name:       tc.Name,
 444				Content:    content,
 445				IsError:    true,
 446			}
 447			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 448				Role: message.Tool,
 449				Parts: []message.ContentPart{
 450					toolResult,
 451				},
 452			})
 453			if createErr != nil {
 454				return nil, createErr
 455			}
 456		}
 457		var fantasyErr *fantasy.Error
 458		var providerErr *fantasy.ProviderError
 459		const defaultTitle = "Provider Error"
 460		if isCancelErr {
 461			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 462		} else if isPermissionErr {
 463			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 464		} else if errors.Is(err, hyper.ErrNoCredits) {
 465			url := hyper.BaseURL()
 466			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 467			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 468		} else if errors.As(err, &providerErr) {
 469			if providerErr.Message == "The requested model is not supported." {
 470				url := "https://github.com/settings/copilot/features"
 471				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 472				currentAssistant.AddFinish(
 473					message.FinishReasonError,
 474					"Copilot model not enabled",
 475					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 476				)
 477			} else {
 478				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 479			}
 480		} else if errors.As(err, &fantasyErr) {
 481			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 482		} else {
 483			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 484		}
 485		// Note: we use the parent context here because the genCtx has been
 486		// cancelled.
 487		updateErr := a.messages.Update(ctx, *currentAssistant)
 488		if updateErr != nil {
 489			return nil, updateErr
 490		}
 491		return nil, err
 492	}
 493	wg.Wait()
 494
 495	if shouldSummarize {
 496		a.activeRequests.Del(call.SessionID)
 497		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 498			return nil, summarizeErr
 499		}
 500		// If the agent wasn't done...
 501		if len(currentAssistant.ToolCalls()) > 0 {
 502			existing, ok := a.messageQueue.Get(call.SessionID)
 503			if !ok {
 504				existing = []SessionAgentCall{}
 505			}
 506			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 507			existing = append(existing, call)
 508			a.messageQueue.Set(call.SessionID, existing)
 509		}
 510	}
 511
 512	// Release active request before processing queued messages.
 513	a.activeRequests.Del(call.SessionID)
 514	cancel()
 515
 516	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 517	if !ok || len(queuedMessages) == 0 {
 518		return result, err
 519	}
 520	// There are queued messages restart the loop.
 521	firstQueuedMessage := queuedMessages[0]
 522	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 523	return a.Run(ctx, firstQueuedMessage)
 524}
 525
 526func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 527	if a.IsSessionBusy(sessionID) {
 528		return ErrSessionBusy
 529	}
 530
 531	currentSession, err := a.sessions.Get(ctx, sessionID)
 532	if err != nil {
 533		return fmt.Errorf("failed to get session: %w", err)
 534	}
 535	msgs, err := a.getSessionMessages(ctx, currentSession)
 536	if err != nil {
 537		return err
 538	}
 539	if len(msgs) == 0 {
 540		// Nothing to summarize.
 541		return nil
 542	}
 543
 544	aiMsgs, _ := a.preparePrompt(msgs)
 545
 546	genCtx, cancel := context.WithCancel(ctx)
 547	a.activeRequests.Set(sessionID, cancel)
 548	defer a.activeRequests.Del(sessionID)
 549	defer cancel()
 550
 551	agent := fantasy.NewAgent(a.largeModel.Model,
 552		fantasy.WithSystemPrompt(string(summaryPrompt)),
 553	)
 554	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 555		Role:             message.Assistant,
 556		Model:            a.largeModel.Model.Model(),
 557		Provider:         a.largeModel.Model.Provider(),
 558		IsSummaryMessage: true,
 559	})
 560	if err != nil {
 561		return err
 562	}
 563
 564	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 565
 566	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 567		Prompt:          summaryPromptText,
 568		Messages:        aiMsgs,
 569		ProviderOptions: opts,
 570		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 571			prepared.Messages = options.Messages
 572			if a.systemPromptPrefix != "" {
 573				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 574			}
 575			return callContext, prepared, nil
 576		},
 577		OnReasoningDelta: func(id string, text string) error {
 578			summaryMessage.AppendReasoningContent(text)
 579			return a.messages.Update(genCtx, summaryMessage)
 580		},
 581		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 582			// Handle anthropic signature.
 583			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 584				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 585					summaryMessage.AppendReasoningSignature(signature.Signature)
 586				}
 587			}
 588			summaryMessage.FinishThinking()
 589			return a.messages.Update(genCtx, summaryMessage)
 590		},
 591		OnTextDelta: func(id, text string) error {
 592			summaryMessage.AppendContent(text)
 593			return a.messages.Update(genCtx, summaryMessage)
 594		},
 595	})
 596	if err != nil {
 597		isCancelErr := errors.Is(err, context.Canceled)
 598		if isCancelErr {
 599			// User cancelled summarize we need to remove the summary message.
 600			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 601			return deleteErr
 602		}
 603		return err
 604	}
 605
 606	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 607	err = a.messages.Update(genCtx, summaryMessage)
 608	if err != nil {
 609		return err
 610	}
 611
 612	var openrouterCost *float64
 613	for _, step := range resp.Steps {
 614		stepCost := a.openrouterCost(step.ProviderMetadata)
 615		if stepCost != nil {
 616			newCost := *stepCost
 617			if openrouterCost != nil {
 618				newCost += *openrouterCost
 619			}
 620			openrouterCost = &newCost
 621		}
 622	}
 623
 624	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 625
 626	// Just in case, get just the last usage info.
 627	usage := resp.Response.Usage
 628	currentSession.SummaryMessageID = summaryMessage.ID
 629	currentSession.CompletionTokens = usage.OutputTokens
 630	currentSession.PromptTokens = 0
 631	_, err = a.sessions.Save(genCtx, currentSession)
 632	return err
 633}
 634
 635func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 636	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 637		return fantasy.ProviderOptions{}
 638	}
 639	return fantasy.ProviderOptions{
 640		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 641			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 642		},
 643		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 644			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 645		},
 646	}
 647}
 648
 649func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 650	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 651	var attachmentParts []message.ContentPart
 652	for _, attachment := range call.Attachments {
 653		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 654	}
 655	parts = append(parts, attachmentParts...)
 656	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 657		Role:  message.User,
 658		Parts: parts,
 659	})
 660	if err != nil {
 661		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 662	}
 663	return msg, nil
 664}
 665
 666func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 667	var history []fantasy.Message
 668	if !a.isSubAgent {
 669		history = append(history, fantasy.NewUserMessage(
 670			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 671				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 672If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 673If not, please feel free to ignore. Again do not mention this message to the user.`,
 674			),
 675		))
 676	}
 677	for _, m := range msgs {
 678		if len(m.Parts) == 0 {
 679			continue
 680		}
 681		// Assistant message without content or tool calls (cancelled before it
 682		// returned anything).
 683		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 684			continue
 685		}
 686		history = append(history, m.ToAIMessage()...)
 687	}
 688
 689	var files []fantasy.FilePart
 690	for _, attachment := range attachments {
 691		if attachment.IsText() {
 692			continue
 693		}
 694		files = append(files, fantasy.FilePart{
 695			Filename:  attachment.FileName,
 696			Data:      attachment.Content,
 697			MediaType: attachment.MimeType,
 698		})
 699	}
 700
 701	return history, files
 702}
 703
 704func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 705	msgs, err := a.messages.List(ctx, session.ID)
 706	if err != nil {
 707		return nil, fmt.Errorf("failed to list messages: %w", err)
 708	}
 709
 710	if session.SummaryMessageID != "" {
 711		summaryMsgInex := -1
 712		for i, msg := range msgs {
 713			if msg.ID == session.SummaryMessageID {
 714				summaryMsgInex = i
 715				break
 716			}
 717		}
 718		if summaryMsgInex != -1 {
 719			msgs = msgs[summaryMsgInex:]
 720			msgs[0].Role = message.User
 721		}
 722	}
 723	return msgs, nil
 724}
 725
 726// generateTitle generates a session titled based on the initial prompt.
 727func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 728	if userPrompt == "" {
 729		return
 730	}
 731
 732	var maxOutputTokens int64 = 40
 733	if a.smallModel.CatwalkCfg.CanReason {
 734		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 735	}
 736
 737	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 738		return fantasy.NewAgent(m,
 739			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 740			fantasy.WithMaxOutputTokens(tok),
 741		)
 742	}
 743
 744	streamCall := fantasy.AgentStreamCall{
 745		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 746		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 747			prepared.Messages = opts.Messages
 748			if a.systemPromptPrefix != "" {
 749				prepared.Messages = append([]fantasy.Message{
 750					fantasy.NewSystemMessage(a.systemPromptPrefix),
 751				}, prepared.Messages...)
 752			}
 753			return callCtx, prepared, nil
 754		},
 755	}
 756
 757	// Use the small model to generate the title.
 758	model := &a.smallModel
 759	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 760	resp, err := agent.Stream(ctx, streamCall)
 761	if err == nil {
 762		// We successfully generated a title with the small model.
 763		slog.Info("generated title with small model")
 764	} else {
 765		// It didn't work. Let's try with the big model.
 766		slog.Error("error generating title with small model; trying big model", "err", err)
 767		model = &a.largeModel
 768		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 769		resp, err = agent.Stream(ctx, streamCall)
 770		if err == nil {
 771			slog.Info("generated title with large model")
 772		} else {
 773			// Welp, the large model didn't work either. Use the default
 774			// session name and return.
 775			slog.Error("error generating title with large model", "err", err)
 776			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 777			if saveErr != nil {
 778				slog.Error("failed to save session title and usage", "error", saveErr)
 779			}
 780			return
 781		}
 782	}
 783
 784	if resp == nil {
 785		// Actually, we didn't get a response so we can't. Use the default
 786		// session name and return.
 787		slog.Error("response is nil; can't generate title")
 788		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 789		if saveErr != nil {
 790			slog.Error("failed to save session title and usage", "error", saveErr)
 791		}
 792		return
 793	}
 794
 795	// Clean up title.
 796	var title string
 797	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 798	slog.Info("generated title", "title", title)
 799
 800	// Remove thinking tags if present.
 801	title = thinkTagRegex.ReplaceAllString(title, "")
 802
 803	title = strings.TrimSpace(title)
 804	if title == "" {
 805		slog.Warn("empty title; using fallback")
 806		title = defaultSessionName
 807	}
 808
 809	// Calculate usage and cost.
 810	var openrouterCost *float64
 811	for _, step := range resp.Steps {
 812		stepCost := a.openrouterCost(step.ProviderMetadata)
 813		if stepCost != nil {
 814			newCost := *stepCost
 815			if openrouterCost != nil {
 816				newCost += *openrouterCost
 817			}
 818			openrouterCost = &newCost
 819		}
 820	}
 821
 822	modelConfig := model.CatwalkCfg
 823	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 824		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 825		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 826		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 827
 828	// Use override cost if available (e.g., from OpenRouter).
 829	if openrouterCost != nil {
 830		cost = *openrouterCost
 831	}
 832
 833	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 834	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 835
 836	// Atomically update only title and usage fields to avoid overriding other
 837	// concurrent session updates.
 838	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 839	if saveErr != nil {
 840		slog.Error("failed to save session title and usage", "error", saveErr)
 841		return
 842	}
 843}
 844
 845func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 846	openrouterMetadata, ok := metadata[openrouter.Name]
 847	if !ok {
 848		return nil
 849	}
 850
 851	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 852	if !ok {
 853		return nil
 854	}
 855	return &opts.Usage.Cost
 856}
 857
 858func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 859	modelConfig := model.CatwalkCfg
 860	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 861		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 862		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 863		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 864
 865	a.eventTokensUsed(session.ID, model, usage, cost)
 866
 867	if overrideCost != nil {
 868		session.Cost += *overrideCost
 869	} else {
 870		session.Cost += cost
 871	}
 872
 873	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 874	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 875}
 876
 877func (a *sessionAgent) Cancel(sessionID string) {
 878	// Cancel regular requests.
 879	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 880		slog.Info("Request cancellation initiated", "session_id", sessionID)
 881		cancel()
 882	}
 883
 884	// Also check for summarize requests.
 885	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 886		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 887		cancel()
 888	}
 889
 890	if a.QueuedPrompts(sessionID) > 0 {
 891		slog.Info("Clearing queued prompts", "session_id", sessionID)
 892		a.messageQueue.Del(sessionID)
 893	}
 894}
 895
 896func (a *sessionAgent) ClearQueue(sessionID string) {
 897	if a.QueuedPrompts(sessionID) > 0 {
 898		slog.Info("Clearing queued prompts", "session_id", sessionID)
 899		a.messageQueue.Del(sessionID)
 900	}
 901}
 902
 903func (a *sessionAgent) CancelAll() {
 904	if !a.IsBusy() {
 905		return
 906	}
 907	for key := range a.activeRequests.Seq2() {
 908		a.Cancel(key) // key is sessionID
 909	}
 910
 911	timeout := time.After(5 * time.Second)
 912	for a.IsBusy() {
 913		select {
 914		case <-timeout:
 915			return
 916		default:
 917			time.Sleep(200 * time.Millisecond)
 918		}
 919	}
 920}
 921
 922func (a *sessionAgent) IsBusy() bool {
 923	var busy bool
 924	for cancelFunc := range a.activeRequests.Seq() {
 925		if cancelFunc != nil {
 926			busy = true
 927			break
 928		}
 929	}
 930	return busy
 931}
 932
 933func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 934	_, busy := a.activeRequests.Get(sessionID)
 935	return busy
 936}
 937
 938func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 939	l, ok := a.messageQueue.Get(sessionID)
 940	if !ok {
 941		return 0
 942	}
 943	return len(l)
 944}
 945
 946func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 947	l, ok := a.messageQueue.Get(sessionID)
 948	if !ok {
 949		return nil
 950	}
 951	prompts := make([]string, len(l))
 952	for i, call := range l {
 953		prompts[i] = call.Prompt
 954	}
 955	return prompts
 956}
 957
 958func (a *sessionAgent) SetModels(large Model, small Model) {
 959	a.largeModel = large
 960	a.smallModel = small
 961}
 962
 963func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 964	a.tools = tools
 965}
 966
 967func (a *sessionAgent) Model() Model {
 968	return a.largeModel
 969}
 970
 971func (a *sessionAgent) promptPrefix() string {
 972	return a.systemPromptPrefix
 973}
 974
 975// convertToToolResult converts a fantasy tool result to a message tool result.
 976func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 977	baseResult := message.ToolResult{
 978		ToolCallID: result.ToolCallID,
 979		Name:       result.ToolName,
 980		Metadata:   result.ClientMetadata,
 981	}
 982
 983	switch result.Result.GetType() {
 984	case fantasy.ToolResultContentTypeText:
 985		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 986			baseResult.Content = r.Text
 987		}
 988	case fantasy.ToolResultContentTypeError:
 989		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 990			baseResult.Content = r.Error.Error()
 991			baseResult.IsError = true
 992		}
 993	case fantasy.ToolResultContentTypeMedia:
 994		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 995			content := r.Text
 996			if content == "" {
 997				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 998			}
 999			baseResult.Content = content
1000			baseResult.Data = r.Data
1001			baseResult.MIMEType = r.MediaType
1002		}
1003	}
1004
1005	return baseResult
1006}
1007
1008// workaroundProviderMediaLimitations converts media content in tool results to
1009// user messages for providers that don't natively support images in tool results.
1010//
1011// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1012// don't support sending images/media in tool result messages - they only accept
1013// text in tool results. However, they DO support images in user messages.
1014//
1015// If we send media in tool results to these providers, the API returns an error.
1016//
1017// Solution: For these providers, we:
1018//  1. Replace the media in the tool result with a text placeholder
1019//  2. Inject a user message immediately after with the image as a file attachment
1020//  3. This maintains the tool execution flow while working around API limitations
1021//
1022// Anthropic and Bedrock support images natively in tool results, so we skip
1023// this workaround for them.
1024//
1025// Example transformation:
1026//
1027//	BEFORE: [tool result: image data]
1028//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1029func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1030	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1031		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1032
1033	if providerSupportsMedia {
1034		return messages
1035	}
1036
1037	convertedMessages := make([]fantasy.Message, 0, len(messages))
1038
1039	for _, msg := range messages {
1040		if msg.Role != fantasy.MessageRoleTool {
1041			convertedMessages = append(convertedMessages, msg)
1042			continue
1043		}
1044
1045		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1046		var mediaFiles []fantasy.FilePart
1047
1048		for _, part := range msg.Content {
1049			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1050			if !ok {
1051				textParts = append(textParts, part)
1052				continue
1053			}
1054
1055			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1056				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1057				if err != nil {
1058					slog.Warn("failed to decode media data", "error", err)
1059					textParts = append(textParts, part)
1060					continue
1061				}
1062
1063				mediaFiles = append(mediaFiles, fantasy.FilePart{
1064					Data:      decoded,
1065					MediaType: media.MediaType,
1066					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1067				})
1068
1069				textParts = append(textParts, fantasy.ToolResultPart{
1070					ToolCallID: toolResult.ToolCallID,
1071					Output: fantasy.ToolResultOutputContentText{
1072						Text: "[Image/media content loaded - see attached file]",
1073					},
1074					ProviderOptions: toolResult.ProviderOptions,
1075				})
1076			} else {
1077				textParts = append(textParts, part)
1078			}
1079		}
1080
1081		convertedMessages = append(convertedMessages, fantasy.Message{
1082			Role:    fantasy.MessageRoleTool,
1083			Content: textParts,
1084		})
1085
1086		if len(mediaFiles) > 0 {
1087			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1088				"Here is the media content from the tool result:",
1089				mediaFiles...,
1090			))
1091		}
1092	}
1093
1094	return convertedMessages
1095}
1096
1097// buildSummaryPrompt constructs the prompt text for session summarization.
1098func buildSummaryPrompt(todos []session.Todo) string {
1099	var sb strings.Builder
1100	sb.WriteString("Provide a detailed summary of our conversation above.")
1101	if len(todos) > 0 {
1102		sb.WriteString("\n\n## Current Todo List\n\n")
1103		for _, t := range todos {
1104			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1105		}
1106		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1107		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1108	}
1109	return sb.String()
1110}