agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	SetSystemPrompt(systemPrompt string)
  72	Cancel(sessionID string)
  73	CancelAll()
  74	IsSessionBusy(sessionID string) bool
  75	IsBusy() bool
  76	QueuedPrompts(sessionID string) int
  77	QueuedPromptsList(sessionID string) []string
  78	ClearQueue(sessionID string)
  79	Summarize(context.Context, string, fantasy.ProviderOptions) error
  80	Model() Model
  81}
  82
  83type Model struct {
  84	Model      fantasy.LanguageModel
  85	CatwalkCfg catwalk.Model
  86	ModelCfg   config.SelectedModel
  87}
  88
  89type sessionAgent struct {
  90	largeModel           Model
  91	smallModel           Model
  92	systemPromptPrefix   string
  93	systemPrompt         string
  94	isSubAgent           bool
  95	tools                []fantasy.AgentTool
  96	sessions             session.Service
  97	messages             message.Service
  98	disableAutoSummarize bool
  99	isYolo               bool
 100
 101	messageQueue   *csync.Map[string, []SessionAgentCall]
 102	activeRequests *csync.Map[string, context.CancelFunc]
 103}
 104
 105type SessionAgentOptions struct {
 106	LargeModel           Model
 107	SmallModel           Model
 108	SystemPromptPrefix   string
 109	SystemPrompt         string
 110	IsSubAgent           bool
 111	DisableAutoSummarize bool
 112	IsYolo               bool
 113	Sessions             session.Service
 114	Messages             message.Service
 115	Tools                []fantasy.AgentTool
 116}
 117
 118func NewSessionAgent(
 119	opts SessionAgentOptions,
 120) SessionAgent {
 121	return &sessionAgent{
 122		largeModel:           opts.LargeModel,
 123		smallModel:           opts.SmallModel,
 124		systemPromptPrefix:   opts.SystemPromptPrefix,
 125		systemPrompt:         opts.SystemPrompt,
 126		isSubAgent:           opts.IsSubAgent,
 127		sessions:             opts.Sessions,
 128		messages:             opts.Messages,
 129		disableAutoSummarize: opts.DisableAutoSummarize,
 130		tools:                opts.Tools,
 131		isYolo:               opts.IsYolo,
 132		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 133		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 134	}
 135}
 136
 137func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 138	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 139		return nil, ErrEmptyPrompt
 140	}
 141	if call.SessionID == "" {
 142		return nil, ErrSessionMissing
 143	}
 144
 145	// Queue the message if busy
 146	if a.IsSessionBusy(call.SessionID) {
 147		existing, ok := a.messageQueue.Get(call.SessionID)
 148		if !ok {
 149			existing = []SessionAgentCall{}
 150		}
 151		existing = append(existing, call)
 152		a.messageQueue.Set(call.SessionID, existing)
 153		return nil, nil
 154	}
 155
 156	if len(a.tools) > 0 {
 157		// Add Anthropic caching to the last tool.
 158		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 159	}
 160
 161	agent := fantasy.NewAgent(
 162		a.largeModel.Model,
 163		fantasy.WithSystemPrompt(a.systemPrompt),
 164		fantasy.WithTools(a.tools...),
 165	)
 166
 167	sessionLock := sync.Mutex{}
 168	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 169	if err != nil {
 170		return nil, fmt.Errorf("failed to get session: %w", err)
 171	}
 172
 173	msgs, err := a.getSessionMessages(ctx, currentSession)
 174	if err != nil {
 175		return nil, fmt.Errorf("failed to get session messages: %w", err)
 176	}
 177
 178	var wg sync.WaitGroup
 179	// Generate title if first message.
 180	if len(msgs) == 0 {
 181		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 182		wg.Go(func() {
 183			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 184		})
 185	}
 186
 187	// Add the user message to the session.
 188	_, err = a.createUserMessage(ctx, call)
 189	if err != nil {
 190		return nil, err
 191	}
 192
 193	// Add the session to the context.
 194	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 195
 196	genCtx, cancel := context.WithCancel(ctx)
 197	a.activeRequests.Set(call.SessionID, cancel)
 198
 199	defer cancel()
 200	defer a.activeRequests.Del(call.SessionID)
 201
 202	history, files := a.preparePrompt(msgs, call.Attachments...)
 203
 204	startTime := time.Now()
 205	a.eventPromptSent(call.SessionID)
 206
 207	var currentAssistant *message.Message
 208	var shouldSummarize bool
 209	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 210		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 211		Files:            files,
 212		Messages:         history,
 213		ProviderOptions:  call.ProviderOptions,
 214		MaxOutputTokens:  &call.MaxOutputTokens,
 215		TopP:             call.TopP,
 216		Temperature:      call.Temperature,
 217		PresencePenalty:  call.PresencePenalty,
 218		TopK:             call.TopK,
 219		FrequencyPenalty: call.FrequencyPenalty,
 220		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 221			prepared.Messages = options.Messages
 222			for i := range prepared.Messages {
 223				prepared.Messages[i].ProviderOptions = nil
 224			}
 225
 226			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 227			a.messageQueue.Del(call.SessionID)
 228			for _, queued := range queuedCalls {
 229				userMessage, createErr := a.createUserMessage(callContext, queued)
 230				if createErr != nil {
 231					return callContext, prepared, createErr
 232				}
 233				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 234			}
 235
 236			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 237
 238			lastSystemRoleInx := 0
 239			systemMessageUpdated := false
 240			for i, msg := range prepared.Messages {
 241				// Only add cache control to the last message.
 242				if msg.Role == fantasy.MessageRoleSystem {
 243					lastSystemRoleInx = i
 244				} else if !systemMessageUpdated {
 245					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 246					systemMessageUpdated = true
 247				}
 248				// Than add cache control to the last 2 messages.
 249				if i > len(prepared.Messages)-3 {
 250					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 251				}
 252			}
 253
 254			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 255				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 256			}
 257
 258			var assistantMsg message.Message
 259			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 260				Role:     message.Assistant,
 261				Parts:    []message.ContentPart{},
 262				Model:    a.largeModel.ModelCfg.Model,
 263				Provider: a.largeModel.ModelCfg.Provider,
 264			})
 265			if err != nil {
 266				return callContext, prepared, err
 267			}
 268			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 269			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 270			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 271			currentAssistant = &assistantMsg
 272			return callContext, prepared, err
 273		},
 274		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 275			currentAssistant.AppendReasoningContent(reasoning.Text)
 276			return a.messages.Update(genCtx, *currentAssistant)
 277		},
 278		OnReasoningDelta: func(id string, text string) error {
 279			currentAssistant.AppendReasoningContent(text)
 280			return a.messages.Update(genCtx, *currentAssistant)
 281		},
 282		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 283			// handle anthropic signature
 284			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 285				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 286					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 287				}
 288			}
 289			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 290				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 291					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 292				}
 293			}
 294			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 295				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 296					currentAssistant.SetReasoningResponsesData(reasoning)
 297				}
 298			}
 299			currentAssistant.FinishThinking()
 300			return a.messages.Update(genCtx, *currentAssistant)
 301		},
 302		OnTextDelta: func(id string, text string) error {
 303			// Strip leading newline from initial text content. This is is
 304			// particularly important in non-interactive mode where leading
 305			// newlines are very visible.
 306			if len(currentAssistant.Parts) == 0 {
 307				text = strings.TrimPrefix(text, "\n")
 308			}
 309
 310			currentAssistant.AppendContent(text)
 311			return a.messages.Update(genCtx, *currentAssistant)
 312		},
 313		OnToolInputStart: func(id string, toolName string) error {
 314			toolCall := message.ToolCall{
 315				ID:               id,
 316				Name:             toolName,
 317				ProviderExecuted: false,
 318				Finished:         false,
 319			}
 320			currentAssistant.AddToolCall(toolCall)
 321			return a.messages.Update(genCtx, *currentAssistant)
 322		},
 323		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 324			// TODO: implement
 325		},
 326		OnToolCall: func(tc fantasy.ToolCallContent) error {
 327			toolCall := message.ToolCall{
 328				ID:               tc.ToolCallID,
 329				Name:             tc.ToolName,
 330				Input:            tc.Input,
 331				ProviderExecuted: false,
 332				Finished:         true,
 333			}
 334			currentAssistant.AddToolCall(toolCall)
 335			return a.messages.Update(genCtx, *currentAssistant)
 336		},
 337		OnToolResult: func(result fantasy.ToolResultContent) error {
 338			toolResult := a.convertToToolResult(result)
 339			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 340				Role: message.Tool,
 341				Parts: []message.ContentPart{
 342					toolResult,
 343				},
 344			})
 345			return createMsgErr
 346		},
 347		OnStepFinish: func(stepResult fantasy.StepResult) error {
 348			finishReason := message.FinishReasonUnknown
 349			switch stepResult.FinishReason {
 350			case fantasy.FinishReasonLength:
 351				finishReason = message.FinishReasonMaxTokens
 352			case fantasy.FinishReasonStop:
 353				finishReason = message.FinishReasonEndTurn
 354			case fantasy.FinishReasonToolCalls:
 355				finishReason = message.FinishReasonToolUse
 356			}
 357			currentAssistant.AddFinish(finishReason, "", "")
 358			sessionLock.Lock()
 359			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 360			if getSessionErr != nil {
 361				sessionLock.Unlock()
 362				return getSessionErr
 363			}
 364			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 365			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 366			sessionLock.Unlock()
 367			if sessionErr != nil {
 368				return sessionErr
 369			}
 370			return a.messages.Update(genCtx, *currentAssistant)
 371		},
 372		StopWhen: []fantasy.StopCondition{
 373			func(_ []fantasy.StepResult) bool {
 374				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 375				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 376				remaining := cw - tokens
 377				var threshold int64
 378				if cw > 200_000 {
 379					threshold = 20_000
 380				} else {
 381					threshold = int64(float64(cw) * 0.2)
 382				}
 383				if (remaining <= threshold) && !a.disableAutoSummarize {
 384					shouldSummarize = true
 385					return true
 386				}
 387				return false
 388			},
 389		},
 390	})
 391
 392	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 393
 394	if err != nil {
 395		isCancelErr := errors.Is(err, context.Canceled)
 396		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 397		if currentAssistant == nil {
 398			return result, err
 399		}
 400		// Ensure we finish thinking on error to close the reasoning state.
 401		currentAssistant.FinishThinking()
 402		toolCalls := currentAssistant.ToolCalls()
 403		// INFO: we use the parent context here because the genCtx has been cancelled.
 404		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 405		if createErr != nil {
 406			return nil, createErr
 407		}
 408		for _, tc := range toolCalls {
 409			if !tc.Finished {
 410				tc.Finished = true
 411				tc.Input = "{}"
 412				currentAssistant.AddToolCall(tc)
 413				updateErr := a.messages.Update(ctx, *currentAssistant)
 414				if updateErr != nil {
 415					return nil, updateErr
 416				}
 417			}
 418
 419			found := false
 420			for _, msg := range msgs {
 421				if msg.Role == message.Tool {
 422					for _, tr := range msg.ToolResults() {
 423						if tr.ToolCallID == tc.ID {
 424							found = true
 425							break
 426						}
 427					}
 428				}
 429				if found {
 430					break
 431				}
 432			}
 433			if found {
 434				continue
 435			}
 436			content := "There was an error while executing the tool"
 437			if isCancelErr {
 438				content = "Tool execution canceled by user"
 439			} else if isPermissionErr {
 440				content = "User denied permission"
 441			}
 442			toolResult := message.ToolResult{
 443				ToolCallID: tc.ID,
 444				Name:       tc.Name,
 445				Content:    content,
 446				IsError:    true,
 447			}
 448			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 449				Role: message.Tool,
 450				Parts: []message.ContentPart{
 451					toolResult,
 452				},
 453			})
 454			if createErr != nil {
 455				return nil, createErr
 456			}
 457		}
 458		var fantasyErr *fantasy.Error
 459		var providerErr *fantasy.ProviderError
 460		const defaultTitle = "Provider Error"
 461		if isCancelErr {
 462			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 463		} else if isPermissionErr {
 464			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 465		} else if errors.Is(err, hyper.ErrNoCredits) {
 466			url := hyper.BaseURL()
 467			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 468			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 469		} else if errors.As(err, &providerErr) {
 470			if providerErr.Message == "The requested model is not supported." {
 471				url := "https://github.com/settings/copilot/features"
 472				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 473				currentAssistant.AddFinish(
 474					message.FinishReasonError,
 475					"Copilot model not enabled",
 476					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 477				)
 478			} else {
 479				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 480			}
 481		} else if errors.As(err, &fantasyErr) {
 482			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 483		} else {
 484			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 485		}
 486		// Note: we use the parent context here because the genCtx has been
 487		// cancelled.
 488		updateErr := a.messages.Update(ctx, *currentAssistant)
 489		if updateErr != nil {
 490			return nil, updateErr
 491		}
 492		return nil, err
 493	}
 494	wg.Wait()
 495
 496	if shouldSummarize {
 497		a.activeRequests.Del(call.SessionID)
 498		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 499			return nil, summarizeErr
 500		}
 501		// If the agent wasn't done...
 502		if len(currentAssistant.ToolCalls()) > 0 {
 503			existing, ok := a.messageQueue.Get(call.SessionID)
 504			if !ok {
 505				existing = []SessionAgentCall{}
 506			}
 507			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 508			existing = append(existing, call)
 509			a.messageQueue.Set(call.SessionID, existing)
 510		}
 511	}
 512
 513	// Release active request before processing queued messages.
 514	a.activeRequests.Del(call.SessionID)
 515	cancel()
 516
 517	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 518	if !ok || len(queuedMessages) == 0 {
 519		return result, err
 520	}
 521	// There are queued messages restart the loop.
 522	firstQueuedMessage := queuedMessages[0]
 523	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 524	return a.Run(ctx, firstQueuedMessage)
 525}
 526
 527func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 528	if a.IsSessionBusy(sessionID) {
 529		return ErrSessionBusy
 530	}
 531
 532	currentSession, err := a.sessions.Get(ctx, sessionID)
 533	if err != nil {
 534		return fmt.Errorf("failed to get session: %w", err)
 535	}
 536	msgs, err := a.getSessionMessages(ctx, currentSession)
 537	if err != nil {
 538		return err
 539	}
 540	if len(msgs) == 0 {
 541		// Nothing to summarize.
 542		return nil
 543	}
 544
 545	aiMsgs, _ := a.preparePrompt(msgs)
 546
 547	genCtx, cancel := context.WithCancel(ctx)
 548	a.activeRequests.Set(sessionID, cancel)
 549	defer a.activeRequests.Del(sessionID)
 550	defer cancel()
 551
 552	agent := fantasy.NewAgent(a.largeModel.Model,
 553		fantasy.WithSystemPrompt(string(summaryPrompt)),
 554	)
 555	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 556		Role:             message.Assistant,
 557		Model:            a.largeModel.Model.Model(),
 558		Provider:         a.largeModel.Model.Provider(),
 559		IsSummaryMessage: true,
 560	})
 561	if err != nil {
 562		return err
 563	}
 564
 565	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 566
 567	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 568		Prompt:          summaryPromptText,
 569		Messages:        aiMsgs,
 570		ProviderOptions: opts,
 571		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 572			prepared.Messages = options.Messages
 573			if a.systemPromptPrefix != "" {
 574				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 575			}
 576			return callContext, prepared, nil
 577		},
 578		OnReasoningDelta: func(id string, text string) error {
 579			summaryMessage.AppendReasoningContent(text)
 580			return a.messages.Update(genCtx, summaryMessage)
 581		},
 582		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 583			// Handle anthropic signature.
 584			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 585				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 586					summaryMessage.AppendReasoningSignature(signature.Signature)
 587				}
 588			}
 589			summaryMessage.FinishThinking()
 590			return a.messages.Update(genCtx, summaryMessage)
 591		},
 592		OnTextDelta: func(id, text string) error {
 593			summaryMessage.AppendContent(text)
 594			return a.messages.Update(genCtx, summaryMessage)
 595		},
 596	})
 597	if err != nil {
 598		isCancelErr := errors.Is(err, context.Canceled)
 599		if isCancelErr {
 600			// User cancelled summarize we need to remove the summary message.
 601			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 602			return deleteErr
 603		}
 604		return err
 605	}
 606
 607	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 608	err = a.messages.Update(genCtx, summaryMessage)
 609	if err != nil {
 610		return err
 611	}
 612
 613	var openrouterCost *float64
 614	for _, step := range resp.Steps {
 615		stepCost := a.openrouterCost(step.ProviderMetadata)
 616		if stepCost != nil {
 617			newCost := *stepCost
 618			if openrouterCost != nil {
 619				newCost += *openrouterCost
 620			}
 621			openrouterCost = &newCost
 622		}
 623	}
 624
 625	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 626
 627	// Just in case, get just the last usage info.
 628	usage := resp.Response.Usage
 629	currentSession.SummaryMessageID = summaryMessage.ID
 630	currentSession.CompletionTokens = usage.OutputTokens
 631	currentSession.PromptTokens = 0
 632	_, err = a.sessions.Save(genCtx, currentSession)
 633	return err
 634}
 635
 636func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 637	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 638		return fantasy.ProviderOptions{}
 639	}
 640	return fantasy.ProviderOptions{
 641		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 642			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 643		},
 644		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 645			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 646		},
 647	}
 648}
 649
 650func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 651	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 652	var attachmentParts []message.ContentPart
 653	for _, attachment := range call.Attachments {
 654		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 655	}
 656	parts = append(parts, attachmentParts...)
 657	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 658		Role:  message.User,
 659		Parts: parts,
 660	})
 661	if err != nil {
 662		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 663	}
 664	return msg, nil
 665}
 666
 667func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 668	var history []fantasy.Message
 669	if !a.isSubAgent {
 670		history = append(history, fantasy.NewUserMessage(
 671			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 672				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 673If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 674If not, please feel free to ignore. Again do not mention this message to the user.`,
 675			),
 676		))
 677	}
 678	for _, m := range msgs {
 679		if len(m.Parts) == 0 {
 680			continue
 681		}
 682		// Assistant message without content or tool calls (cancelled before it
 683		// returned anything).
 684		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 685			continue
 686		}
 687		history = append(history, m.ToAIMessage()...)
 688	}
 689
 690	var files []fantasy.FilePart
 691	for _, attachment := range attachments {
 692		if attachment.IsText() {
 693			continue
 694		}
 695		files = append(files, fantasy.FilePart{
 696			Filename:  attachment.FileName,
 697			Data:      attachment.Content,
 698			MediaType: attachment.MimeType,
 699		})
 700	}
 701
 702	return history, files
 703}
 704
 705func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 706	msgs, err := a.messages.List(ctx, session.ID)
 707	if err != nil {
 708		return nil, fmt.Errorf("failed to list messages: %w", err)
 709	}
 710
 711	if session.SummaryMessageID != "" {
 712		summaryMsgInex := -1
 713		for i, msg := range msgs {
 714			if msg.ID == session.SummaryMessageID {
 715				summaryMsgInex = i
 716				break
 717			}
 718		}
 719		if summaryMsgInex != -1 {
 720			msgs = msgs[summaryMsgInex:]
 721			msgs[0].Role = message.User
 722		}
 723	}
 724	return msgs, nil
 725}
 726
 727// generateTitle generates a session titled based on the initial prompt.
 728func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 729	if userPrompt == "" {
 730		return
 731	}
 732
 733	var maxOutputTokens int64 = 40
 734	if a.smallModel.CatwalkCfg.CanReason {
 735		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 736	}
 737
 738	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 739		return fantasy.NewAgent(m,
 740			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 741			fantasy.WithMaxOutputTokens(tok),
 742		)
 743	}
 744
 745	streamCall := fantasy.AgentStreamCall{
 746		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 747		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 748			prepared.Messages = opts.Messages
 749			if a.systemPromptPrefix != "" {
 750				prepared.Messages = append([]fantasy.Message{
 751					fantasy.NewSystemMessage(a.systemPromptPrefix),
 752				}, prepared.Messages...)
 753			}
 754			return callCtx, prepared, nil
 755		},
 756	}
 757
 758	// Use the small model to generate the title.
 759	model := &a.smallModel
 760	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 761	resp, err := agent.Stream(ctx, streamCall)
 762	if err == nil {
 763		// We successfully generated a title with the small model.
 764		slog.Info("generated title with small model")
 765	} else {
 766		// It didn't work. Let's try with the big model.
 767		slog.Error("error generating title with small model; trying big model", "err", err)
 768		model = &a.largeModel
 769		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 770		resp, err = agent.Stream(ctx, streamCall)
 771		if err == nil {
 772			slog.Info("generated title with large model")
 773		} else {
 774			// Welp, the large model didn't work either. Use the default
 775			// session name and return.
 776			slog.Error("error generating title with large model", "err", err)
 777			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 778			if saveErr != nil {
 779				slog.Error("failed to save session title and usage", "error", saveErr)
 780			}
 781			return
 782		}
 783	}
 784
 785	if resp == nil {
 786		// Actually, we didn't get a response so we can't. Use the default
 787		// session name and return.
 788		slog.Error("response is nil; can't generate title")
 789		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 790		if saveErr != nil {
 791			slog.Error("failed to save session title and usage", "error", saveErr)
 792		}
 793		return
 794	}
 795
 796	// Clean up title.
 797	var title string
 798	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 799	slog.Info("generated title", "title", title)
 800
 801	// Remove thinking tags if present.
 802	title = thinkTagRegex.ReplaceAllString(title, "")
 803
 804	title = strings.TrimSpace(title)
 805	if title == "" {
 806		slog.Warn("empty title; using fallback")
 807		title = defaultSessionName
 808	}
 809
 810	// Calculate usage and cost.
 811	var openrouterCost *float64
 812	for _, step := range resp.Steps {
 813		stepCost := a.openrouterCost(step.ProviderMetadata)
 814		if stepCost != nil {
 815			newCost := *stepCost
 816			if openrouterCost != nil {
 817				newCost += *openrouterCost
 818			}
 819			openrouterCost = &newCost
 820		}
 821	}
 822
 823	modelConfig := model.CatwalkCfg
 824	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 825		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 826		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 827		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 828
 829	// Use override cost if available (e.g., from OpenRouter).
 830	if openrouterCost != nil {
 831		cost = *openrouterCost
 832	}
 833
 834	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 835	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 836
 837	// Atomically update only title and usage fields to avoid overriding other
 838	// concurrent session updates.
 839	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 840	if saveErr != nil {
 841		slog.Error("failed to save session title and usage", "error", saveErr)
 842		return
 843	}
 844}
 845
 846func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 847	openrouterMetadata, ok := metadata[openrouter.Name]
 848	if !ok {
 849		return nil
 850	}
 851
 852	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 853	if !ok {
 854		return nil
 855	}
 856	return &opts.Usage.Cost
 857}
 858
 859func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 860	modelConfig := model.CatwalkCfg
 861	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 862		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 863		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 864		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 865
 866	a.eventTokensUsed(session.ID, model, usage, cost)
 867
 868	if overrideCost != nil {
 869		session.Cost += *overrideCost
 870	} else {
 871		session.Cost += cost
 872	}
 873
 874	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 875	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 876}
 877
 878func (a *sessionAgent) Cancel(sessionID string) {
 879	// Cancel regular requests. Don't use Take() here - we need the entry to
 880	// remain in activeRequests so IsBusy() returns true until the goroutine
 881	// fully completes (including error handling that may access the DB).
 882	// The defer in processRequest will clean up the entry.
 883	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 884		slog.Info("Request cancellation initiated", "session_id", sessionID)
 885		cancel()
 886	}
 887
 888	// Also check for summarize requests.
 889	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 890		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 891		cancel()
 892	}
 893
 894	if a.QueuedPrompts(sessionID) > 0 {
 895		slog.Info("Clearing queued prompts", "session_id", sessionID)
 896		a.messageQueue.Del(sessionID)
 897	}
 898}
 899
 900func (a *sessionAgent) ClearQueue(sessionID string) {
 901	if a.QueuedPrompts(sessionID) > 0 {
 902		slog.Info("Clearing queued prompts", "session_id", sessionID)
 903		a.messageQueue.Del(sessionID)
 904	}
 905}
 906
 907func (a *sessionAgent) CancelAll() {
 908	if !a.IsBusy() {
 909		return
 910	}
 911	for key := range a.activeRequests.Seq2() {
 912		a.Cancel(key) // key is sessionID
 913	}
 914
 915	timeout := time.After(5 * time.Second)
 916	for a.IsBusy() {
 917		select {
 918		case <-timeout:
 919			return
 920		default:
 921			time.Sleep(200 * time.Millisecond)
 922		}
 923	}
 924}
 925
 926func (a *sessionAgent) IsBusy() bool {
 927	var busy bool
 928	for cancelFunc := range a.activeRequests.Seq() {
 929		if cancelFunc != nil {
 930			busy = true
 931			break
 932		}
 933	}
 934	return busy
 935}
 936
 937func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 938	_, busy := a.activeRequests.Get(sessionID)
 939	return busy
 940}
 941
 942func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 943	l, ok := a.messageQueue.Get(sessionID)
 944	if !ok {
 945		return 0
 946	}
 947	return len(l)
 948}
 949
 950func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 951	l, ok := a.messageQueue.Get(sessionID)
 952	if !ok {
 953		return nil
 954	}
 955	prompts := make([]string, len(l))
 956	for i, call := range l {
 957		prompts[i] = call.Prompt
 958	}
 959	return prompts
 960}
 961
 962func (a *sessionAgent) SetModels(large Model, small Model) {
 963	a.largeModel = large
 964	a.smallModel = small
 965}
 966
 967func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 968	a.tools = tools
 969}
 970
 971func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
 972	a.systemPrompt = systemPrompt
 973}
 974
 975func (a *sessionAgent) Model() Model {
 976	return a.largeModel
 977}
 978
 979func (a *sessionAgent) promptPrefix() string {
 980	return a.systemPromptPrefix
 981}
 982
 983// convertToToolResult converts a fantasy tool result to a message tool result.
 984func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 985	baseResult := message.ToolResult{
 986		ToolCallID: result.ToolCallID,
 987		Name:       result.ToolName,
 988		Metadata:   result.ClientMetadata,
 989	}
 990
 991	switch result.Result.GetType() {
 992	case fantasy.ToolResultContentTypeText:
 993		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 994			baseResult.Content = r.Text
 995		}
 996	case fantasy.ToolResultContentTypeError:
 997		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 998			baseResult.Content = r.Error.Error()
 999			baseResult.IsError = true
1000		}
1001	case fantasy.ToolResultContentTypeMedia:
1002		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1003			content := r.Text
1004			if content == "" {
1005				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1006			}
1007			baseResult.Content = content
1008			baseResult.Data = r.Data
1009			baseResult.MIMEType = r.MediaType
1010		}
1011	}
1012
1013	return baseResult
1014}
1015
1016// workaroundProviderMediaLimitations converts media content in tool results to
1017// user messages for providers that don't natively support images in tool results.
1018//
1019// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1020// don't support sending images/media in tool result messages - they only accept
1021// text in tool results. However, they DO support images in user messages.
1022//
1023// If we send media in tool results to these providers, the API returns an error.
1024//
1025// Solution: For these providers, we:
1026//  1. Replace the media in the tool result with a text placeholder
1027//  2. Inject a user message immediately after with the image as a file attachment
1028//  3. This maintains the tool execution flow while working around API limitations
1029//
1030// Anthropic and Bedrock support images natively in tool results, so we skip
1031// this workaround for them.
1032//
1033// Example transformation:
1034//
1035//	BEFORE: [tool result: image data]
1036//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1037func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1038	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1039		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1040
1041	if providerSupportsMedia {
1042		return messages
1043	}
1044
1045	convertedMessages := make([]fantasy.Message, 0, len(messages))
1046
1047	for _, msg := range messages {
1048		if msg.Role != fantasy.MessageRoleTool {
1049			convertedMessages = append(convertedMessages, msg)
1050			continue
1051		}
1052
1053		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1054		var mediaFiles []fantasy.FilePart
1055
1056		for _, part := range msg.Content {
1057			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1058			if !ok {
1059				textParts = append(textParts, part)
1060				continue
1061			}
1062
1063			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1064				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1065				if err != nil {
1066					slog.Warn("failed to decode media data", "error", err)
1067					textParts = append(textParts, part)
1068					continue
1069				}
1070
1071				mediaFiles = append(mediaFiles, fantasy.FilePart{
1072					Data:      decoded,
1073					MediaType: media.MediaType,
1074					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1075				})
1076
1077				textParts = append(textParts, fantasy.ToolResultPart{
1078					ToolCallID: toolResult.ToolCallID,
1079					Output: fantasy.ToolResultOutputContentText{
1080						Text: "[Image/media content loaded - see attached file]",
1081					},
1082					ProviderOptions: toolResult.ProviderOptions,
1083				})
1084			} else {
1085				textParts = append(textParts, part)
1086			}
1087		}
1088
1089		convertedMessages = append(convertedMessages, fantasy.Message{
1090			Role:    fantasy.MessageRoleTool,
1091			Content: textParts,
1092		})
1093
1094		if len(mediaFiles) > 0 {
1095			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1096				"Here is the media content from the tool result:",
1097				mediaFiles...,
1098			))
1099		}
1100	}
1101
1102	return convertedMessages
1103}
1104
1105// buildSummaryPrompt constructs the prompt text for session summarization.
1106func buildSummaryPrompt(todos []session.Todo) string {
1107	var sb strings.Builder
1108	sb.WriteString("Provide a detailed summary of our conversation above.")
1109	if len(todos) > 0 {
1110		sb.WriteString("\n\n## Current Todo List\n\n")
1111		for _, t := range todos {
1112			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1113		}
1114		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1115		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1116	}
1117	return sb.String()
1118}