agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/catwalk/pkg/catwalk"
  26	"charm.land/fantasy"
  27	"charm.land/fantasy/providers/anthropic"
  28	"charm.land/fantasy/providers/bedrock"
  29	"charm.land/fantasy/providers/google"
  30	"charm.land/fantasy/providers/openai"
  31	"charm.land/fantasy/providers/openrouter"
  32	"charm.land/fantasy/providers/vercel"
  33	"charm.land/lipgloss/v2"
  34	"github.com/charmbracelet/crush/internal/agent/hyper"
  35	"github.com/charmbracelet/crush/internal/agent/notify"
  36	"github.com/charmbracelet/crush/internal/agent/tools"
  37	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
  38	"github.com/charmbracelet/crush/internal/config"
  39	"github.com/charmbracelet/crush/internal/csync"
  40	"github.com/charmbracelet/crush/internal/message"
  41	"github.com/charmbracelet/crush/internal/permission"
  42	"github.com/charmbracelet/crush/internal/pubsub"
  43	"github.com/charmbracelet/crush/internal/session"
  44	"github.com/charmbracelet/crush/internal/stringext"
  45	"github.com/charmbracelet/crush/internal/version"
  46	"github.com/charmbracelet/x/exp/charmtone"
  47)
  48
  49const (
  50	DefaultSessionName = "Untitled Session"
  51
  52	// Constants for auto-summarization thresholds
  53	largeContextWindowThreshold = 200_000
  54	largeContextWindowBuffer    = 20_000
  55	smallContextWindowRatio     = 0.2
  56)
  57
  58//go:embed templates/title.md
  59var titlePrompt []byte
  60
  61//go:embed templates/summary.md
  62var summaryPrompt []byte
  63
  64// Used to remove <think> tags from generated titles.
  65var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  66
  67type SessionAgentCall struct {
  68	SessionID        string
  69	Prompt           string
  70	ProviderOptions  fantasy.ProviderOptions
  71	Attachments      []message.Attachment
  72	MaxOutputTokens  int64
  73	Temperature      *float64
  74	TopP             *float64
  75	TopK             *int64
  76	FrequencyPenalty *float64
  77	PresencePenalty  *float64
  78	NonInteractive   bool
  79}
  80
  81type SessionAgent interface {
  82	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  83	SetModels(large Model, small Model)
  84	SetTools(tools []fantasy.AgentTool)
  85	SetSystemPrompt(systemPrompt string)
  86	Cancel(sessionID string)
  87	CancelAll()
  88	IsSessionBusy(sessionID string) bool
  89	IsBusy() bool
  90	QueuedPrompts(sessionID string) int
  91	QueuedPromptsList(sessionID string) []string
  92	ClearQueue(sessionID string)
  93	Summarize(context.Context, string, fantasy.ProviderOptions) error
  94	Model() Model
  95}
  96
  97type Model struct {
  98	Model      fantasy.LanguageModel
  99	CatwalkCfg catwalk.Model
 100	ModelCfg   config.SelectedModel
 101}
 102
 103type sessionAgent struct {
 104	largeModel         *csync.Value[Model]
 105	smallModel         *csync.Value[Model]
 106	systemPromptPrefix *csync.Value[string]
 107	systemPrompt       *csync.Value[string]
 108	tools              *csync.Slice[fantasy.AgentTool]
 109
 110	isSubAgent           bool
 111	sessions             session.Service
 112	messages             message.Service
 113	disableAutoSummarize bool
 114	isYolo               bool
 115	notify               pubsub.Publisher[notify.Notification]
 116
 117	messageQueue   *csync.Map[string, []SessionAgentCall]
 118	activeRequests *csync.Map[string, context.CancelFunc]
 119}
 120
 121type SessionAgentOptions struct {
 122	LargeModel           Model
 123	SmallModel           Model
 124	SystemPromptPrefix   string
 125	SystemPrompt         string
 126	IsSubAgent           bool
 127	DisableAutoSummarize bool
 128	IsYolo               bool
 129	Sessions             session.Service
 130	Messages             message.Service
 131	Tools                []fantasy.AgentTool
 132	Notify               pubsub.Publisher[notify.Notification]
 133}
 134
 135func NewSessionAgent(
 136	opts SessionAgentOptions,
 137) SessionAgent {
 138	return &sessionAgent{
 139		largeModel:           csync.NewValue(opts.LargeModel),
 140		smallModel:           csync.NewValue(opts.SmallModel),
 141		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
 142		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 143		isSubAgent:           opts.IsSubAgent,
 144		sessions:             opts.Sessions,
 145		messages:             opts.Messages,
 146		disableAutoSummarize: opts.DisableAutoSummarize,
 147		tools:                csync.NewSliceFrom(opts.Tools),
 148		isYolo:               opts.IsYolo,
 149		notify:               opts.Notify,
 150		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 151		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 152	}
 153}
 154
 155func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 156	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 157		return nil, ErrEmptyPrompt
 158	}
 159	if call.SessionID == "" {
 160		return nil, ErrSessionMissing
 161	}
 162
 163	// Queue the message if busy
 164	if a.IsSessionBusy(call.SessionID) {
 165		existing, ok := a.messageQueue.Get(call.SessionID)
 166		if !ok {
 167			existing = []SessionAgentCall{}
 168		}
 169		existing = append(existing, call)
 170		a.messageQueue.Set(call.SessionID, existing)
 171		return nil, nil
 172	}
 173
 174	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
 175	agentTools := a.tools.Copy()
 176	largeModel := a.largeModel.Get()
 177	systemPrompt := a.systemPrompt.Get()
 178	promptPrefix := a.systemPromptPrefix.Get()
 179	var instructions strings.Builder
 180
 181	for _, server := range mcp.GetStates() {
 182		if server.State != mcp.StateConnected {
 183			continue
 184		}
 185		if s := server.Client.InitializeResult().Instructions; s != "" {
 186			instructions.WriteString(s)
 187			instructions.WriteString("\n\n")
 188		}
 189	}
 190
 191	if s := instructions.String(); s != "" {
 192		systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
 193	}
 194
 195	if len(agentTools) > 0 {
 196		// Add Anthropic caching to the last tool.
 197		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 198	}
 199
 200	agent := fantasy.NewAgent(
 201		largeModel.Model,
 202		fantasy.WithSystemPrompt(systemPrompt),
 203		fantasy.WithTools(agentTools...),
 204		fantasy.WithUserAgent("Charm Crush/"+version.Version),
 205	)
 206
 207	sessionLock := sync.Mutex{}
 208	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 209	if err != nil {
 210		return nil, fmt.Errorf("failed to get session: %w", err)
 211	}
 212
 213	msgs, err := a.getSessionMessages(ctx, currentSession)
 214	if err != nil {
 215		return nil, fmt.Errorf("failed to get session messages: %w", err)
 216	}
 217
 218	var wg sync.WaitGroup
 219	// Generate title if first message.
 220	if len(msgs) == 0 {
 221		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 222		wg.Go(func() {
 223			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 224		})
 225	}
 226	defer wg.Wait()
 227
 228	// Add the user message to the session.
 229	_, err = a.createUserMessage(ctx, call)
 230	if err != nil {
 231		return nil, err
 232	}
 233
 234	// Add the session to the context.
 235	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 236
 237	genCtx, cancel := context.WithCancel(ctx)
 238	a.activeRequests.Set(call.SessionID, cancel)
 239
 240	defer cancel()
 241	defer a.activeRequests.Del(call.SessionID)
 242
 243	history, files := a.preparePrompt(msgs, call.Attachments...)
 244
 245	startTime := time.Now()
 246	a.eventPromptSent(call.SessionID)
 247
 248	var currentAssistant *message.Message
 249	var shouldSummarize bool
 250	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 251		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 252		Files:            files,
 253		Messages:         history,
 254		ProviderOptions:  call.ProviderOptions,
 255		MaxOutputTokens:  &call.MaxOutputTokens,
 256		TopP:             call.TopP,
 257		Temperature:      call.Temperature,
 258		PresencePenalty:  call.PresencePenalty,
 259		TopK:             call.TopK,
 260		FrequencyPenalty: call.FrequencyPenalty,
 261		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 262			prepared.Messages = options.Messages
 263			for i := range prepared.Messages {
 264				prepared.Messages[i].ProviderOptions = nil
 265			}
 266
 267			// Use latest tools (updated by SetTools when MCP tools change).
 268			prepared.Tools = a.tools.Copy()
 269
 270			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 271			a.messageQueue.Del(call.SessionID)
 272			for _, queued := range queuedCalls {
 273				userMessage, createErr := a.createUserMessage(callContext, queued)
 274				if createErr != nil {
 275					return callContext, prepared, createErr
 276				}
 277				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 278			}
 279
 280			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 281
 282			lastSystemRoleInx := 0
 283			systemMessageUpdated := false
 284			for i, msg := range prepared.Messages {
 285				// Only add cache control to the last message.
 286				if msg.Role == fantasy.MessageRoleSystem {
 287					lastSystemRoleInx = i
 288				} else if !systemMessageUpdated {
 289					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 290					systemMessageUpdated = true
 291				}
 292				// Than add cache control to the last 2 messages.
 293				if i > len(prepared.Messages)-3 {
 294					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 295				}
 296			}
 297
 298			if promptPrefix != "" {
 299				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 300			}
 301
 302			var assistantMsg message.Message
 303			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 304				Role:     message.Assistant,
 305				Parts:    []message.ContentPart{},
 306				Model:    largeModel.ModelCfg.Model,
 307				Provider: largeModel.ModelCfg.Provider,
 308			})
 309			if err != nil {
 310				return callContext, prepared, err
 311			}
 312			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 313			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
 314			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 315			currentAssistant = &assistantMsg
 316			return callContext, prepared, err
 317		},
 318		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 319			currentAssistant.AppendReasoningContent(reasoning.Text)
 320			return a.messages.Update(genCtx, *currentAssistant)
 321		},
 322		OnReasoningDelta: func(id string, text string) error {
 323			currentAssistant.AppendReasoningContent(text)
 324			return a.messages.Update(genCtx, *currentAssistant)
 325		},
 326		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 327			// handle anthropic signature
 328			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 329				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 330					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 331				}
 332			}
 333			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 334				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 335					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 336				}
 337			}
 338			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 339				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 340					currentAssistant.SetReasoningResponsesData(reasoning)
 341				}
 342			}
 343			currentAssistant.FinishThinking()
 344			return a.messages.Update(genCtx, *currentAssistant)
 345		},
 346		OnTextDelta: func(id string, text string) error {
 347			// Strip leading newline from initial text content. This is is
 348			// particularly important in non-interactive mode where leading
 349			// newlines are very visible.
 350			if len(currentAssistant.Parts) == 0 {
 351				text = strings.TrimPrefix(text, "\n")
 352			}
 353
 354			currentAssistant.AppendContent(text)
 355			return a.messages.Update(genCtx, *currentAssistant)
 356		},
 357		OnToolInputStart: func(id string, toolName string) error {
 358			toolCall := message.ToolCall{
 359				ID:               id,
 360				Name:             toolName,
 361				ProviderExecuted: false,
 362				Finished:         false,
 363			}
 364			currentAssistant.AddToolCall(toolCall)
 365			return a.messages.Update(genCtx, *currentAssistant)
 366		},
 367		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 368			// TODO: implement
 369		},
 370		OnToolCall: func(tc fantasy.ToolCallContent) error {
 371			toolCall := message.ToolCall{
 372				ID:               tc.ToolCallID,
 373				Name:             tc.ToolName,
 374				Input:            tc.Input,
 375				ProviderExecuted: false,
 376				Finished:         true,
 377			}
 378			currentAssistant.AddToolCall(toolCall)
 379			return a.messages.Update(genCtx, *currentAssistant)
 380		},
 381		OnToolResult: func(result fantasy.ToolResultContent) error {
 382			toolResult := a.convertToToolResult(result)
 383			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 384				Role: message.Tool,
 385				Parts: []message.ContentPart{
 386					toolResult,
 387				},
 388			})
 389			return createMsgErr
 390		},
 391		OnStepFinish: func(stepResult fantasy.StepResult) error {
 392			finishReason := message.FinishReasonUnknown
 393			switch stepResult.FinishReason {
 394			case fantasy.FinishReasonLength:
 395				finishReason = message.FinishReasonMaxTokens
 396			case fantasy.FinishReasonStop:
 397				finishReason = message.FinishReasonEndTurn
 398			case fantasy.FinishReasonToolCalls:
 399				finishReason = message.FinishReasonToolUse
 400			}
 401			currentAssistant.AddFinish(finishReason, "", "")
 402			sessionLock.Lock()
 403			defer sessionLock.Unlock()
 404
 405			updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
 406			if getSessionErr != nil {
 407				return getSessionErr
 408			}
 409			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 410			_, sessionErr := a.sessions.Save(ctx, updatedSession)
 411			if sessionErr != nil {
 412				return sessionErr
 413			}
 414			currentSession = updatedSession
 415			return a.messages.Update(genCtx, *currentAssistant)
 416		},
 417		StopWhen: []fantasy.StopCondition{
 418			func(_ []fantasy.StepResult) bool {
 419				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 420				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 421				remaining := cw - tokens
 422				var threshold int64
 423				if cw > largeContextWindowThreshold {
 424					threshold = largeContextWindowBuffer
 425				} else {
 426					threshold = int64(float64(cw) * smallContextWindowRatio)
 427				}
 428				if (remaining <= threshold) && !a.disableAutoSummarize {
 429					shouldSummarize = true
 430					return true
 431				}
 432				return false
 433			},
 434			func(steps []fantasy.StepResult) bool {
 435				return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
 436			},
 437		},
 438	})
 439
 440	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 441
 442	if err != nil {
 443		isCancelErr := errors.Is(err, context.Canceled)
 444		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 445		if currentAssistant == nil {
 446			return result, err
 447		}
 448		// Ensure we finish thinking on error to close the reasoning state.
 449		currentAssistant.FinishThinking()
 450		toolCalls := currentAssistant.ToolCalls()
 451		// INFO: we use the parent context here because the genCtx has been cancelled.
 452		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 453		if createErr != nil {
 454			return nil, createErr
 455		}
 456		for _, tc := range toolCalls {
 457			if !tc.Finished {
 458				tc.Finished = true
 459				tc.Input = "{}"
 460				currentAssistant.AddToolCall(tc)
 461				updateErr := a.messages.Update(ctx, *currentAssistant)
 462				if updateErr != nil {
 463					return nil, updateErr
 464				}
 465			}
 466
 467			found := false
 468			for _, msg := range msgs {
 469				if msg.Role == message.Tool {
 470					for _, tr := range msg.ToolResults() {
 471						if tr.ToolCallID == tc.ID {
 472							found = true
 473							break
 474						}
 475					}
 476				}
 477				if found {
 478					break
 479				}
 480			}
 481			if found {
 482				continue
 483			}
 484			content := "There was an error while executing the tool"
 485			if isCancelErr {
 486				content = "Tool execution canceled by user"
 487			} else if isPermissionErr {
 488				content = "User denied permission"
 489			}
 490			toolResult := message.ToolResult{
 491				ToolCallID: tc.ID,
 492				Name:       tc.Name,
 493				Content:    content,
 494				IsError:    true,
 495			}
 496			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 497				Role: message.Tool,
 498				Parts: []message.ContentPart{
 499					toolResult,
 500				},
 501			})
 502			if createErr != nil {
 503				return nil, createErr
 504			}
 505		}
 506		var fantasyErr *fantasy.Error
 507		var providerErr *fantasy.ProviderError
 508		const defaultTitle = "Provider Error"
 509		linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
 510		if isCancelErr {
 511			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 512		} else if isPermissionErr {
 513			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 514		} else if errors.Is(err, hyper.ErrNoCredits) {
 515			url := hyper.BaseURL()
 516			link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
 517			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 518		} else if errors.As(err, &providerErr) {
 519			if providerErr.Message == "The requested model is not supported." {
 520				url := "https://github.com/settings/copilot/features"
 521				link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
 522				currentAssistant.AddFinish(
 523					message.FinishReasonError,
 524					"Copilot model not enabled",
 525					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 526				)
 527			} else {
 528				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 529			}
 530		} else if errors.As(err, &fantasyErr) {
 531			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 532		} else {
 533			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 534		}
 535		// Note: we use the parent context here because the genCtx has been
 536		// cancelled.
 537		updateErr := a.messages.Update(ctx, *currentAssistant)
 538		if updateErr != nil {
 539			return nil, updateErr
 540		}
 541		return nil, err
 542	}
 543
 544	// Send notification that agent has finished its turn (skip for
 545	// nested/non-interactive sessions).
 546	if !call.NonInteractive && a.notify != nil {
 547		a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
 548			SessionID:    call.SessionID,
 549			SessionTitle: currentSession.Title,
 550			Type:         notify.TypeAgentFinished,
 551		})
 552	}
 553
 554	if shouldSummarize {
 555		a.activeRequests.Del(call.SessionID)
 556		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 557			return nil, summarizeErr
 558		}
 559		// If the agent wasn't done...
 560		if len(currentAssistant.ToolCalls()) > 0 {
 561			existing, ok := a.messageQueue.Get(call.SessionID)
 562			if !ok {
 563				existing = []SessionAgentCall{}
 564			}
 565			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 566			existing = append(existing, call)
 567			a.messageQueue.Set(call.SessionID, existing)
 568		}
 569	}
 570
 571	// Release active request before processing queued messages.
 572	a.activeRequests.Del(call.SessionID)
 573	cancel()
 574
 575	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 576	if !ok || len(queuedMessages) == 0 {
 577		return result, err
 578	}
 579	// There are queued messages restart the loop.
 580	firstQueuedMessage := queuedMessages[0]
 581	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 582	return a.Run(ctx, firstQueuedMessage)
 583}
 584
 585func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 586	if a.IsSessionBusy(sessionID) {
 587		return ErrSessionBusy
 588	}
 589
 590	// Copy mutable fields under lock to avoid races with SetModels.
 591	largeModel := a.largeModel.Get()
 592	systemPromptPrefix := a.systemPromptPrefix.Get()
 593
 594	currentSession, err := a.sessions.Get(ctx, sessionID)
 595	if err != nil {
 596		return fmt.Errorf("failed to get session: %w", err)
 597	}
 598	msgs, err := a.getSessionMessages(ctx, currentSession)
 599	if err != nil {
 600		return err
 601	}
 602	if len(msgs) == 0 {
 603		// Nothing to summarize.
 604		return nil
 605	}
 606
 607	aiMsgs, _ := a.preparePrompt(msgs)
 608
 609	genCtx, cancel := context.WithCancel(ctx)
 610	a.activeRequests.Set(sessionID, cancel)
 611	defer a.activeRequests.Del(sessionID)
 612	defer cancel()
 613
 614	agent := fantasy.NewAgent(largeModel.Model,
 615		fantasy.WithSystemPrompt(string(summaryPrompt)),
 616		fantasy.WithUserAgent("Charm Crush/"+version.Version),
 617	)
 618	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 619		Role:             message.Assistant,
 620		Model:            largeModel.Model.Model(),
 621		Provider:         largeModel.Model.Provider(),
 622		IsSummaryMessage: true,
 623	})
 624	if err != nil {
 625		return err
 626	}
 627
 628	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 629
 630	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 631		Prompt:          summaryPromptText,
 632		Messages:        aiMsgs,
 633		ProviderOptions: opts,
 634		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 635			prepared.Messages = options.Messages
 636			if systemPromptPrefix != "" {
 637				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 638			}
 639			return callContext, prepared, nil
 640		},
 641		OnReasoningDelta: func(id string, text string) error {
 642			summaryMessage.AppendReasoningContent(text)
 643			return a.messages.Update(genCtx, summaryMessage)
 644		},
 645		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 646			// Handle anthropic signature.
 647			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 648				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 649					summaryMessage.AppendReasoningSignature(signature.Signature)
 650				}
 651			}
 652			summaryMessage.FinishThinking()
 653			return a.messages.Update(genCtx, summaryMessage)
 654		},
 655		OnTextDelta: func(id, text string) error {
 656			summaryMessage.AppendContent(text)
 657			return a.messages.Update(genCtx, summaryMessage)
 658		},
 659	})
 660	if err != nil {
 661		isCancelErr := errors.Is(err, context.Canceled)
 662		if isCancelErr {
 663			// User cancelled summarize we need to remove the summary message.
 664			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 665			return deleteErr
 666		}
 667		return err
 668	}
 669
 670	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 671	err = a.messages.Update(genCtx, summaryMessage)
 672	if err != nil {
 673		return err
 674	}
 675
 676	var openrouterCost *float64
 677	for _, step := range resp.Steps {
 678		stepCost := a.openrouterCost(step.ProviderMetadata)
 679		if stepCost != nil {
 680			newCost := *stepCost
 681			if openrouterCost != nil {
 682				newCost += *openrouterCost
 683			}
 684			openrouterCost = &newCost
 685		}
 686	}
 687
 688	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 689
 690	// Just in case, get just the last usage info.
 691	usage := resp.Response.Usage
 692	currentSession.SummaryMessageID = summaryMessage.ID
 693	currentSession.CompletionTokens = usage.OutputTokens
 694	currentSession.PromptTokens = 0
 695	_, err = a.sessions.Save(genCtx, currentSession)
 696	return err
 697}
 698
 699func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 700	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 701		return fantasy.ProviderOptions{}
 702	}
 703	return fantasy.ProviderOptions{
 704		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 705			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 706		},
 707		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 708			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 709		},
 710		vercel.Name: &anthropic.ProviderCacheControlOptions{
 711			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 712		},
 713	}
 714}
 715
 716func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 717	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 718	var attachmentParts []message.ContentPart
 719	for _, attachment := range call.Attachments {
 720		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 721	}
 722	parts = append(parts, attachmentParts...)
 723	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 724		Role:  message.User,
 725		Parts: parts,
 726	})
 727	if err != nil {
 728		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 729	}
 730	return msg, nil
 731}
 732
 733func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 734	var history []fantasy.Message
 735	if !a.isSubAgent {
 736		history = append(history, fantasy.NewUserMessage(
 737			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 738				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 739If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 740If not, please feel free to ignore. Again do not mention this message to the user.`,
 741			),
 742		))
 743	}
 744	for _, m := range msgs {
 745		if len(m.Parts) == 0 {
 746			continue
 747		}
 748		// Assistant message without content or tool calls (cancelled before it
 749		// returned anything).
 750		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 751			continue
 752		}
 753		history = append(history, m.ToAIMessage()...)
 754	}
 755
 756	var files []fantasy.FilePart
 757	for _, attachment := range attachments {
 758		if attachment.IsText() {
 759			continue
 760		}
 761		files = append(files, fantasy.FilePart{
 762			Filename:  attachment.FileName,
 763			Data:      attachment.Content,
 764			MediaType: attachment.MimeType,
 765		})
 766	}
 767
 768	return history, files
 769}
 770
 771func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 772	msgs, err := a.messages.List(ctx, session.ID)
 773	if err != nil {
 774		return nil, fmt.Errorf("failed to list messages: %w", err)
 775	}
 776
 777	if session.SummaryMessageID != "" {
 778		summaryMsgIndex := -1
 779		for i, msg := range msgs {
 780			if msg.ID == session.SummaryMessageID {
 781				summaryMsgIndex = i
 782				break
 783			}
 784		}
 785		if summaryMsgIndex != -1 {
 786			msgs = msgs[summaryMsgIndex:]
 787			msgs[0].Role = message.User
 788		}
 789	}
 790	return msgs, nil
 791}
 792
 793// generateTitle generates a session titled based on the initial prompt.
 794func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 795	if userPrompt == "" {
 796		return
 797	}
 798
 799	smallModel := a.smallModel.Get()
 800	largeModel := a.largeModel.Get()
 801	systemPromptPrefix := a.systemPromptPrefix.Get()
 802
 803	var maxOutputTokens int64 = 40
 804	if smallModel.CatwalkCfg.CanReason {
 805		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 806	}
 807
 808	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 809		return fantasy.NewAgent(m,
 810			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 811			fantasy.WithMaxOutputTokens(tok),
 812			fantasy.WithUserAgent("Charm Crush/"+version.Version),
 813		)
 814	}
 815
 816	streamCall := fantasy.AgentStreamCall{
 817		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 818		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 819			prepared.Messages = opts.Messages
 820			if systemPromptPrefix != "" {
 821				prepared.Messages = append([]fantasy.Message{
 822					fantasy.NewSystemMessage(systemPromptPrefix),
 823				}, prepared.Messages...)
 824			}
 825			return callCtx, prepared, nil
 826		},
 827	}
 828
 829	// Use the small model to generate the title.
 830	model := smallModel
 831	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 832	resp, err := agent.Stream(ctx, streamCall)
 833	if err == nil {
 834		// We successfully generated a title with the small model.
 835		slog.Debug("Generated title with small model")
 836	} else {
 837		// It didn't work. Let's try with the big model.
 838		slog.Error("Error generating title with small model; trying big model", "err", err)
 839		model = largeModel
 840		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 841		resp, err = agent.Stream(ctx, streamCall)
 842		if err == nil {
 843			slog.Debug("Generated title with large model")
 844		} else {
 845			// Welp, the large model didn't work either. Use the default
 846			// session name and return.
 847			slog.Error("Error generating title with large model", "err", err)
 848			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
 849			if saveErr != nil {
 850				slog.Error("Failed to save session title and usage", "error", saveErr)
 851			}
 852			return
 853		}
 854	}
 855
 856	if resp == nil {
 857		// Actually, we didn't get a response so we can't. Use the default
 858		// session name and return.
 859		slog.Error("Response is nil; can't generate title")
 860		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
 861		if saveErr != nil {
 862			slog.Error("Failed to save session title and usage", "error", saveErr)
 863		}
 864		return
 865	}
 866
 867	// Clean up title.
 868	var title string
 869	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 870
 871	// Remove thinking tags if present.
 872	title = thinkTagRegex.ReplaceAllString(title, "")
 873
 874	title = strings.TrimSpace(title)
 875	title = cmp.Or(title, DefaultSessionName)
 876
 877	// Calculate usage and cost.
 878	var openrouterCost *float64
 879	for _, step := range resp.Steps {
 880		stepCost := a.openrouterCost(step.ProviderMetadata)
 881		if stepCost != nil {
 882			newCost := *stepCost
 883			if openrouterCost != nil {
 884				newCost += *openrouterCost
 885			}
 886			openrouterCost = &newCost
 887		}
 888	}
 889
 890	modelConfig := model.CatwalkCfg
 891	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 892		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 893		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 894		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 895
 896	// Use override cost if available (e.g., from OpenRouter).
 897	if openrouterCost != nil {
 898		cost = *openrouterCost
 899	}
 900
 901	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 902	completionTokens := resp.TotalUsage.OutputTokens
 903
 904	// Atomically update only title and usage fields to avoid overriding other
 905	// concurrent session updates.
 906	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 907	if saveErr != nil {
 908		slog.Error("Failed to save session title and usage", "error", saveErr)
 909		return
 910	}
 911}
 912
 913func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 914	openrouterMetadata, ok := metadata[openrouter.Name]
 915	if !ok {
 916		return nil
 917	}
 918
 919	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 920	if !ok {
 921		return nil
 922	}
 923	return &opts.Usage.Cost
 924}
 925
 926func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 927	modelConfig := model.CatwalkCfg
 928	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 929		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 930		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 931		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 932
 933	a.eventTokensUsed(session.ID, model, usage, cost)
 934
 935	if overrideCost != nil {
 936		session.Cost += *overrideCost
 937	} else {
 938		session.Cost += cost
 939	}
 940
 941	session.CompletionTokens = usage.OutputTokens
 942	session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
 943}
 944
 945func (a *sessionAgent) Cancel(sessionID string) {
 946	// Cancel regular requests. Don't use Take() here - we need the entry to
 947	// remain in activeRequests so IsBusy() returns true until the goroutine
 948	// fully completes (including error handling that may access the DB).
 949	// The defer in processRequest will clean up the entry.
 950	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 951		slog.Debug("Request cancellation initiated", "session_id", sessionID)
 952		cancel()
 953	}
 954
 955	// Also check for summarize requests.
 956	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 957		slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
 958		cancel()
 959	}
 960
 961	if a.QueuedPrompts(sessionID) > 0 {
 962		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 963		a.messageQueue.Del(sessionID)
 964	}
 965}
 966
 967func (a *sessionAgent) ClearQueue(sessionID string) {
 968	if a.QueuedPrompts(sessionID) > 0 {
 969		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 970		a.messageQueue.Del(sessionID)
 971	}
 972}
 973
 974func (a *sessionAgent) CancelAll() {
 975	if !a.IsBusy() {
 976		return
 977	}
 978	for key := range a.activeRequests.Seq2() {
 979		a.Cancel(key) // key is sessionID
 980	}
 981
 982	timeout := time.After(5 * time.Second)
 983	for a.IsBusy() {
 984		select {
 985		case <-timeout:
 986			return
 987		default:
 988			time.Sleep(200 * time.Millisecond)
 989		}
 990	}
 991}
 992
 993func (a *sessionAgent) IsBusy() bool {
 994	var busy bool
 995	for cancelFunc := range a.activeRequests.Seq() {
 996		if cancelFunc != nil {
 997			busy = true
 998			break
 999		}
1000	}
1001	return busy
1002}
1003
1004func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1005	_, busy := a.activeRequests.Get(sessionID)
1006	return busy
1007}
1008
1009func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1010	l, ok := a.messageQueue.Get(sessionID)
1011	if !ok {
1012		return 0
1013	}
1014	return len(l)
1015}
1016
1017func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1018	l, ok := a.messageQueue.Get(sessionID)
1019	if !ok {
1020		return nil
1021	}
1022	prompts := make([]string, len(l))
1023	for i, call := range l {
1024		prompts[i] = call.Prompt
1025	}
1026	return prompts
1027}
1028
1029func (a *sessionAgent) SetModels(large Model, small Model) {
1030	a.largeModel.Set(large)
1031	a.smallModel.Set(small)
1032}
1033
1034func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1035	a.tools.SetSlice(tools)
1036}
1037
1038func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1039	a.systemPrompt.Set(systemPrompt)
1040}
1041
1042func (a *sessionAgent) Model() Model {
1043	return a.largeModel.Get()
1044}
1045
1046// convertToToolResult converts a fantasy tool result to a message tool result.
1047func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1048	baseResult := message.ToolResult{
1049		ToolCallID: result.ToolCallID,
1050		Name:       result.ToolName,
1051		Metadata:   result.ClientMetadata,
1052	}
1053
1054	switch result.Result.GetType() {
1055	case fantasy.ToolResultContentTypeText:
1056		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1057			baseResult.Content = r.Text
1058		}
1059	case fantasy.ToolResultContentTypeError:
1060		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1061			baseResult.Content = r.Error.Error()
1062			baseResult.IsError = true
1063		}
1064	case fantasy.ToolResultContentTypeMedia:
1065		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1066			content := r.Text
1067			if content == "" {
1068				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1069			}
1070			baseResult.Content = content
1071			baseResult.Data = r.Data
1072			baseResult.MIMEType = r.MediaType
1073		}
1074	}
1075
1076	return baseResult
1077}
1078
1079// workaroundProviderMediaLimitations converts media content in tool results to
1080// user messages for providers that don't natively support images in tool results.
1081//
1082// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1083// don't support sending images/media in tool result messages - they only accept
1084// text in tool results. However, they DO support images in user messages.
1085//
1086// If we send media in tool results to these providers, the API returns an error.
1087//
1088// Solution: For these providers, we:
1089//  1. Replace the media in the tool result with a text placeholder
1090//  2. Inject a user message immediately after with the image as a file attachment
1091//  3. This maintains the tool execution flow while working around API limitations
1092//
1093// Anthropic and Bedrock support images natively in tool results, so we skip
1094// this workaround for them.
1095//
1096// Example transformation:
1097//
1098//	BEFORE: [tool result: image data]
1099//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1100func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1101	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1102		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1103
1104	if providerSupportsMedia {
1105		return messages
1106	}
1107
1108	convertedMessages := make([]fantasy.Message, 0, len(messages))
1109
1110	for _, msg := range messages {
1111		if msg.Role != fantasy.MessageRoleTool {
1112			convertedMessages = append(convertedMessages, msg)
1113			continue
1114		}
1115
1116		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1117		var mediaFiles []fantasy.FilePart
1118
1119		for _, part := range msg.Content {
1120			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1121			if !ok {
1122				textParts = append(textParts, part)
1123				continue
1124			}
1125
1126			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1127				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1128				if err != nil {
1129					slog.Warn("Failed to decode media data", "error", err)
1130					textParts = append(textParts, part)
1131					continue
1132				}
1133
1134				mediaFiles = append(mediaFiles, fantasy.FilePart{
1135					Data:      decoded,
1136					MediaType: media.MediaType,
1137					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1138				})
1139
1140				textParts = append(textParts, fantasy.ToolResultPart{
1141					ToolCallID: toolResult.ToolCallID,
1142					Output: fantasy.ToolResultOutputContentText{
1143						Text: "[Image/media content loaded - see attached file]",
1144					},
1145					ProviderOptions: toolResult.ProviderOptions,
1146				})
1147			} else {
1148				textParts = append(textParts, part)
1149			}
1150		}
1151
1152		convertedMessages = append(convertedMessages, fantasy.Message{
1153			Role:    fantasy.MessageRoleTool,
1154			Content: textParts,
1155		})
1156
1157		if len(mediaFiles) > 0 {
1158			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1159				"Here is the media content from the tool result:",
1160				mediaFiles...,
1161			))
1162		}
1163	}
1164
1165	return convertedMessages
1166}
1167
1168// buildSummaryPrompt constructs the prompt text for session summarization.
1169func buildSummaryPrompt(todos []session.Todo) string {
1170	var sb strings.Builder
1171	sb.WriteString("Provide a detailed summary of our conversation above.")
1172	if len(todos) > 0 {
1173		sb.WriteString("\n\n## Current Todo List\n\n")
1174		for _, t := range todos {
1175			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1176		}
1177		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1178		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1179	}
1180	return sb.String()
1181}