agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/catwalk/pkg/catwalk"
  26	"charm.land/fantasy"
  27	"charm.land/fantasy/providers/anthropic"
  28	"charm.land/fantasy/providers/bedrock"
  29	"charm.land/fantasy/providers/google"
  30	"charm.land/fantasy/providers/openai"
  31	"charm.land/fantasy/providers/openrouter"
  32	"charm.land/fantasy/providers/vercel"
  33	"charm.land/lipgloss/v2"
  34	"github.com/charmbracelet/crush/internal/agent/hyper"
  35	"github.com/charmbracelet/crush/internal/agent/notify"
  36	"github.com/charmbracelet/crush/internal/agent/tools"
  37	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
  38	"github.com/charmbracelet/crush/internal/config"
  39	"github.com/charmbracelet/crush/internal/csync"
  40	"github.com/charmbracelet/crush/internal/message"
  41	"github.com/charmbracelet/crush/internal/permission"
  42	"github.com/charmbracelet/crush/internal/pubsub"
  43	"github.com/charmbracelet/crush/internal/session"
  44	"github.com/charmbracelet/crush/internal/stringext"
  45	"github.com/charmbracelet/crush/internal/version"
  46	"github.com/charmbracelet/x/exp/charmtone"
  47)
  48
  49const (
  50	DefaultSessionName = "Untitled Session"
  51
  52	// Constants for auto-summarization thresholds
  53	largeContextWindowThreshold = 200_000
  54	largeContextWindowBuffer    = 20_000
  55	smallContextWindowRatio     = 0.2
  56)
  57
  58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
  59
  60//go:embed templates/title.md
  61var titlePrompt []byte
  62
  63//go:embed templates/summary.md
  64var summaryPrompt []byte
  65
  66// Used to remove <think> tags from generated titles.
  67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  68
  69type SessionAgentCall struct {
  70	SessionID        string
  71	Prompt           string
  72	ProviderOptions  fantasy.ProviderOptions
  73	Attachments      []message.Attachment
  74	MaxOutputTokens  int64
  75	Temperature      *float64
  76	TopP             *float64
  77	TopK             *int64
  78	FrequencyPenalty *float64
  79	PresencePenalty  *float64
  80	NonInteractive   bool
  81}
  82
  83type SessionAgent interface {
  84	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  85	SetModels(large Model, small Model)
  86	SetTools(tools []fantasy.AgentTool)
  87	SetSystemPrompt(systemPrompt string)
  88	Cancel(sessionID string)
  89	CancelAll()
  90	IsSessionBusy(sessionID string) bool
  91	IsBusy() bool
  92	QueuedPrompts(sessionID string) int
  93	QueuedPromptsList(sessionID string) []string
  94	ClearQueue(sessionID string)
  95	Summarize(context.Context, string, fantasy.ProviderOptions) error
  96	Model() Model
  97}
  98
  99type Model struct {
 100	Model      fantasy.LanguageModel
 101	CatwalkCfg catwalk.Model
 102	ModelCfg   config.SelectedModel
 103}
 104
 105type sessionAgent struct {
 106	largeModel         *csync.Value[Model]
 107	smallModel         *csync.Value[Model]
 108	systemPromptPrefix *csync.Value[string]
 109	systemPrompt       *csync.Value[string]
 110	tools              *csync.Slice[fantasy.AgentTool]
 111
 112	isSubAgent           bool
 113	sessions             session.Service
 114	messages             message.Service
 115	disableAutoSummarize bool
 116	isYolo               bool
 117	notify               pubsub.Publisher[notify.Notification]
 118
 119	messageQueue   *csync.Map[string, []SessionAgentCall]
 120	activeRequests *csync.Map[string, context.CancelFunc]
 121}
 122
 123type SessionAgentOptions struct {
 124	LargeModel           Model
 125	SmallModel           Model
 126	SystemPromptPrefix   string
 127	SystemPrompt         string
 128	IsSubAgent           bool
 129	DisableAutoSummarize bool
 130	IsYolo               bool
 131	Sessions             session.Service
 132	Messages             message.Service
 133	Tools                []fantasy.AgentTool
 134	Notify               pubsub.Publisher[notify.Notification]
 135}
 136
 137func NewSessionAgent(
 138	opts SessionAgentOptions,
 139) SessionAgent {
 140	return &sessionAgent{
 141		largeModel:           csync.NewValue(opts.LargeModel),
 142		smallModel:           csync.NewValue(opts.SmallModel),
 143		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
 144		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 145		isSubAgent:           opts.IsSubAgent,
 146		sessions:             opts.Sessions,
 147		messages:             opts.Messages,
 148		disableAutoSummarize: opts.DisableAutoSummarize,
 149		tools:                csync.NewSliceFrom(opts.Tools),
 150		isYolo:               opts.IsYolo,
 151		notify:               opts.Notify,
 152		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 153		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 154	}
 155}
 156
 157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 158	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 159		return nil, ErrEmptyPrompt
 160	}
 161	if call.SessionID == "" {
 162		return nil, ErrSessionMissing
 163	}
 164
 165	// Queue the message if busy
 166	if a.IsSessionBusy(call.SessionID) {
 167		existing, ok := a.messageQueue.Get(call.SessionID)
 168		if !ok {
 169			existing = []SessionAgentCall{}
 170		}
 171		existing = append(existing, call)
 172		a.messageQueue.Set(call.SessionID, existing)
 173		return nil, nil
 174	}
 175
 176	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
 177	agentTools := a.tools.Copy()
 178	largeModel := a.largeModel.Get()
 179	systemPrompt := a.systemPrompt.Get()
 180	promptPrefix := a.systemPromptPrefix.Get()
 181	var instructions strings.Builder
 182
 183	for _, server := range mcp.GetStates() {
 184		if server.State != mcp.StateConnected {
 185			continue
 186		}
 187		if s := server.Client.InitializeResult().Instructions; s != "" {
 188			instructions.WriteString(s)
 189			instructions.WriteString("\n\n")
 190		}
 191	}
 192
 193	if s := instructions.String(); s != "" {
 194		systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
 195	}
 196
 197	if len(agentTools) > 0 {
 198		// Add Anthropic caching to the last tool.
 199		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 200	}
 201
 202	agent := fantasy.NewAgent(
 203		largeModel.Model,
 204		fantasy.WithSystemPrompt(systemPrompt),
 205		fantasy.WithTools(agentTools...),
 206		fantasy.WithUserAgent(userAgent),
 207	)
 208
 209	sessionLock := sync.Mutex{}
 210	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 211	if err != nil {
 212		return nil, fmt.Errorf("failed to get session: %w", err)
 213	}
 214
 215	msgs, err := a.getSessionMessages(ctx, currentSession)
 216	if err != nil {
 217		return nil, fmt.Errorf("failed to get session messages: %w", err)
 218	}
 219
 220	var wg sync.WaitGroup
 221	// Generate title if first message.
 222	if len(msgs) == 0 {
 223		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 224		wg.Go(func() {
 225			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 226		})
 227	}
 228	defer wg.Wait()
 229
 230	// Add the user message to the session.
 231	_, err = a.createUserMessage(ctx, call)
 232	if err != nil {
 233		return nil, err
 234	}
 235
 236	// Add the session to the context.
 237	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 238
 239	genCtx, cancel := context.WithCancel(ctx)
 240	a.activeRequests.Set(call.SessionID, cancel)
 241
 242	defer cancel()
 243	defer a.activeRequests.Del(call.SessionID)
 244
 245	history, files := a.preparePrompt(msgs, call.Attachments...)
 246
 247	startTime := time.Now()
 248	a.eventPromptSent(call.SessionID)
 249
 250	var currentAssistant *message.Message
 251	var shouldSummarize bool
 252	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 253		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 254		Files:            files,
 255		Messages:         history,
 256		ProviderOptions:  call.ProviderOptions,
 257		MaxOutputTokens:  &call.MaxOutputTokens,
 258		TopP:             call.TopP,
 259		Temperature:      call.Temperature,
 260		PresencePenalty:  call.PresencePenalty,
 261		TopK:             call.TopK,
 262		FrequencyPenalty: call.FrequencyPenalty,
 263		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 264			prepared.Messages = options.Messages
 265			for i := range prepared.Messages {
 266				prepared.Messages[i].ProviderOptions = nil
 267			}
 268
 269			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 270			a.messageQueue.Del(call.SessionID)
 271			for _, queued := range queuedCalls {
 272				userMessage, createErr := a.createUserMessage(callContext, queued)
 273				if createErr != nil {
 274					return callContext, prepared, createErr
 275				}
 276				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 277			}
 278
 279			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 280
 281			lastSystemRoleInx := 0
 282			systemMessageUpdated := false
 283			for i, msg := range prepared.Messages {
 284				// Only add cache control to the last message.
 285				if msg.Role == fantasy.MessageRoleSystem {
 286					lastSystemRoleInx = i
 287				} else if !systemMessageUpdated {
 288					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 289					systemMessageUpdated = true
 290				}
 291				// Than add cache control to the last 2 messages.
 292				if i > len(prepared.Messages)-3 {
 293					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 294				}
 295			}
 296
 297			if promptPrefix != "" {
 298				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 299			}
 300
 301			var assistantMsg message.Message
 302			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 303				Role:     message.Assistant,
 304				Parts:    []message.ContentPart{},
 305				Model:    largeModel.ModelCfg.Model,
 306				Provider: largeModel.ModelCfg.Provider,
 307			})
 308			if err != nil {
 309				return callContext, prepared, err
 310			}
 311			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 312			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
 313			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 314			currentAssistant = &assistantMsg
 315			return callContext, prepared, err
 316		},
 317		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 318			currentAssistant.AppendReasoningContent(reasoning.Text)
 319			return a.messages.Update(genCtx, *currentAssistant)
 320		},
 321		OnReasoningDelta: func(id string, text string) error {
 322			currentAssistant.AppendReasoningContent(text)
 323			return a.messages.Update(genCtx, *currentAssistant)
 324		},
 325		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 326			// handle anthropic signature
 327			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 328				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 329					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 330				}
 331			}
 332			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 333				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 334					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 335				}
 336			}
 337			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 338				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 339					currentAssistant.SetReasoningResponsesData(reasoning)
 340				}
 341			}
 342			currentAssistant.FinishThinking()
 343			return a.messages.Update(genCtx, *currentAssistant)
 344		},
 345		OnTextDelta: func(id string, text string) error {
 346			// Strip leading newline from initial text content. This is is
 347			// particularly important in non-interactive mode where leading
 348			// newlines are very visible.
 349			if len(currentAssistant.Parts) == 0 {
 350				text = strings.TrimPrefix(text, "\n")
 351			}
 352
 353			currentAssistant.AppendContent(text)
 354			return a.messages.Update(genCtx, *currentAssistant)
 355		},
 356		OnToolInputStart: func(id string, toolName string) error {
 357			toolCall := message.ToolCall{
 358				ID:               id,
 359				Name:             toolName,
 360				ProviderExecuted: false,
 361				Finished:         false,
 362			}
 363			currentAssistant.AddToolCall(toolCall)
 364			return a.messages.Update(genCtx, *currentAssistant)
 365		},
 366		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 367			// TODO: implement
 368		},
 369		OnToolCall: func(tc fantasy.ToolCallContent) error {
 370			toolCall := message.ToolCall{
 371				ID:               tc.ToolCallID,
 372				Name:             tc.ToolName,
 373				Input:            tc.Input,
 374				ProviderExecuted: false,
 375				Finished:         true,
 376			}
 377			currentAssistant.AddToolCall(toolCall)
 378			return a.messages.Update(genCtx, *currentAssistant)
 379		},
 380		OnToolResult: func(result fantasy.ToolResultContent) error {
 381			toolResult := a.convertToToolResult(result)
 382			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 383				Role: message.Tool,
 384				Parts: []message.ContentPart{
 385					toolResult,
 386				},
 387			})
 388			return createMsgErr
 389		},
 390		OnStepFinish: func(stepResult fantasy.StepResult) error {
 391			finishReason := message.FinishReasonUnknown
 392			switch stepResult.FinishReason {
 393			case fantasy.FinishReasonLength:
 394				finishReason = message.FinishReasonMaxTokens
 395			case fantasy.FinishReasonStop:
 396				finishReason = message.FinishReasonEndTurn
 397			case fantasy.FinishReasonToolCalls:
 398				finishReason = message.FinishReasonToolUse
 399			}
 400			currentAssistant.AddFinish(finishReason, "", "")
 401			sessionLock.Lock()
 402			defer sessionLock.Unlock()
 403
 404			updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
 405			if getSessionErr != nil {
 406				return getSessionErr
 407			}
 408			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 409			_, sessionErr := a.sessions.Save(ctx, updatedSession)
 410			if sessionErr != nil {
 411				return sessionErr
 412			}
 413			currentSession = updatedSession
 414			return a.messages.Update(genCtx, *currentAssistant)
 415		},
 416		StopWhen: []fantasy.StopCondition{
 417			func(_ []fantasy.StepResult) bool {
 418				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 419				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 420				remaining := cw - tokens
 421				var threshold int64
 422				if cw > largeContextWindowThreshold {
 423					threshold = largeContextWindowBuffer
 424				} else {
 425					threshold = int64(float64(cw) * smallContextWindowRatio)
 426				}
 427				if (remaining <= threshold) && !a.disableAutoSummarize {
 428					shouldSummarize = true
 429					return true
 430				}
 431				return false
 432			},
 433			func(steps []fantasy.StepResult) bool {
 434				return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
 435			},
 436		},
 437	})
 438
 439	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 440
 441	if err != nil {
 442		isCancelErr := errors.Is(err, context.Canceled)
 443		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 444		if currentAssistant == nil {
 445			return result, err
 446		}
 447		// Ensure we finish thinking on error to close the reasoning state.
 448		currentAssistant.FinishThinking()
 449		toolCalls := currentAssistant.ToolCalls()
 450		// INFO: we use the parent context here because the genCtx has been cancelled.
 451		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 452		if createErr != nil {
 453			return nil, createErr
 454		}
 455		for _, tc := range toolCalls {
 456			if !tc.Finished {
 457				tc.Finished = true
 458				tc.Input = "{}"
 459				currentAssistant.AddToolCall(tc)
 460				updateErr := a.messages.Update(ctx, *currentAssistant)
 461				if updateErr != nil {
 462					return nil, updateErr
 463				}
 464			}
 465
 466			found := false
 467			for _, msg := range msgs {
 468				if msg.Role == message.Tool {
 469					for _, tr := range msg.ToolResults() {
 470						if tr.ToolCallID == tc.ID {
 471							found = true
 472							break
 473						}
 474					}
 475				}
 476				if found {
 477					break
 478				}
 479			}
 480			if found {
 481				continue
 482			}
 483			content := "There was an error while executing the tool"
 484			if isCancelErr {
 485				content = "Tool execution canceled by user"
 486			} else if isPermissionErr {
 487				content = "User denied permission"
 488			}
 489			toolResult := message.ToolResult{
 490				ToolCallID: tc.ID,
 491				Name:       tc.Name,
 492				Content:    content,
 493				IsError:    true,
 494			}
 495			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 496				Role: message.Tool,
 497				Parts: []message.ContentPart{
 498					toolResult,
 499				},
 500			})
 501			if createErr != nil {
 502				return nil, createErr
 503			}
 504		}
 505		var fantasyErr *fantasy.Error
 506		var providerErr *fantasy.ProviderError
 507		const defaultTitle = "Provider Error"
 508		linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
 509		if isCancelErr {
 510			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 511		} else if isPermissionErr {
 512			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 513		} else if errors.Is(err, hyper.ErrNoCredits) {
 514			url := hyper.BaseURL()
 515			link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
 516			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 517		} else if errors.As(err, &providerErr) {
 518			if providerErr.Message == "The requested model is not supported." {
 519				url := "https://github.com/settings/copilot/features"
 520				link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
 521				currentAssistant.AddFinish(
 522					message.FinishReasonError,
 523					"Copilot model not enabled",
 524					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 525				)
 526			} else {
 527				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 528			}
 529		} else if errors.As(err, &fantasyErr) {
 530			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 531		} else {
 532			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 533		}
 534		// Note: we use the parent context here because the genCtx has been
 535		// cancelled.
 536		updateErr := a.messages.Update(ctx, *currentAssistant)
 537		if updateErr != nil {
 538			return nil, updateErr
 539		}
 540		return nil, err
 541	}
 542
 543	// Send notification that agent has finished its turn (skip for
 544	// nested/non-interactive sessions).
 545	if !call.NonInteractive && a.notify != nil {
 546		a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
 547			SessionID:    call.SessionID,
 548			SessionTitle: currentSession.Title,
 549			Type:         notify.TypeAgentFinished,
 550		})
 551	}
 552
 553	if shouldSummarize {
 554		a.activeRequests.Del(call.SessionID)
 555		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 556			return nil, summarizeErr
 557		}
 558		// If the agent wasn't done...
 559		if len(currentAssistant.ToolCalls()) > 0 {
 560			existing, ok := a.messageQueue.Get(call.SessionID)
 561			if !ok {
 562				existing = []SessionAgentCall{}
 563			}
 564			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 565			existing = append(existing, call)
 566			a.messageQueue.Set(call.SessionID, existing)
 567		}
 568	}
 569
 570	// Release active request before processing queued messages.
 571	a.activeRequests.Del(call.SessionID)
 572	cancel()
 573
 574	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 575	if !ok || len(queuedMessages) == 0 {
 576		return result, err
 577	}
 578	// There are queued messages restart the loop.
 579	firstQueuedMessage := queuedMessages[0]
 580	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 581	return a.Run(ctx, firstQueuedMessage)
 582}
 583
 584func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 585	if a.IsSessionBusy(sessionID) {
 586		return ErrSessionBusy
 587	}
 588
 589	// Copy mutable fields under lock to avoid races with SetModels.
 590	largeModel := a.largeModel.Get()
 591	systemPromptPrefix := a.systemPromptPrefix.Get()
 592
 593	currentSession, err := a.sessions.Get(ctx, sessionID)
 594	if err != nil {
 595		return fmt.Errorf("failed to get session: %w", err)
 596	}
 597	msgs, err := a.getSessionMessages(ctx, currentSession)
 598	if err != nil {
 599		return err
 600	}
 601	if len(msgs) == 0 {
 602		// Nothing to summarize.
 603		return nil
 604	}
 605
 606	aiMsgs, _ := a.preparePrompt(msgs)
 607
 608	genCtx, cancel := context.WithCancel(ctx)
 609	a.activeRequests.Set(sessionID, cancel)
 610	defer a.activeRequests.Del(sessionID)
 611	defer cancel()
 612
 613	agent := fantasy.NewAgent(largeModel.Model,
 614		fantasy.WithSystemPrompt(string(summaryPrompt)),
 615		fantasy.WithUserAgent(userAgent),
 616	)
 617	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 618		Role:             message.Assistant,
 619		Model:            largeModel.Model.Model(),
 620		Provider:         largeModel.Model.Provider(),
 621		IsSummaryMessage: true,
 622	})
 623	if err != nil {
 624		return err
 625	}
 626
 627	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 628
 629	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 630		Prompt:          summaryPromptText,
 631		Messages:        aiMsgs,
 632		ProviderOptions: opts,
 633		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 634			prepared.Messages = options.Messages
 635			if systemPromptPrefix != "" {
 636				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 637			}
 638			return callContext, prepared, nil
 639		},
 640		OnReasoningDelta: func(id string, text string) error {
 641			summaryMessage.AppendReasoningContent(text)
 642			return a.messages.Update(genCtx, summaryMessage)
 643		},
 644		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 645			// Handle anthropic signature.
 646			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 647				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 648					summaryMessage.AppendReasoningSignature(signature.Signature)
 649				}
 650			}
 651			summaryMessage.FinishThinking()
 652			return a.messages.Update(genCtx, summaryMessage)
 653		},
 654		OnTextDelta: func(id, text string) error {
 655			summaryMessage.AppendContent(text)
 656			return a.messages.Update(genCtx, summaryMessage)
 657		},
 658	})
 659	if err != nil {
 660		isCancelErr := errors.Is(err, context.Canceled)
 661		if isCancelErr {
 662			// User cancelled summarize we need to remove the summary message.
 663			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 664			return deleteErr
 665		}
 666		return err
 667	}
 668
 669	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 670	err = a.messages.Update(genCtx, summaryMessage)
 671	if err != nil {
 672		return err
 673	}
 674
 675	var openrouterCost *float64
 676	for _, step := range resp.Steps {
 677		stepCost := a.openrouterCost(step.ProviderMetadata)
 678		if stepCost != nil {
 679			newCost := *stepCost
 680			if openrouterCost != nil {
 681				newCost += *openrouterCost
 682			}
 683			openrouterCost = &newCost
 684		}
 685	}
 686
 687	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 688
 689	// Just in case, get just the last usage info.
 690	usage := resp.Response.Usage
 691	currentSession.SummaryMessageID = summaryMessage.ID
 692	currentSession.CompletionTokens = usage.OutputTokens
 693	currentSession.PromptTokens = 0
 694	_, err = a.sessions.Save(genCtx, currentSession)
 695	return err
 696}
 697
 698func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 699	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 700		return fantasy.ProviderOptions{}
 701	}
 702	return fantasy.ProviderOptions{
 703		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 704			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 705		},
 706		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 707			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 708		},
 709		vercel.Name: &anthropic.ProviderCacheControlOptions{
 710			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 711		},
 712	}
 713}
 714
 715func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 716	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 717	var attachmentParts []message.ContentPart
 718	for _, attachment := range call.Attachments {
 719		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 720	}
 721	parts = append(parts, attachmentParts...)
 722	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 723		Role:  message.User,
 724		Parts: parts,
 725	})
 726	if err != nil {
 727		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 728	}
 729	return msg, nil
 730}
 731
 732func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 733	var history []fantasy.Message
 734	if !a.isSubAgent {
 735		history = append(history, fantasy.NewUserMessage(
 736			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 737				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 738If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 739If not, please feel free to ignore. Again do not mention this message to the user.`,
 740			),
 741		))
 742	}
 743	for _, m := range msgs {
 744		if len(m.Parts) == 0 {
 745			continue
 746		}
 747		// Assistant message without content or tool calls (cancelled before it
 748		// returned anything).
 749		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 750			continue
 751		}
 752		history = append(history, m.ToAIMessage()...)
 753	}
 754
 755	var files []fantasy.FilePart
 756	for _, attachment := range attachments {
 757		if attachment.IsText() {
 758			continue
 759		}
 760		files = append(files, fantasy.FilePart{
 761			Filename:  attachment.FileName,
 762			Data:      attachment.Content,
 763			MediaType: attachment.MimeType,
 764		})
 765	}
 766
 767	return history, files
 768}
 769
 770func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 771	msgs, err := a.messages.List(ctx, session.ID)
 772	if err != nil {
 773		return nil, fmt.Errorf("failed to list messages: %w", err)
 774	}
 775
 776	if session.SummaryMessageID != "" {
 777		summaryMsgIndex := -1
 778		for i, msg := range msgs {
 779			if msg.ID == session.SummaryMessageID {
 780				summaryMsgIndex = i
 781				break
 782			}
 783		}
 784		if summaryMsgIndex != -1 {
 785			msgs = msgs[summaryMsgIndex:]
 786			msgs[0].Role = message.User
 787		}
 788	}
 789	return msgs, nil
 790}
 791
 792// generateTitle generates a session titled based on the initial prompt.
 793func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 794	if userPrompt == "" {
 795		return
 796	}
 797
 798	smallModel := a.smallModel.Get()
 799	largeModel := a.largeModel.Get()
 800	systemPromptPrefix := a.systemPromptPrefix.Get()
 801
 802	var maxOutputTokens int64 = 40
 803	if smallModel.CatwalkCfg.CanReason {
 804		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 805	}
 806
 807	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 808		return fantasy.NewAgent(m,
 809			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 810			fantasy.WithMaxOutputTokens(tok),
 811			fantasy.WithUserAgent(userAgent),
 812		)
 813	}
 814
 815	streamCall := fantasy.AgentStreamCall{
 816		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 817		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 818			prepared.Messages = opts.Messages
 819			if systemPromptPrefix != "" {
 820				prepared.Messages = append([]fantasy.Message{
 821					fantasy.NewSystemMessage(systemPromptPrefix),
 822				}, prepared.Messages...)
 823			}
 824			return callCtx, prepared, nil
 825		},
 826	}
 827
 828	// Use the small model to generate the title.
 829	model := smallModel
 830	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 831	resp, err := agent.Stream(ctx, streamCall)
 832	if err == nil {
 833		// We successfully generated a title with the small model.
 834		slog.Debug("Generated title with small model")
 835	} else {
 836		// It didn't work. Let's try with the big model.
 837		slog.Error("Error generating title with small model; trying big model", "err", err)
 838		model = largeModel
 839		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 840		resp, err = agent.Stream(ctx, streamCall)
 841		if err == nil {
 842			slog.Debug("Generated title with large model")
 843		} else {
 844			// Welp, the large model didn't work either. Use the default
 845			// session name and return.
 846			slog.Error("Error generating title with large model", "err", err)
 847			saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
 848			if saveErr != nil {
 849				slog.Error("Failed to save session title", "error", saveErr)
 850			}
 851			return
 852		}
 853	}
 854
 855	if resp == nil {
 856		// Actually, we didn't get a response so we can't. Use the default
 857		// session name and return.
 858		slog.Error("Response is nil; can't generate title")
 859		saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
 860		if saveErr != nil {
 861			slog.Error("Failed to save session title", "error", saveErr)
 862		}
 863		return
 864	}
 865
 866	// Clean up title.
 867	var title string
 868	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 869
 870	// Remove thinking tags if present.
 871	title = thinkTagRegex.ReplaceAllString(title, "")
 872
 873	title = strings.TrimSpace(title)
 874	title = cmp.Or(title, DefaultSessionName)
 875
 876	// Calculate usage and cost.
 877	var openrouterCost *float64
 878	for _, step := range resp.Steps {
 879		stepCost := a.openrouterCost(step.ProviderMetadata)
 880		if stepCost != nil {
 881			newCost := *stepCost
 882			if openrouterCost != nil {
 883				newCost += *openrouterCost
 884			}
 885			openrouterCost = &newCost
 886		}
 887	}
 888
 889	modelConfig := model.CatwalkCfg
 890	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 891		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 892		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 893		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 894
 895	// Use override cost if available (e.g., from OpenRouter).
 896	if openrouterCost != nil {
 897		cost = *openrouterCost
 898	}
 899
 900	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 901	completionTokens := resp.TotalUsage.OutputTokens
 902
 903	// Atomically update only title and usage fields to avoid overriding other
 904	// concurrent session updates.
 905	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 906	if saveErr != nil {
 907		slog.Error("Failed to save session title and usage", "error", saveErr)
 908		return
 909	}
 910}
 911
 912func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 913	openrouterMetadata, ok := metadata[openrouter.Name]
 914	if !ok {
 915		return nil
 916	}
 917
 918	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 919	if !ok {
 920		return nil
 921	}
 922	return &opts.Usage.Cost
 923}
 924
 925func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 926	modelConfig := model.CatwalkCfg
 927	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 928		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 929		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 930		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 931
 932	a.eventTokensUsed(session.ID, model, usage, cost)
 933
 934	if overrideCost != nil {
 935		session.Cost += *overrideCost
 936	} else {
 937		session.Cost += cost
 938	}
 939
 940	session.CompletionTokens = usage.OutputTokens
 941	session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
 942}
 943
 944func (a *sessionAgent) Cancel(sessionID string) {
 945	// Cancel regular requests. Don't use Take() here - we need the entry to
 946	// remain in activeRequests so IsBusy() returns true until the goroutine
 947	// fully completes (including error handling that may access the DB).
 948	// The defer in processRequest will clean up the entry.
 949	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 950		slog.Debug("Request cancellation initiated", "session_id", sessionID)
 951		cancel()
 952	}
 953
 954	// Also check for summarize requests.
 955	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 956		slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
 957		cancel()
 958	}
 959
 960	if a.QueuedPrompts(sessionID) > 0 {
 961		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 962		a.messageQueue.Del(sessionID)
 963	}
 964}
 965
 966func (a *sessionAgent) ClearQueue(sessionID string) {
 967	if a.QueuedPrompts(sessionID) > 0 {
 968		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 969		a.messageQueue.Del(sessionID)
 970	}
 971}
 972
 973func (a *sessionAgent) CancelAll() {
 974	if !a.IsBusy() {
 975		return
 976	}
 977	for key := range a.activeRequests.Seq2() {
 978		a.Cancel(key) // key is sessionID
 979	}
 980
 981	timeout := time.After(5 * time.Second)
 982	for a.IsBusy() {
 983		select {
 984		case <-timeout:
 985			return
 986		default:
 987			time.Sleep(200 * time.Millisecond)
 988		}
 989	}
 990}
 991
 992func (a *sessionAgent) IsBusy() bool {
 993	var busy bool
 994	for cancelFunc := range a.activeRequests.Seq() {
 995		if cancelFunc != nil {
 996			busy = true
 997			break
 998		}
 999	}
1000	return busy
1001}
1002
1003func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1004	_, busy := a.activeRequests.Get(sessionID)
1005	return busy
1006}
1007
1008func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1009	l, ok := a.messageQueue.Get(sessionID)
1010	if !ok {
1011		return 0
1012	}
1013	return len(l)
1014}
1015
1016func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1017	l, ok := a.messageQueue.Get(sessionID)
1018	if !ok {
1019		return nil
1020	}
1021	prompts := make([]string, len(l))
1022	for i, call := range l {
1023		prompts[i] = call.Prompt
1024	}
1025	return prompts
1026}
1027
1028func (a *sessionAgent) SetModels(large Model, small Model) {
1029	a.largeModel.Set(large)
1030	a.smallModel.Set(small)
1031}
1032
1033func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1034	a.tools.SetSlice(tools)
1035}
1036
1037func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1038	a.systemPrompt.Set(systemPrompt)
1039}
1040
1041func (a *sessionAgent) Model() Model {
1042	return a.largeModel.Get()
1043}
1044
1045// convertToToolResult converts a fantasy tool result to a message tool result.
1046func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1047	baseResult := message.ToolResult{
1048		ToolCallID: result.ToolCallID,
1049		Name:       result.ToolName,
1050		Metadata:   result.ClientMetadata,
1051	}
1052
1053	switch result.Result.GetType() {
1054	case fantasy.ToolResultContentTypeText:
1055		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1056			baseResult.Content = r.Text
1057		}
1058	case fantasy.ToolResultContentTypeError:
1059		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1060			baseResult.Content = r.Error.Error()
1061			baseResult.IsError = true
1062		}
1063	case fantasy.ToolResultContentTypeMedia:
1064		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1065			content := r.Text
1066			if content == "" {
1067				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1068			}
1069			baseResult.Content = content
1070			baseResult.Data = r.Data
1071			baseResult.MIMEType = r.MediaType
1072		}
1073	}
1074
1075	return baseResult
1076}
1077
1078// workaroundProviderMediaLimitations converts media content in tool results to
1079// user messages for providers that don't natively support images in tool results.
1080//
1081// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1082// don't support sending images/media in tool result messages - they only accept
1083// text in tool results. However, they DO support images in user messages.
1084//
1085// If we send media in tool results to these providers, the API returns an error.
1086//
1087// Solution: For these providers, we:
1088//  1. Replace the media in the tool result with a text placeholder
1089//  2. Inject a user message immediately after with the image as a file attachment
1090//  3. This maintains the tool execution flow while working around API limitations
1091//
1092// Anthropic and Bedrock support images natively in tool results, so we skip
1093// this workaround for them.
1094//
1095// Example transformation:
1096//
1097//	BEFORE: [tool result: image data]
1098//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1099func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1100	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1101		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1102
1103	if providerSupportsMedia {
1104		return messages
1105	}
1106
1107	convertedMessages := make([]fantasy.Message, 0, len(messages))
1108
1109	for _, msg := range messages {
1110		if msg.Role != fantasy.MessageRoleTool {
1111			convertedMessages = append(convertedMessages, msg)
1112			continue
1113		}
1114
1115		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1116		var mediaFiles []fantasy.FilePart
1117
1118		for _, part := range msg.Content {
1119			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1120			if !ok {
1121				textParts = append(textParts, part)
1122				continue
1123			}
1124
1125			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1126				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1127				if err != nil {
1128					slog.Warn("Failed to decode media data", "error", err)
1129					textParts = append(textParts, part)
1130					continue
1131				}
1132
1133				mediaFiles = append(mediaFiles, fantasy.FilePart{
1134					Data:      decoded,
1135					MediaType: media.MediaType,
1136					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1137				})
1138
1139				textParts = append(textParts, fantasy.ToolResultPart{
1140					ToolCallID: toolResult.ToolCallID,
1141					Output: fantasy.ToolResultOutputContentText{
1142						Text: "[Image/media content loaded - see attached file]",
1143					},
1144					ProviderOptions: toolResult.ProviderOptions,
1145				})
1146			} else {
1147				textParts = append(textParts, part)
1148			}
1149		}
1150
1151		convertedMessages = append(convertedMessages, fantasy.Message{
1152			Role:    fantasy.MessageRoleTool,
1153			Content: textParts,
1154		})
1155
1156		if len(mediaFiles) > 0 {
1157			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1158				"Here is the media content from the tool result:",
1159				mediaFiles...,
1160			))
1161		}
1162	}
1163
1164	return convertedMessages
1165}
1166
1167// buildSummaryPrompt constructs the prompt text for session summarization.
1168func buildSummaryPrompt(todos []session.Todo) string {
1169	var sb strings.Builder
1170	sb.WriteString("Provide a detailed summary of our conversation above.")
1171	if len(todos) > 0 {
1172		sb.WriteString("\n\n## Current Todo List\n\n")
1173		for _, t := range todos {
1174			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1175		}
1176		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1177		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1178	}
1179	return sb.String()
1180}