agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/catwalk/pkg/catwalk"
  26	"charm.land/fantasy"
  27	"charm.land/fantasy/providers/anthropic"
  28	"charm.land/fantasy/providers/bedrock"
  29	"charm.land/fantasy/providers/google"
  30	"charm.land/fantasy/providers/openai"
  31	"charm.land/fantasy/providers/openrouter"
  32	"charm.land/fantasy/providers/vercel"
  33	"charm.land/lipgloss/v2"
  34	"github.com/charmbracelet/crush/internal/agent/hyper"
  35	"github.com/charmbracelet/crush/internal/agent/tools"
  36	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
  37	"github.com/charmbracelet/crush/internal/config"
  38	"github.com/charmbracelet/crush/internal/csync"
  39	"github.com/charmbracelet/crush/internal/message"
  40	"github.com/charmbracelet/crush/internal/permission"
  41	"github.com/charmbracelet/crush/internal/session"
  42	"github.com/charmbracelet/crush/internal/stringext"
  43	"github.com/charmbracelet/crush/internal/version"
  44	"github.com/charmbracelet/x/exp/charmtone"
  45)
  46
  47const (
  48	DefaultSessionName = "Untitled Session"
  49
  50	// Constants for auto-summarization thresholds
  51	largeContextWindowThreshold = 200_000
  52	largeContextWindowBuffer    = 20_000
  53	smallContextWindowRatio     = 0.2
  54)
  55
  56//go:embed templates/title.md
  57var titlePrompt []byte
  58
  59//go:embed templates/summary.md
  60var summaryPrompt []byte
  61
  62// Used to remove <think> tags from generated titles.
  63var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  64
  65type SessionAgentCall struct {
  66	SessionID        string
  67	Prompt           string
  68	ProviderOptions  fantasy.ProviderOptions
  69	Attachments      []message.Attachment
  70	MaxOutputTokens  int64
  71	Temperature      *float64
  72	TopP             *float64
  73	TopK             *int64
  74	FrequencyPenalty *float64
  75	PresencePenalty  *float64
  76}
  77
  78type SessionAgent interface {
  79	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  80	SetModels(large Model, small Model)
  81	SetTools(tools []fantasy.AgentTool)
  82	SetSystemPrompt(systemPrompt string)
  83	Cancel(sessionID string)
  84	CancelAll()
  85	IsSessionBusy(sessionID string) bool
  86	IsBusy() bool
  87	QueuedPrompts(sessionID string) int
  88	QueuedPromptsList(sessionID string) []string
  89	ClearQueue(sessionID string)
  90	Summarize(context.Context, string, fantasy.ProviderOptions) error
  91	Model() Model
  92}
  93
  94type Model struct {
  95	Model      fantasy.LanguageModel
  96	CatwalkCfg catwalk.Model
  97	ModelCfg   config.SelectedModel
  98}
  99
 100type sessionAgent struct {
 101	largeModel         *csync.Value[Model]
 102	smallModel         *csync.Value[Model]
 103	systemPromptPrefix *csync.Value[string]
 104	systemPrompt       *csync.Value[string]
 105	tools              *csync.Slice[fantasy.AgentTool]
 106
 107	isSubAgent           bool
 108	sessions             session.Service
 109	messages             message.Service
 110	disableAutoSummarize bool
 111	isYolo               bool
 112
 113	messageQueue   *csync.Map[string, []SessionAgentCall]
 114	activeRequests *csync.Map[string, context.CancelFunc]
 115}
 116
 117type SessionAgentOptions struct {
 118	LargeModel           Model
 119	SmallModel           Model
 120	SystemPromptPrefix   string
 121	SystemPrompt         string
 122	IsSubAgent           bool
 123	DisableAutoSummarize bool
 124	IsYolo               bool
 125	Sessions             session.Service
 126	Messages             message.Service
 127	Tools                []fantasy.AgentTool
 128}
 129
 130func NewSessionAgent(
 131	opts SessionAgentOptions,
 132) SessionAgent {
 133	return &sessionAgent{
 134		largeModel:           csync.NewValue(opts.LargeModel),
 135		smallModel:           csync.NewValue(opts.SmallModel),
 136		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
 137		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 138		isSubAgent:           opts.IsSubAgent,
 139		sessions:             opts.Sessions,
 140		messages:             opts.Messages,
 141		disableAutoSummarize: opts.DisableAutoSummarize,
 142		tools:                csync.NewSliceFrom(opts.Tools),
 143		isYolo:               opts.IsYolo,
 144		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 145		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 146	}
 147}
 148
 149func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 150	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 151		return nil, ErrEmptyPrompt
 152	}
 153	if call.SessionID == "" {
 154		return nil, ErrSessionMissing
 155	}
 156
 157	// Queue the message if busy
 158	if a.IsSessionBusy(call.SessionID) {
 159		existing, ok := a.messageQueue.Get(call.SessionID)
 160		if !ok {
 161			existing = []SessionAgentCall{}
 162		}
 163		existing = append(existing, call)
 164		a.messageQueue.Set(call.SessionID, existing)
 165		return nil, nil
 166	}
 167
 168	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
 169	agentTools := a.tools.Copy()
 170	largeModel := a.largeModel.Get()
 171	systemPrompt := a.systemPrompt.Get()
 172	promptPrefix := a.systemPromptPrefix.Get()
 173	var instructions strings.Builder
 174
 175	for _, server := range mcp.GetStates() {
 176		if server.State != mcp.StateConnected {
 177			continue
 178		}
 179		if s := server.Client.InitializeResult().Instructions; s != "" {
 180			instructions.WriteString(s)
 181			instructions.WriteString("\n\n")
 182		}
 183	}
 184
 185	if s := instructions.String(); s != "" {
 186		systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
 187	}
 188
 189	if len(agentTools) > 0 {
 190		// Add Anthropic caching to the last tool.
 191		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 192	}
 193
 194	agent := fantasy.NewAgent(
 195		largeModel.Model,
 196		fantasy.WithSystemPrompt(systemPrompt),
 197		fantasy.WithTools(agentTools...),
 198		fantasy.WithUserAgent("Charm Crush/"+version.Version),
 199	)
 200
 201	sessionLock := sync.Mutex{}
 202	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 203	if err != nil {
 204		return nil, fmt.Errorf("failed to get session: %w", err)
 205	}
 206
 207	msgs, err := a.getSessionMessages(ctx, currentSession)
 208	if err != nil {
 209		return nil, fmt.Errorf("failed to get session messages: %w", err)
 210	}
 211
 212	var wg sync.WaitGroup
 213	// Generate title if first message.
 214	if len(msgs) == 0 {
 215		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 216		wg.Go(func() {
 217			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 218		})
 219	}
 220	defer wg.Wait()
 221
 222	// Add the user message to the session.
 223	_, err = a.createUserMessage(ctx, call)
 224	if err != nil {
 225		return nil, err
 226	}
 227
 228	// Add the session to the context.
 229	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 230
 231	genCtx, cancel := context.WithCancel(ctx)
 232	a.activeRequests.Set(call.SessionID, cancel)
 233
 234	defer cancel()
 235	defer a.activeRequests.Del(call.SessionID)
 236
 237	history, files := a.preparePrompt(msgs, call.Attachments...)
 238
 239	startTime := time.Now()
 240	a.eventPromptSent(call.SessionID)
 241
 242	var currentAssistant *message.Message
 243	var shouldSummarize bool
 244	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 245		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 246		Files:            files,
 247		Messages:         history,
 248		ProviderOptions:  call.ProviderOptions,
 249		MaxOutputTokens:  &call.MaxOutputTokens,
 250		TopP:             call.TopP,
 251		Temperature:      call.Temperature,
 252		PresencePenalty:  call.PresencePenalty,
 253		TopK:             call.TopK,
 254		FrequencyPenalty: call.FrequencyPenalty,
 255		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 256			prepared.Messages = options.Messages
 257			for i := range prepared.Messages {
 258				prepared.Messages[i].ProviderOptions = nil
 259			}
 260
 261			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 262			a.messageQueue.Del(call.SessionID)
 263			for _, queued := range queuedCalls {
 264				userMessage, createErr := a.createUserMessage(callContext, queued)
 265				if createErr != nil {
 266					return callContext, prepared, createErr
 267				}
 268				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 269			}
 270
 271			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 272
 273			lastSystemRoleInx := 0
 274			systemMessageUpdated := false
 275			for i, msg := range prepared.Messages {
 276				// Only add cache control to the last message.
 277				if msg.Role == fantasy.MessageRoleSystem {
 278					lastSystemRoleInx = i
 279				} else if !systemMessageUpdated {
 280					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 281					systemMessageUpdated = true
 282				}
 283				// Than add cache control to the last 2 messages.
 284				if i > len(prepared.Messages)-3 {
 285					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 286				}
 287			}
 288
 289			if promptPrefix != "" {
 290				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 291			}
 292
 293			var assistantMsg message.Message
 294			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 295				Role:     message.Assistant,
 296				Parts:    []message.ContentPart{},
 297				Model:    largeModel.ModelCfg.Model,
 298				Provider: largeModel.ModelCfg.Provider,
 299			})
 300			if err != nil {
 301				return callContext, prepared, err
 302			}
 303			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 304			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
 305			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 306			currentAssistant = &assistantMsg
 307			return callContext, prepared, err
 308		},
 309		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 310			currentAssistant.AppendReasoningContent(reasoning.Text)
 311			return a.messages.Update(genCtx, *currentAssistant)
 312		},
 313		OnReasoningDelta: func(id string, text string) error {
 314			currentAssistant.AppendReasoningContent(text)
 315			return a.messages.Update(genCtx, *currentAssistant)
 316		},
 317		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 318			// handle anthropic signature
 319			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 320				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 321					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 322				}
 323			}
 324			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 325				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 326					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 327				}
 328			}
 329			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 330				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 331					currentAssistant.SetReasoningResponsesData(reasoning)
 332				}
 333			}
 334			currentAssistant.FinishThinking()
 335			return a.messages.Update(genCtx, *currentAssistant)
 336		},
 337		OnTextDelta: func(id string, text string) error {
 338			// Strip leading newline from initial text content. This is is
 339			// particularly important in non-interactive mode where leading
 340			// newlines are very visible.
 341			if len(currentAssistant.Parts) == 0 {
 342				text = strings.TrimPrefix(text, "\n")
 343			}
 344
 345			currentAssistant.AppendContent(text)
 346			return a.messages.Update(genCtx, *currentAssistant)
 347		},
 348		OnToolInputStart: func(id string, toolName string) error {
 349			toolCall := message.ToolCall{
 350				ID:               id,
 351				Name:             toolName,
 352				ProviderExecuted: false,
 353				Finished:         false,
 354			}
 355			currentAssistant.AddToolCall(toolCall)
 356			return a.messages.Update(genCtx, *currentAssistant)
 357		},
 358		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 359			// TODO: implement
 360		},
 361		OnToolCall: func(tc fantasy.ToolCallContent) error {
 362			toolCall := message.ToolCall{
 363				ID:               tc.ToolCallID,
 364				Name:             tc.ToolName,
 365				Input:            tc.Input,
 366				ProviderExecuted: false,
 367				Finished:         true,
 368			}
 369			currentAssistant.AddToolCall(toolCall)
 370			return a.messages.Update(genCtx, *currentAssistant)
 371		},
 372		OnToolResult: func(result fantasy.ToolResultContent) error {
 373			toolResult := a.convertToToolResult(result)
 374			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 375				Role: message.Tool,
 376				Parts: []message.ContentPart{
 377					toolResult,
 378				},
 379			})
 380			return createMsgErr
 381		},
 382		OnStepFinish: func(stepResult fantasy.StepResult) error {
 383			finishReason := message.FinishReasonUnknown
 384			switch stepResult.FinishReason {
 385			case fantasy.FinishReasonLength:
 386				finishReason = message.FinishReasonMaxTokens
 387			case fantasy.FinishReasonStop:
 388				finishReason = message.FinishReasonEndTurn
 389			case fantasy.FinishReasonToolCalls:
 390				finishReason = message.FinishReasonToolUse
 391			}
 392			currentAssistant.AddFinish(finishReason, "", "")
 393			sessionLock.Lock()
 394			defer sessionLock.Unlock()
 395
 396			updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
 397			if getSessionErr != nil {
 398				return getSessionErr
 399			}
 400			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 401			_, sessionErr := a.sessions.Save(ctx, updatedSession)
 402			if sessionErr != nil {
 403				return sessionErr
 404			}
 405			currentSession = updatedSession
 406			return a.messages.Update(genCtx, *currentAssistant)
 407		},
 408		StopWhen: []fantasy.StopCondition{
 409			func(_ []fantasy.StepResult) bool {
 410				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 411				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 412				remaining := cw - tokens
 413				var threshold int64
 414				if cw > largeContextWindowThreshold {
 415					threshold = largeContextWindowBuffer
 416				} else {
 417					threshold = int64(float64(cw) * smallContextWindowRatio)
 418				}
 419				if (remaining <= threshold) && !a.disableAutoSummarize {
 420					shouldSummarize = true
 421					return true
 422				}
 423				return false
 424			},
 425			func(steps []fantasy.StepResult) bool {
 426				return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
 427			},
 428		},
 429	})
 430
 431	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 432
 433	if err != nil {
 434		isCancelErr := errors.Is(err, context.Canceled)
 435		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 436		if currentAssistant == nil {
 437			return result, err
 438		}
 439		// Ensure we finish thinking on error to close the reasoning state.
 440		currentAssistant.FinishThinking()
 441		toolCalls := currentAssistant.ToolCalls()
 442		// INFO: we use the parent context here because the genCtx has been cancelled.
 443		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 444		if createErr != nil {
 445			return nil, createErr
 446		}
 447		for _, tc := range toolCalls {
 448			if !tc.Finished {
 449				tc.Finished = true
 450				tc.Input = "{}"
 451				currentAssistant.AddToolCall(tc)
 452				updateErr := a.messages.Update(ctx, *currentAssistant)
 453				if updateErr != nil {
 454					return nil, updateErr
 455				}
 456			}
 457
 458			found := false
 459			for _, msg := range msgs {
 460				if msg.Role == message.Tool {
 461					for _, tr := range msg.ToolResults() {
 462						if tr.ToolCallID == tc.ID {
 463							found = true
 464							break
 465						}
 466					}
 467				}
 468				if found {
 469					break
 470				}
 471			}
 472			if found {
 473				continue
 474			}
 475			content := "There was an error while executing the tool"
 476			if isCancelErr {
 477				content = "Tool execution canceled by user"
 478			} else if isPermissionErr {
 479				content = "User denied permission"
 480			}
 481			toolResult := message.ToolResult{
 482				ToolCallID: tc.ID,
 483				Name:       tc.Name,
 484				Content:    content,
 485				IsError:    true,
 486			}
 487			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 488				Role: message.Tool,
 489				Parts: []message.ContentPart{
 490					toolResult,
 491				},
 492			})
 493			if createErr != nil {
 494				return nil, createErr
 495			}
 496		}
 497		var fantasyErr *fantasy.Error
 498		var providerErr *fantasy.ProviderError
 499		const defaultTitle = "Provider Error"
 500		linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
 501		if isCancelErr {
 502			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 503		} else if isPermissionErr {
 504			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 505		} else if errors.Is(err, hyper.ErrNoCredits) {
 506			url := hyper.BaseURL()
 507			link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
 508			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 509		} else if errors.As(err, &providerErr) {
 510			if providerErr.Message == "The requested model is not supported." {
 511				url := "https://github.com/settings/copilot/features"
 512				link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
 513				currentAssistant.AddFinish(
 514					message.FinishReasonError,
 515					"Copilot model not enabled",
 516					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 517				)
 518			} else {
 519				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 520			}
 521		} else if errors.As(err, &fantasyErr) {
 522			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 523		} else {
 524			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 525		}
 526		// Note: we use the parent context here because the genCtx has been
 527		// cancelled.
 528		updateErr := a.messages.Update(ctx, *currentAssistant)
 529		if updateErr != nil {
 530			return nil, updateErr
 531		}
 532		return nil, err
 533	}
 534
 535	if shouldSummarize {
 536		a.activeRequests.Del(call.SessionID)
 537		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 538			return nil, summarizeErr
 539		}
 540		// If the agent wasn't done...
 541		if len(currentAssistant.ToolCalls()) > 0 {
 542			existing, ok := a.messageQueue.Get(call.SessionID)
 543			if !ok {
 544				existing = []SessionAgentCall{}
 545			}
 546			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 547			existing = append(existing, call)
 548			a.messageQueue.Set(call.SessionID, existing)
 549		}
 550	}
 551
 552	// Release active request before processing queued messages.
 553	a.activeRequests.Del(call.SessionID)
 554	cancel()
 555
 556	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 557	if !ok || len(queuedMessages) == 0 {
 558		return result, err
 559	}
 560	// There are queued messages restart the loop.
 561	firstQueuedMessage := queuedMessages[0]
 562	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 563	return a.Run(ctx, firstQueuedMessage)
 564}
 565
 566func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 567	if a.IsSessionBusy(sessionID) {
 568		return ErrSessionBusy
 569	}
 570
 571	// Copy mutable fields under lock to avoid races with SetModels.
 572	largeModel := a.largeModel.Get()
 573	systemPromptPrefix := a.systemPromptPrefix.Get()
 574
 575	currentSession, err := a.sessions.Get(ctx, sessionID)
 576	if err != nil {
 577		return fmt.Errorf("failed to get session: %w", err)
 578	}
 579	msgs, err := a.getSessionMessages(ctx, currentSession)
 580	if err != nil {
 581		return err
 582	}
 583	if len(msgs) == 0 {
 584		// Nothing to summarize.
 585		return nil
 586	}
 587
 588	aiMsgs, _ := a.preparePrompt(msgs)
 589
 590	genCtx, cancel := context.WithCancel(ctx)
 591	a.activeRequests.Set(sessionID, cancel)
 592	defer a.activeRequests.Del(sessionID)
 593	defer cancel()
 594
 595	agent := fantasy.NewAgent(largeModel.Model,
 596		fantasy.WithSystemPrompt(string(summaryPrompt)),
 597		fantasy.WithUserAgent("Charm Crush/"+version.Version),
 598	)
 599	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 600		Role:             message.Assistant,
 601		Model:            largeModel.Model.Model(),
 602		Provider:         largeModel.Model.Provider(),
 603		IsSummaryMessage: true,
 604	})
 605	if err != nil {
 606		return err
 607	}
 608
 609	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 610
 611	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 612		Prompt:          summaryPromptText,
 613		Messages:        aiMsgs,
 614		ProviderOptions: opts,
 615		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 616			prepared.Messages = options.Messages
 617			if systemPromptPrefix != "" {
 618				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 619			}
 620			return callContext, prepared, nil
 621		},
 622		OnReasoningDelta: func(id string, text string) error {
 623			summaryMessage.AppendReasoningContent(text)
 624			return a.messages.Update(genCtx, summaryMessage)
 625		},
 626		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 627			// Handle anthropic signature.
 628			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 629				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 630					summaryMessage.AppendReasoningSignature(signature.Signature)
 631				}
 632			}
 633			summaryMessage.FinishThinking()
 634			return a.messages.Update(genCtx, summaryMessage)
 635		},
 636		OnTextDelta: func(id, text string) error {
 637			summaryMessage.AppendContent(text)
 638			return a.messages.Update(genCtx, summaryMessage)
 639		},
 640	})
 641	if err != nil {
 642		isCancelErr := errors.Is(err, context.Canceled)
 643		if isCancelErr {
 644			// User cancelled summarize we need to remove the summary message.
 645			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 646			return deleteErr
 647		}
 648		return err
 649	}
 650
 651	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 652	err = a.messages.Update(genCtx, summaryMessage)
 653	if err != nil {
 654		return err
 655	}
 656
 657	var openrouterCost *float64
 658	for _, step := range resp.Steps {
 659		stepCost := a.openrouterCost(step.ProviderMetadata)
 660		if stepCost != nil {
 661			newCost := *stepCost
 662			if openrouterCost != nil {
 663				newCost += *openrouterCost
 664			}
 665			openrouterCost = &newCost
 666		}
 667	}
 668
 669	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 670
 671	// Just in case, get just the last usage info.
 672	usage := resp.Response.Usage
 673	currentSession.SummaryMessageID = summaryMessage.ID
 674	currentSession.CompletionTokens = usage.OutputTokens
 675	currentSession.PromptTokens = 0
 676	_, err = a.sessions.Save(genCtx, currentSession)
 677	return err
 678}
 679
 680func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 681	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 682		return fantasy.ProviderOptions{}
 683	}
 684	return fantasy.ProviderOptions{
 685		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 686			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 687		},
 688		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 689			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 690		},
 691		vercel.Name: &anthropic.ProviderCacheControlOptions{
 692			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 693		},
 694	}
 695}
 696
 697func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 698	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 699	var attachmentParts []message.ContentPart
 700	for _, attachment := range call.Attachments {
 701		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 702	}
 703	parts = append(parts, attachmentParts...)
 704	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 705		Role:  message.User,
 706		Parts: parts,
 707	})
 708	if err != nil {
 709		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 710	}
 711	return msg, nil
 712}
 713
 714func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 715	var history []fantasy.Message
 716	if !a.isSubAgent {
 717		history = append(history, fantasy.NewUserMessage(
 718			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 719				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 720If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 721If not, please feel free to ignore. Again do not mention this message to the user.`,
 722			),
 723		))
 724	}
 725	for _, m := range msgs {
 726		if len(m.Parts) == 0 {
 727			continue
 728		}
 729		// Assistant message without content or tool calls (cancelled before it
 730		// returned anything).
 731		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 732			continue
 733		}
 734		history = append(history, m.ToAIMessage()...)
 735	}
 736
 737	var files []fantasy.FilePart
 738	for _, attachment := range attachments {
 739		if attachment.IsText() {
 740			continue
 741		}
 742		files = append(files, fantasy.FilePart{
 743			Filename:  attachment.FileName,
 744			Data:      attachment.Content,
 745			MediaType: attachment.MimeType,
 746		})
 747	}
 748
 749	return history, files
 750}
 751
 752func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 753	msgs, err := a.messages.List(ctx, session.ID)
 754	if err != nil {
 755		return nil, fmt.Errorf("failed to list messages: %w", err)
 756	}
 757
 758	if session.SummaryMessageID != "" {
 759		summaryMsgIndex := -1
 760		for i, msg := range msgs {
 761			if msg.ID == session.SummaryMessageID {
 762				summaryMsgIndex = i
 763				break
 764			}
 765		}
 766		if summaryMsgIndex != -1 {
 767			msgs = msgs[summaryMsgIndex:]
 768			msgs[0].Role = message.User
 769		}
 770	}
 771	return msgs, nil
 772}
 773
 774// generateTitle generates a session titled based on the initial prompt.
 775func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 776	if userPrompt == "" {
 777		return
 778	}
 779
 780	smallModel := a.smallModel.Get()
 781	largeModel := a.largeModel.Get()
 782	systemPromptPrefix := a.systemPromptPrefix.Get()
 783
 784	var maxOutputTokens int64 = 40
 785	if smallModel.CatwalkCfg.CanReason {
 786		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 787	}
 788
 789	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 790		return fantasy.NewAgent(m,
 791			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 792			fantasy.WithMaxOutputTokens(tok),
 793			fantasy.WithUserAgent("Charm Crush/"+version.Version),
 794		)
 795	}
 796
 797	streamCall := fantasy.AgentStreamCall{
 798		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 799		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 800			prepared.Messages = opts.Messages
 801			if systemPromptPrefix != "" {
 802				prepared.Messages = append([]fantasy.Message{
 803					fantasy.NewSystemMessage(systemPromptPrefix),
 804				}, prepared.Messages...)
 805			}
 806			return callCtx, prepared, nil
 807		},
 808	}
 809
 810	// Use the small model to generate the title.
 811	model := smallModel
 812	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 813	resp, err := agent.Stream(ctx, streamCall)
 814	if err == nil {
 815		// We successfully generated a title with the small model.
 816		slog.Debug("Generated title with small model")
 817	} else {
 818		// It didn't work. Let's try with the big model.
 819		slog.Error("Error generating title with small model; trying big model", "err", err)
 820		model = largeModel
 821		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 822		resp, err = agent.Stream(ctx, streamCall)
 823		if err == nil {
 824			slog.Debug("Generated title with large model")
 825		} else {
 826			// Welp, the large model didn't work either. Use the default
 827			// session name and return.
 828			slog.Error("Error generating title with large model", "err", err)
 829			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
 830			if saveErr != nil {
 831				slog.Error("Failed to save session title and usage", "error", saveErr)
 832			}
 833			return
 834		}
 835	}
 836
 837	if resp == nil {
 838		// Actually, we didn't get a response so we can't. Use the default
 839		// session name and return.
 840		slog.Error("Response is nil; can't generate title")
 841		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
 842		if saveErr != nil {
 843			slog.Error("Failed to save session title and usage", "error", saveErr)
 844		}
 845		return
 846	}
 847
 848	// Clean up title.
 849	var title string
 850	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 851
 852	// Remove thinking tags if present.
 853	title = thinkTagRegex.ReplaceAllString(title, "")
 854
 855	title = strings.TrimSpace(title)
 856	title = cmp.Or(title, DefaultSessionName)
 857
 858	// Calculate usage and cost.
 859	var openrouterCost *float64
 860	for _, step := range resp.Steps {
 861		stepCost := a.openrouterCost(step.ProviderMetadata)
 862		if stepCost != nil {
 863			newCost := *stepCost
 864			if openrouterCost != nil {
 865				newCost += *openrouterCost
 866			}
 867			openrouterCost = &newCost
 868		}
 869	}
 870
 871	modelConfig := model.CatwalkCfg
 872	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 873		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 874		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 875		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 876
 877	// Use override cost if available (e.g., from OpenRouter).
 878	if openrouterCost != nil {
 879		cost = *openrouterCost
 880	}
 881
 882	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 883	completionTokens := resp.TotalUsage.OutputTokens
 884
 885	// Atomically update only title and usage fields to avoid overriding other
 886	// concurrent session updates.
 887	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 888	if saveErr != nil {
 889		slog.Error("Failed to save session title and usage", "error", saveErr)
 890		return
 891	}
 892}
 893
 894func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 895	openrouterMetadata, ok := metadata[openrouter.Name]
 896	if !ok {
 897		return nil
 898	}
 899
 900	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 901	if !ok {
 902		return nil
 903	}
 904	return &opts.Usage.Cost
 905}
 906
 907func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 908	modelConfig := model.CatwalkCfg
 909	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 910		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 911		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 912		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 913
 914	a.eventTokensUsed(session.ID, model, usage, cost)
 915
 916	if overrideCost != nil {
 917		session.Cost += *overrideCost
 918	} else {
 919		session.Cost += cost
 920	}
 921
 922	session.CompletionTokens = usage.OutputTokens
 923	session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
 924}
 925
 926func (a *sessionAgent) Cancel(sessionID string) {
 927	// Cancel regular requests. Don't use Take() here - we need the entry to
 928	// remain in activeRequests so IsBusy() returns true until the goroutine
 929	// fully completes (including error handling that may access the DB).
 930	// The defer in processRequest will clean up the entry.
 931	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 932		slog.Debug("Request cancellation initiated", "session_id", sessionID)
 933		cancel()
 934	}
 935
 936	// Also check for summarize requests.
 937	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 938		slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
 939		cancel()
 940	}
 941
 942	if a.QueuedPrompts(sessionID) > 0 {
 943		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 944		a.messageQueue.Del(sessionID)
 945	}
 946}
 947
 948func (a *sessionAgent) ClearQueue(sessionID string) {
 949	if a.QueuedPrompts(sessionID) > 0 {
 950		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 951		a.messageQueue.Del(sessionID)
 952	}
 953}
 954
 955func (a *sessionAgent) CancelAll() {
 956	if !a.IsBusy() {
 957		return
 958	}
 959	for key := range a.activeRequests.Seq2() {
 960		a.Cancel(key) // key is sessionID
 961	}
 962
 963	timeout := time.After(5 * time.Second)
 964	for a.IsBusy() {
 965		select {
 966		case <-timeout:
 967			return
 968		default:
 969			time.Sleep(200 * time.Millisecond)
 970		}
 971	}
 972}
 973
 974func (a *sessionAgent) IsBusy() bool {
 975	var busy bool
 976	for cancelFunc := range a.activeRequests.Seq() {
 977		if cancelFunc != nil {
 978			busy = true
 979			break
 980		}
 981	}
 982	return busy
 983}
 984
 985func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 986	_, busy := a.activeRequests.Get(sessionID)
 987	return busy
 988}
 989
 990func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 991	l, ok := a.messageQueue.Get(sessionID)
 992	if !ok {
 993		return 0
 994	}
 995	return len(l)
 996}
 997
 998func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 999	l, ok := a.messageQueue.Get(sessionID)
1000	if !ok {
1001		return nil
1002	}
1003	prompts := make([]string, len(l))
1004	for i, call := range l {
1005		prompts[i] = call.Prompt
1006	}
1007	return prompts
1008}
1009
1010func (a *sessionAgent) SetModels(large Model, small Model) {
1011	a.largeModel.Set(large)
1012	a.smallModel.Set(small)
1013}
1014
1015func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1016	a.tools.SetSlice(tools)
1017}
1018
1019func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1020	a.systemPrompt.Set(systemPrompt)
1021}
1022
1023func (a *sessionAgent) Model() Model {
1024	return a.largeModel.Get()
1025}
1026
1027// convertToToolResult converts a fantasy tool result to a message tool result.
1028func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1029	baseResult := message.ToolResult{
1030		ToolCallID: result.ToolCallID,
1031		Name:       result.ToolName,
1032		Metadata:   result.ClientMetadata,
1033	}
1034
1035	switch result.Result.GetType() {
1036	case fantasy.ToolResultContentTypeText:
1037		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1038			baseResult.Content = r.Text
1039		}
1040	case fantasy.ToolResultContentTypeError:
1041		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1042			baseResult.Content = r.Error.Error()
1043			baseResult.IsError = true
1044		}
1045	case fantasy.ToolResultContentTypeMedia:
1046		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1047			content := r.Text
1048			if content == "" {
1049				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1050			}
1051			baseResult.Content = content
1052			baseResult.Data = r.Data
1053			baseResult.MIMEType = r.MediaType
1054		}
1055	}
1056
1057	return baseResult
1058}
1059
1060// workaroundProviderMediaLimitations converts media content in tool results to
1061// user messages for providers that don't natively support images in tool results.
1062//
1063// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1064// don't support sending images/media in tool result messages - they only accept
1065// text in tool results. However, they DO support images in user messages.
1066//
1067// If we send media in tool results to these providers, the API returns an error.
1068//
1069// Solution: For these providers, we:
1070//  1. Replace the media in the tool result with a text placeholder
1071//  2. Inject a user message immediately after with the image as a file attachment
1072//  3. This maintains the tool execution flow while working around API limitations
1073//
1074// Anthropic and Bedrock support images natively in tool results, so we skip
1075// this workaround for them.
1076//
1077// Example transformation:
1078//
1079//	BEFORE: [tool result: image data]
1080//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1081func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1082	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1083		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1084
1085	if providerSupportsMedia {
1086		return messages
1087	}
1088
1089	convertedMessages := make([]fantasy.Message, 0, len(messages))
1090
1091	for _, msg := range messages {
1092		if msg.Role != fantasy.MessageRoleTool {
1093			convertedMessages = append(convertedMessages, msg)
1094			continue
1095		}
1096
1097		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1098		var mediaFiles []fantasy.FilePart
1099
1100		for _, part := range msg.Content {
1101			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1102			if !ok {
1103				textParts = append(textParts, part)
1104				continue
1105			}
1106
1107			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1108				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1109				if err != nil {
1110					slog.Warn("Failed to decode media data", "error", err)
1111					textParts = append(textParts, part)
1112					continue
1113				}
1114
1115				mediaFiles = append(mediaFiles, fantasy.FilePart{
1116					Data:      decoded,
1117					MediaType: media.MediaType,
1118					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1119				})
1120
1121				textParts = append(textParts, fantasy.ToolResultPart{
1122					ToolCallID: toolResult.ToolCallID,
1123					Output: fantasy.ToolResultOutputContentText{
1124						Text: "[Image/media content loaded - see attached file]",
1125					},
1126					ProviderOptions: toolResult.ProviderOptions,
1127				})
1128			} else {
1129				textParts = append(textParts, part)
1130			}
1131		}
1132
1133		convertedMessages = append(convertedMessages, fantasy.Message{
1134			Role:    fantasy.MessageRoleTool,
1135			Content: textParts,
1136		})
1137
1138		if len(mediaFiles) > 0 {
1139			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1140				"Here is the media content from the tool result:",
1141				mediaFiles...,
1142			))
1143		}
1144	}
1145
1146	return convertedMessages
1147}
1148
1149// buildSummaryPrompt constructs the prompt text for session summarization.
1150func buildSummaryPrompt(todos []session.Todo) string {
1151	var sb strings.Builder
1152	sb.WriteString("Provide a detailed summary of our conversation above.")
1153	if len(todos) > 0 {
1154		sb.WriteString("\n\n## Current Todo List\n\n")
1155		for _, t := range todos {
1156			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1157		}
1158		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1159		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1160	}
1161	return sb.String()
1162}