agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/catwalk/pkg/catwalk"
  26	"charm.land/fantasy"
  27	"charm.land/fantasy/providers/anthropic"
  28	"charm.land/fantasy/providers/bedrock"
  29	"charm.land/fantasy/providers/google"
  30	"charm.land/fantasy/providers/openai"
  31	"charm.land/fantasy/providers/openrouter"
  32	"charm.land/fantasy/providers/vercel"
  33	"charm.land/lipgloss/v2"
  34	"github.com/charmbracelet/crush/internal/agent/hyper"
  35	"github.com/charmbracelet/crush/internal/agent/notify"
  36	"github.com/charmbracelet/crush/internal/agent/tools"
  37	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
  38	"github.com/charmbracelet/crush/internal/config"
  39	"github.com/charmbracelet/crush/internal/csync"
  40	"github.com/charmbracelet/crush/internal/message"
  41	"github.com/charmbracelet/crush/internal/permission"
  42	"github.com/charmbracelet/crush/internal/pubsub"
  43	"github.com/charmbracelet/crush/internal/session"
  44	"github.com/charmbracelet/crush/internal/stringext"
  45	"github.com/charmbracelet/crush/internal/version"
  46	"github.com/charmbracelet/x/exp/charmtone"
  47)
  48
  49const (
  50	DefaultSessionName = "Untitled Session"
  51
  52	// Constants for auto-summarization thresholds
  53	largeContextWindowThreshold = 200_000
  54	largeContextWindowBuffer    = 20_000
  55	smallContextWindowRatio     = 0.2
  56)
  57
  58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
  59
  60//go:embed templates/title.md
  61var titlePrompt []byte
  62
  63//go:embed templates/summary.md
  64var summaryPrompt []byte
  65
  66// Used to remove <think> tags from generated titles.
  67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  68
  69type SessionAgentCall struct {
  70	SessionID        string
  71	Prompt           string
  72	ProviderOptions  fantasy.ProviderOptions
  73	Attachments      []message.Attachment
  74	MaxOutputTokens  int64
  75	Temperature      *float64
  76	TopP             *float64
  77	TopK             *int64
  78	FrequencyPenalty *float64
  79	PresencePenalty  *float64
  80	NonInteractive   bool
  81}
  82
  83type SessionAgent interface {
  84	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  85	SetModels(large Model, small Model)
  86	SetTools(tools []fantasy.AgentTool)
  87	SetSystemPrompt(systemPrompt string)
  88	Cancel(sessionID string)
  89	CancelAll()
  90	IsSessionBusy(sessionID string) bool
  91	IsBusy() bool
  92	QueuedPrompts(sessionID string) int
  93	QueuedPromptsList(sessionID string) []string
  94	ClearQueue(sessionID string)
  95	Summarize(context.Context, string, fantasy.ProviderOptions) error
  96	Model() Model
  97}
  98
  99type Model struct {
 100	Model      fantasy.LanguageModel
 101	CatwalkCfg catwalk.Model
 102	ModelCfg   config.SelectedModel
 103}
 104
 105type sessionAgent struct {
 106	largeModel         *csync.Value[Model]
 107	smallModel         *csync.Value[Model]
 108	systemPromptPrefix *csync.Value[string]
 109	systemPrompt       *csync.Value[string]
 110	tools              *csync.Slice[fantasy.AgentTool]
 111
 112	isSubAgent           bool
 113	sessions             session.Service
 114	messages             message.Service
 115	disableAutoSummarize bool
 116	isYolo               bool
 117	notify               pubsub.Publisher[notify.Notification]
 118
 119	messageQueue   *csync.Map[string, []SessionAgentCall]
 120	activeRequests *csync.Map[string, context.CancelFunc]
 121}
 122
 123type SessionAgentOptions struct {
 124	LargeModel           Model
 125	SmallModel           Model
 126	SystemPromptPrefix   string
 127	SystemPrompt         string
 128	IsSubAgent           bool
 129	DisableAutoSummarize bool
 130	IsYolo               bool
 131	Sessions             session.Service
 132	Messages             message.Service
 133	Tools                []fantasy.AgentTool
 134	Notify               pubsub.Publisher[notify.Notification]
 135}
 136
 137func NewSessionAgent(
 138	opts SessionAgentOptions,
 139) SessionAgent {
 140	return &sessionAgent{
 141		largeModel:           csync.NewValue(opts.LargeModel),
 142		smallModel:           csync.NewValue(opts.SmallModel),
 143		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
 144		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 145		isSubAgent:           opts.IsSubAgent,
 146		sessions:             opts.Sessions,
 147		messages:             opts.Messages,
 148		disableAutoSummarize: opts.DisableAutoSummarize,
 149		tools:                csync.NewSliceFrom(opts.Tools),
 150		isYolo:               opts.IsYolo,
 151		notify:               opts.Notify,
 152		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 153		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 154	}
 155}
 156
 157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 158	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 159		return nil, ErrEmptyPrompt
 160	}
 161	if call.SessionID == "" {
 162		return nil, ErrSessionMissing
 163	}
 164
 165	// Queue the message if busy
 166	if a.IsSessionBusy(call.SessionID) {
 167		existing, ok := a.messageQueue.Get(call.SessionID)
 168		if !ok {
 169			existing = []SessionAgentCall{}
 170		}
 171		existing = append(existing, call)
 172		a.messageQueue.Set(call.SessionID, existing)
 173		return nil, nil
 174	}
 175
 176	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
 177	agentTools := a.tools.Copy()
 178	largeModel := a.largeModel.Get()
 179	systemPrompt := a.systemPrompt.Get()
 180	promptPrefix := a.systemPromptPrefix.Get()
 181	var instructions strings.Builder
 182
 183	for _, server := range mcp.GetStates() {
 184		if server.State != mcp.StateConnected {
 185			continue
 186		}
 187		if s := server.Client.InitializeResult().Instructions; s != "" {
 188			instructions.WriteString(s)
 189			instructions.WriteString("\n\n")
 190		}
 191	}
 192
 193	if s := instructions.String(); s != "" {
 194		systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
 195	}
 196
 197	if len(agentTools) > 0 {
 198		// Add Anthropic caching to the last tool.
 199		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 200	}
 201
 202	agent := fantasy.NewAgent(
 203		largeModel.Model,
 204		fantasy.WithSystemPrompt(systemPrompt),
 205		fantasy.WithTools(agentTools...),
 206		fantasy.WithUserAgent(userAgent),
 207	)
 208
 209	sessionLock := sync.Mutex{}
 210	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 211	if err != nil {
 212		return nil, fmt.Errorf("failed to get session: %w", err)
 213	}
 214
 215	msgs, err := a.getSessionMessages(ctx, currentSession)
 216	if err != nil {
 217		return nil, fmt.Errorf("failed to get session messages: %w", err)
 218	}
 219
 220	var wg sync.WaitGroup
 221	// Generate title if first message.
 222	if len(msgs) == 0 {
 223		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 224		wg.Go(func() {
 225			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 226		})
 227	}
 228	defer wg.Wait()
 229
 230	// Add the user message to the session.
 231	_, err = a.createUserMessage(ctx, call)
 232	if err != nil {
 233		return nil, err
 234	}
 235
 236	// Add the session to the context.
 237	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 238
 239	genCtx, cancel := context.WithCancel(ctx)
 240	a.activeRequests.Set(call.SessionID, cancel)
 241
 242	defer cancel()
 243	defer a.activeRequests.Del(call.SessionID)
 244
 245	history, files := a.preparePrompt(msgs, call.Attachments...)
 246
 247	startTime := time.Now()
 248	a.eventPromptSent(call.SessionID)
 249
 250	var currentAssistant *message.Message
 251	var shouldSummarize bool
 252	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 253		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 254		Files:            files,
 255		Messages:         history,
 256		ProviderOptions:  call.ProviderOptions,
 257		MaxOutputTokens:  &call.MaxOutputTokens,
 258		TopP:             call.TopP,
 259		Temperature:      call.Temperature,
 260		PresencePenalty:  call.PresencePenalty,
 261		TopK:             call.TopK,
 262		FrequencyPenalty: call.FrequencyPenalty,
 263		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 264			prepared.Messages = options.Messages
 265			for i := range prepared.Messages {
 266				prepared.Messages[i].ProviderOptions = nil
 267			}
 268
 269			// Use latest tools (updated by SetTools when MCP tools change).
 270			prepared.Tools = a.tools.Copy()
 271
 272			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 273			a.messageQueue.Del(call.SessionID)
 274			for _, queued := range queuedCalls {
 275				userMessage, createErr := a.createUserMessage(callContext, queued)
 276				if createErr != nil {
 277					return callContext, prepared, createErr
 278				}
 279				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 280			}
 281
 282			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 283
 284			lastSystemRoleInx := 0
 285			systemMessageUpdated := false
 286			for i, msg := range prepared.Messages {
 287				// Only add cache control to the last message.
 288				if msg.Role == fantasy.MessageRoleSystem {
 289					lastSystemRoleInx = i
 290				} else if !systemMessageUpdated {
 291					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 292					systemMessageUpdated = true
 293				}
 294				// Than add cache control to the last 2 messages.
 295				if i > len(prepared.Messages)-3 {
 296					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 297				}
 298			}
 299
 300			if promptPrefix != "" {
 301				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 302			}
 303
 304			var assistantMsg message.Message
 305			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 306				Role:     message.Assistant,
 307				Parts:    []message.ContentPart{},
 308				Model:    largeModel.ModelCfg.Model,
 309				Provider: largeModel.ModelCfg.Provider,
 310			})
 311			if err != nil {
 312				return callContext, prepared, err
 313			}
 314			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 315			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
 316			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 317			currentAssistant = &assistantMsg
 318			return callContext, prepared, err
 319		},
 320		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 321			currentAssistant.AppendReasoningContent(reasoning.Text)
 322			return a.messages.Update(genCtx, *currentAssistant)
 323		},
 324		OnReasoningDelta: func(id string, text string) error {
 325			currentAssistant.AppendReasoningContent(text)
 326			return a.messages.Update(genCtx, *currentAssistant)
 327		},
 328		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 329			// handle anthropic signature
 330			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 331				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 332					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 333				}
 334			}
 335			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 336				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 337					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 338				}
 339			}
 340			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 341				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 342					currentAssistant.SetReasoningResponsesData(reasoning)
 343				}
 344			}
 345			currentAssistant.FinishThinking()
 346			return a.messages.Update(genCtx, *currentAssistant)
 347		},
 348		OnTextDelta: func(id string, text string) error {
 349			// Strip leading newline from initial text content. This is is
 350			// particularly important in non-interactive mode where leading
 351			// newlines are very visible.
 352			if len(currentAssistant.Parts) == 0 {
 353				text = strings.TrimPrefix(text, "\n")
 354			}
 355
 356			currentAssistant.AppendContent(text)
 357			return a.messages.Update(genCtx, *currentAssistant)
 358		},
 359		OnToolInputStart: func(id string, toolName string) error {
 360			toolCall := message.ToolCall{
 361				ID:               id,
 362				Name:             toolName,
 363				ProviderExecuted: false,
 364				Finished:         false,
 365			}
 366			currentAssistant.AddToolCall(toolCall)
 367			// Use parent ctx instead of genCtx to ensure the update succeeds
 368			// even if the request is canceled mid-stream
 369			return a.messages.Update(ctx, *currentAssistant)
 370		},
 371		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 372			// TODO: implement
 373		},
 374		OnToolCall: func(tc fantasy.ToolCallContent) error {
 375			toolCall := message.ToolCall{
 376				ID:               tc.ToolCallID,
 377				Name:             tc.ToolName,
 378				Input:            tc.Input,
 379				ProviderExecuted: false,
 380				Finished:         true,
 381			}
 382			currentAssistant.AddToolCall(toolCall)
 383			// Use parent ctx instead of genCtx to ensure the update succeeds
 384			// even if the request is canceled mid-stream
 385			return a.messages.Update(ctx, *currentAssistant)
 386		},
 387		OnToolResult: func(result fantasy.ToolResultContent) error {
 388			toolResult := a.convertToToolResult(result)
 389			// Use parent ctx instead of genCtx to ensure the message is created
 390			// even if the request is canceled mid-stream
 391			_, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 392				Role: message.Tool,
 393				Parts: []message.ContentPart{
 394					toolResult,
 395				},
 396			})
 397			return createMsgErr
 398		},
 399		OnStepFinish: func(stepResult fantasy.StepResult) error {
 400			finishReason := message.FinishReasonUnknown
 401			switch stepResult.FinishReason {
 402			case fantasy.FinishReasonLength:
 403				finishReason = message.FinishReasonMaxTokens
 404			case fantasy.FinishReasonStop:
 405				finishReason = message.FinishReasonEndTurn
 406			case fantasy.FinishReasonToolCalls:
 407				finishReason = message.FinishReasonToolUse
 408			}
 409			currentAssistant.AddFinish(finishReason, "", "")
 410			sessionLock.Lock()
 411			defer sessionLock.Unlock()
 412
 413			updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
 414			if getSessionErr != nil {
 415				return getSessionErr
 416			}
 417			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 418			_, sessionErr := a.sessions.Save(ctx, updatedSession)
 419			if sessionErr != nil {
 420				return sessionErr
 421			}
 422			currentSession = updatedSession
 423			return a.messages.Update(genCtx, *currentAssistant)
 424		},
 425		StopWhen: []fantasy.StopCondition{
 426			func(_ []fantasy.StepResult) bool {
 427				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 428				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 429				remaining := cw - tokens
 430				var threshold int64
 431				if cw > largeContextWindowThreshold {
 432					threshold = largeContextWindowBuffer
 433				} else {
 434					threshold = int64(float64(cw) * smallContextWindowRatio)
 435				}
 436				if (remaining <= threshold) && !a.disableAutoSummarize {
 437					shouldSummarize = true
 438					return true
 439				}
 440				return false
 441			},
 442			func(steps []fantasy.StepResult) bool {
 443				return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
 444			},
 445		},
 446	})
 447
 448	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 449
 450	if err != nil {
 451		isCancelErr := errors.Is(err, context.Canceled)
 452		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 453		if currentAssistant == nil {
 454			return result, err
 455		}
 456		// Ensure we finish thinking on error to close the reasoning state.
 457		currentAssistant.FinishThinking()
 458		toolCalls := currentAssistant.ToolCalls()
 459		// INFO: we use the parent context here because the genCtx has been cancelled.
 460		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 461		if createErr != nil {
 462			return nil, createErr
 463		}
 464		for _, tc := range toolCalls {
 465			if !tc.Finished {
 466				tc.Finished = true
 467				tc.Input = "{}"
 468				currentAssistant.AddToolCall(tc)
 469				updateErr := a.messages.Update(ctx, *currentAssistant)
 470				if updateErr != nil {
 471					return nil, updateErr
 472				}
 473			}
 474
 475			found := false
 476			for _, msg := range msgs {
 477				if msg.Role == message.Tool {
 478					for _, tr := range msg.ToolResults() {
 479						if tr.ToolCallID == tc.ID {
 480							found = true
 481							break
 482						}
 483					}
 484				}
 485				if found {
 486					break
 487				}
 488			}
 489			if found {
 490				continue
 491			}
 492			content := "There was an error while executing the tool"
 493			if isCancelErr {
 494				content = "Error: user cancelled assistant tool calling"
 495			} else if isPermissionErr {
 496				content = "User denied permission"
 497			}
 498			toolResult := message.ToolResult{
 499				ToolCallID: tc.ID,
 500				Name:       tc.Name,
 501				Content:    content,
 502				IsError:    true,
 503			}
 504			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 505				Role: message.Tool,
 506				Parts: []message.ContentPart{
 507					toolResult,
 508				},
 509			})
 510			if createErr != nil {
 511				return nil, createErr
 512			}
 513		}
 514		var fantasyErr *fantasy.Error
 515		var providerErr *fantasy.ProviderError
 516		const defaultTitle = "Provider Error"
 517		linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
 518		if isCancelErr {
 519			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 520		} else if isPermissionErr {
 521			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 522		} else if errors.Is(err, hyper.ErrNoCredits) {
 523			url := hyper.BaseURL()
 524			link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
 525			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 526		} else if errors.As(err, &providerErr) {
 527			if providerErr.Message == "The requested model is not supported." {
 528				url := "https://github.com/settings/copilot/features"
 529				link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
 530				currentAssistant.AddFinish(
 531					message.FinishReasonError,
 532					"Copilot model not enabled",
 533					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 534				)
 535			} else {
 536				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 537			}
 538		} else if errors.As(err, &fantasyErr) {
 539			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 540		} else {
 541			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 542		}
 543		// Note: we use the parent context here because the genCtx has been
 544		// cancelled.
 545		updateErr := a.messages.Update(ctx, *currentAssistant)
 546		if updateErr != nil {
 547			return nil, updateErr
 548		}
 549		return nil, err
 550	}
 551
 552	// Send notification that agent has finished its turn (skip for
 553	// nested/non-interactive sessions).
 554	if !call.NonInteractive && a.notify != nil {
 555		a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
 556			SessionID:    call.SessionID,
 557			SessionTitle: currentSession.Title,
 558			Type:         notify.TypeAgentFinished,
 559		})
 560	}
 561
 562	if shouldSummarize {
 563		a.activeRequests.Del(call.SessionID)
 564		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 565			return nil, summarizeErr
 566		}
 567		// If the agent wasn't done...
 568		if len(currentAssistant.ToolCalls()) > 0 {
 569			existing, ok := a.messageQueue.Get(call.SessionID)
 570			if !ok {
 571				existing = []SessionAgentCall{}
 572			}
 573			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 574			existing = append(existing, call)
 575			a.messageQueue.Set(call.SessionID, existing)
 576		}
 577	}
 578
 579	// Release active request before processing queued messages.
 580	a.activeRequests.Del(call.SessionID)
 581	cancel()
 582
 583	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 584	if !ok || len(queuedMessages) == 0 {
 585		return result, err
 586	}
 587	// There are queued messages restart the loop.
 588	firstQueuedMessage := queuedMessages[0]
 589	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 590	return a.Run(ctx, firstQueuedMessage)
 591}
 592
 593func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 594	if a.IsSessionBusy(sessionID) {
 595		return ErrSessionBusy
 596	}
 597
 598	// Copy mutable fields under lock to avoid races with SetModels.
 599	largeModel := a.largeModel.Get()
 600	systemPromptPrefix := a.systemPromptPrefix.Get()
 601
 602	currentSession, err := a.sessions.Get(ctx, sessionID)
 603	if err != nil {
 604		return fmt.Errorf("failed to get session: %w", err)
 605	}
 606	msgs, err := a.getSessionMessages(ctx, currentSession)
 607	if err != nil {
 608		return err
 609	}
 610	if len(msgs) == 0 {
 611		// Nothing to summarize.
 612		return nil
 613	}
 614
 615	aiMsgs, _ := a.preparePrompt(msgs)
 616
 617	genCtx, cancel := context.WithCancel(ctx)
 618	a.activeRequests.Set(sessionID, cancel)
 619	defer a.activeRequests.Del(sessionID)
 620	defer cancel()
 621
 622	agent := fantasy.NewAgent(largeModel.Model,
 623		fantasy.WithSystemPrompt(string(summaryPrompt)),
 624		fantasy.WithUserAgent(userAgent),
 625	)
 626	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 627		Role:             message.Assistant,
 628		Model:            largeModel.Model.Model(),
 629		Provider:         largeModel.Model.Provider(),
 630		IsSummaryMessage: true,
 631	})
 632	if err != nil {
 633		return err
 634	}
 635
 636	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 637
 638	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 639		Prompt:          summaryPromptText,
 640		Messages:        aiMsgs,
 641		ProviderOptions: opts,
 642		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 643			prepared.Messages = options.Messages
 644			if systemPromptPrefix != "" {
 645				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 646			}
 647			return callContext, prepared, nil
 648		},
 649		OnReasoningDelta: func(id string, text string) error {
 650			summaryMessage.AppendReasoningContent(text)
 651			return a.messages.Update(genCtx, summaryMessage)
 652		},
 653		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 654			// Handle anthropic signature.
 655			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 656				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 657					summaryMessage.AppendReasoningSignature(signature.Signature)
 658				}
 659			}
 660			summaryMessage.FinishThinking()
 661			return a.messages.Update(genCtx, summaryMessage)
 662		},
 663		OnTextDelta: func(id, text string) error {
 664			summaryMessage.AppendContent(text)
 665			return a.messages.Update(genCtx, summaryMessage)
 666		},
 667	})
 668	if err != nil {
 669		isCancelErr := errors.Is(err, context.Canceled)
 670		if isCancelErr {
 671			// User cancelled summarize we need to remove the summary message.
 672			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 673			return deleteErr
 674		}
 675		return err
 676	}
 677
 678	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 679	err = a.messages.Update(genCtx, summaryMessage)
 680	if err != nil {
 681		return err
 682	}
 683
 684	var openrouterCost *float64
 685	for _, step := range resp.Steps {
 686		stepCost := a.openrouterCost(step.ProviderMetadata)
 687		if stepCost != nil {
 688			newCost := *stepCost
 689			if openrouterCost != nil {
 690				newCost += *openrouterCost
 691			}
 692			openrouterCost = &newCost
 693		}
 694	}
 695
 696	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 697
 698	// Just in case, get just the last usage info.
 699	usage := resp.Response.Usage
 700	currentSession.SummaryMessageID = summaryMessage.ID
 701	currentSession.CompletionTokens = usage.OutputTokens
 702	currentSession.PromptTokens = 0
 703	_, err = a.sessions.Save(genCtx, currentSession)
 704	return err
 705}
 706
 707func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 708	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 709		return fantasy.ProviderOptions{}
 710	}
 711	return fantasy.ProviderOptions{
 712		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 713			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 714		},
 715		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 716			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 717		},
 718		vercel.Name: &anthropic.ProviderCacheControlOptions{
 719			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 720		},
 721	}
 722}
 723
 724func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 725	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 726	var attachmentParts []message.ContentPart
 727	for _, attachment := range call.Attachments {
 728		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 729	}
 730	parts = append(parts, attachmentParts...)
 731	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 732		Role:  message.User,
 733		Parts: parts,
 734	})
 735	if err != nil {
 736		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 737	}
 738	return msg, nil
 739}
 740
 741func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 742	var history []fantasy.Message
 743	if !a.isSubAgent {
 744		history = append(history, fantasy.NewUserMessage(
 745			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 746				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 747If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 748If not, please feel free to ignore. Again do not mention this message to the user.`,
 749			),
 750		))
 751	}
 752	for _, m := range msgs {
 753		if len(m.Parts) == 0 {
 754			continue
 755		}
 756		// Assistant message without content or tool calls (cancelled before it
 757		// returned anything).
 758		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 759			continue
 760		}
 761		history = append(history, m.ToAIMessage()...)
 762	}
 763
 764	var files []fantasy.FilePart
 765	for _, attachment := range attachments {
 766		if attachment.IsText() {
 767			continue
 768		}
 769		files = append(files, fantasy.FilePart{
 770			Filename:  attachment.FileName,
 771			Data:      attachment.Content,
 772			MediaType: attachment.MimeType,
 773		})
 774	}
 775
 776	return history, files
 777}
 778
 779func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 780	msgs, err := a.messages.List(ctx, session.ID)
 781	if err != nil {
 782		return nil, fmt.Errorf("failed to list messages: %w", err)
 783	}
 784
 785	if session.SummaryMessageID != "" {
 786		summaryMsgIndex := -1
 787		for i, msg := range msgs {
 788			if msg.ID == session.SummaryMessageID {
 789				summaryMsgIndex = i
 790				break
 791			}
 792		}
 793		if summaryMsgIndex != -1 {
 794			msgs = msgs[summaryMsgIndex:]
 795			msgs[0].Role = message.User
 796		}
 797	}
 798	return msgs, nil
 799}
 800
 801// generateTitle generates a session titled based on the initial prompt.
 802func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 803	if userPrompt == "" {
 804		return
 805	}
 806
 807	smallModel := a.smallModel.Get()
 808	largeModel := a.largeModel.Get()
 809	systemPromptPrefix := a.systemPromptPrefix.Get()
 810
 811	var maxOutputTokens int64 = 40
 812	if smallModel.CatwalkCfg.CanReason {
 813		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 814	}
 815
 816	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 817		return fantasy.NewAgent(m,
 818			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 819			fantasy.WithMaxOutputTokens(tok),
 820			fantasy.WithUserAgent(userAgent),
 821		)
 822	}
 823
 824	streamCall := fantasy.AgentStreamCall{
 825		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 826		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 827			prepared.Messages = opts.Messages
 828			if systemPromptPrefix != "" {
 829				prepared.Messages = append([]fantasy.Message{
 830					fantasy.NewSystemMessage(systemPromptPrefix),
 831				}, prepared.Messages...)
 832			}
 833			return callCtx, prepared, nil
 834		},
 835	}
 836
 837	// Use the small model to generate the title.
 838	model := smallModel
 839	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 840	resp, err := agent.Stream(ctx, streamCall)
 841	if err == nil {
 842		// We successfully generated a title with the small model.
 843		slog.Debug("Generated title with small model")
 844	} else {
 845		// It didn't work. Let's try with the big model.
 846		slog.Error("Error generating title with small model; trying big model", "err", err)
 847		model = largeModel
 848		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 849		resp, err = agent.Stream(ctx, streamCall)
 850		if err == nil {
 851			slog.Debug("Generated title with large model")
 852		} else {
 853			// Welp, the large model didn't work either. Use the default
 854			// session name and return.
 855			slog.Error("Error generating title with large model", "err", err)
 856			saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
 857			if saveErr != nil {
 858				slog.Error("Failed to save session title", "error", saveErr)
 859			}
 860			return
 861		}
 862	}
 863
 864	if resp == nil {
 865		// Actually, we didn't get a response so we can't. Use the default
 866		// session name and return.
 867		slog.Error("Response is nil; can't generate title")
 868		saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
 869		if saveErr != nil {
 870			slog.Error("Failed to save session title", "error", saveErr)
 871		}
 872		return
 873	}
 874
 875	// Clean up title.
 876	var title string
 877	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 878
 879	// Remove thinking tags if present.
 880	title = thinkTagRegex.ReplaceAllString(title, "")
 881
 882	title = strings.TrimSpace(title)
 883	title = cmp.Or(title, DefaultSessionName)
 884
 885	// Calculate usage and cost.
 886	var openrouterCost *float64
 887	for _, step := range resp.Steps {
 888		stepCost := a.openrouterCost(step.ProviderMetadata)
 889		if stepCost != nil {
 890			newCost := *stepCost
 891			if openrouterCost != nil {
 892				newCost += *openrouterCost
 893			}
 894			openrouterCost = &newCost
 895		}
 896	}
 897
 898	modelConfig := model.CatwalkCfg
 899	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 900		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 901		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 902		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 903
 904	// Use override cost if available (e.g., from OpenRouter).
 905	if openrouterCost != nil {
 906		cost = *openrouterCost
 907	}
 908
 909	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 910	completionTokens := resp.TotalUsage.OutputTokens
 911
 912	// Atomically update only title and usage fields to avoid overriding other
 913	// concurrent session updates.
 914	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 915	if saveErr != nil {
 916		slog.Error("Failed to save session title and usage", "error", saveErr)
 917		return
 918	}
 919}
 920
 921func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 922	openrouterMetadata, ok := metadata[openrouter.Name]
 923	if !ok {
 924		return nil
 925	}
 926
 927	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 928	if !ok {
 929		return nil
 930	}
 931	return &opts.Usage.Cost
 932}
 933
 934func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 935	modelConfig := model.CatwalkCfg
 936	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 937		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 938		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 939		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 940
 941	a.eventTokensUsed(session.ID, model, usage, cost)
 942
 943	if overrideCost != nil {
 944		session.Cost += *overrideCost
 945	} else {
 946		session.Cost += cost
 947	}
 948
 949	session.CompletionTokens = usage.OutputTokens
 950	session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
 951}
 952
 953func (a *sessionAgent) Cancel(sessionID string) {
 954	// Cancel regular requests. Don't use Take() here - we need the entry to
 955	// remain in activeRequests so IsBusy() returns true until the goroutine
 956	// fully completes (including error handling that may access the DB).
 957	// The defer in processRequest will clean up the entry.
 958	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 959		slog.Debug("Request cancellation initiated", "session_id", sessionID)
 960		cancel()
 961	}
 962
 963	// Also check for summarize requests.
 964	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 965		slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
 966		cancel()
 967	}
 968
 969	if a.QueuedPrompts(sessionID) > 0 {
 970		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 971		a.messageQueue.Del(sessionID)
 972	}
 973}
 974
 975func (a *sessionAgent) ClearQueue(sessionID string) {
 976	if a.QueuedPrompts(sessionID) > 0 {
 977		slog.Debug("Clearing queued prompts", "session_id", sessionID)
 978		a.messageQueue.Del(sessionID)
 979	}
 980}
 981
 982func (a *sessionAgent) CancelAll() {
 983	if !a.IsBusy() {
 984		return
 985	}
 986	for key := range a.activeRequests.Seq2() {
 987		a.Cancel(key) // key is sessionID
 988	}
 989
 990	timeout := time.After(5 * time.Second)
 991	for a.IsBusy() {
 992		select {
 993		case <-timeout:
 994			return
 995		default:
 996			time.Sleep(200 * time.Millisecond)
 997		}
 998	}
 999}
1000
1001func (a *sessionAgent) IsBusy() bool {
1002	var busy bool
1003	for cancelFunc := range a.activeRequests.Seq() {
1004		if cancelFunc != nil {
1005			busy = true
1006			break
1007		}
1008	}
1009	return busy
1010}
1011
1012func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1013	_, busy := a.activeRequests.Get(sessionID)
1014	return busy
1015}
1016
1017func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1018	l, ok := a.messageQueue.Get(sessionID)
1019	if !ok {
1020		return 0
1021	}
1022	return len(l)
1023}
1024
1025func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1026	l, ok := a.messageQueue.Get(sessionID)
1027	if !ok {
1028		return nil
1029	}
1030	prompts := make([]string, len(l))
1031	for i, call := range l {
1032		prompts[i] = call.Prompt
1033	}
1034	return prompts
1035}
1036
1037func (a *sessionAgent) SetModels(large Model, small Model) {
1038	a.largeModel.Set(large)
1039	a.smallModel.Set(small)
1040}
1041
1042func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1043	a.tools.SetSlice(tools)
1044}
1045
1046func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1047	a.systemPrompt.Set(systemPrompt)
1048}
1049
1050func (a *sessionAgent) Model() Model {
1051	return a.largeModel.Get()
1052}
1053
1054// convertToToolResult converts a fantasy tool result to a message tool result.
1055func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1056	baseResult := message.ToolResult{
1057		ToolCallID: result.ToolCallID,
1058		Name:       result.ToolName,
1059		Metadata:   result.ClientMetadata,
1060	}
1061
1062	switch result.Result.GetType() {
1063	case fantasy.ToolResultContentTypeText:
1064		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1065			baseResult.Content = r.Text
1066		}
1067	case fantasy.ToolResultContentTypeError:
1068		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1069			baseResult.Content = r.Error.Error()
1070			baseResult.IsError = true
1071		}
1072	case fantasy.ToolResultContentTypeMedia:
1073		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1074			content := r.Text
1075			if content == "" {
1076				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1077			}
1078			baseResult.Content = content
1079			baseResult.Data = r.Data
1080			baseResult.MIMEType = r.MediaType
1081		}
1082	}
1083
1084	return baseResult
1085}
1086
1087// workaroundProviderMediaLimitations converts media content in tool results to
1088// user messages for providers that don't natively support images in tool results.
1089//
1090// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1091// don't support sending images/media in tool result messages - they only accept
1092// text in tool results. However, they DO support images in user messages.
1093//
1094// If we send media in tool results to these providers, the API returns an error.
1095//
1096// Solution: For these providers, we:
1097//  1. Replace the media in the tool result with a text placeholder
1098//  2. Inject a user message immediately after with the image as a file attachment
1099//  3. This maintains the tool execution flow while working around API limitations
1100//
1101// Anthropic and Bedrock support images natively in tool results, so we skip
1102// this workaround for them.
1103//
1104// Example transformation:
1105//
1106//	BEFORE: [tool result: image data]
1107//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1108func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1109	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1110		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1111
1112	if providerSupportsMedia {
1113		return messages
1114	}
1115
1116	convertedMessages := make([]fantasy.Message, 0, len(messages))
1117
1118	for _, msg := range messages {
1119		if msg.Role != fantasy.MessageRoleTool {
1120			convertedMessages = append(convertedMessages, msg)
1121			continue
1122		}
1123
1124		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1125		var mediaFiles []fantasy.FilePart
1126
1127		for _, part := range msg.Content {
1128			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1129			if !ok {
1130				textParts = append(textParts, part)
1131				continue
1132			}
1133
1134			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1135				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1136				if err != nil {
1137					slog.Warn("Failed to decode media data", "error", err)
1138					textParts = append(textParts, part)
1139					continue
1140				}
1141
1142				mediaFiles = append(mediaFiles, fantasy.FilePart{
1143					Data:      decoded,
1144					MediaType: media.MediaType,
1145					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1146				})
1147
1148				textParts = append(textParts, fantasy.ToolResultPart{
1149					ToolCallID: toolResult.ToolCallID,
1150					Output: fantasy.ToolResultOutputContentText{
1151						Text: "[Image/media content loaded - see attached file]",
1152					},
1153					ProviderOptions: toolResult.ProviderOptions,
1154				})
1155			} else {
1156				textParts = append(textParts, part)
1157			}
1158		}
1159
1160		convertedMessages = append(convertedMessages, fantasy.Message{
1161			Role:    fantasy.MessageRoleTool,
1162			Content: textParts,
1163		})
1164
1165		if len(mediaFiles) > 0 {
1166			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1167				"Here is the media content from the tool result:",
1168				mediaFiles...,
1169			))
1170		}
1171	}
1172
1173	return convertedMessages
1174}
1175
1176// buildSummaryPrompt constructs the prompt text for session summarization.
1177func buildSummaryPrompt(todos []session.Todo) string {
1178	var sb strings.Builder
1179	sb.WriteString("Provide a detailed summary of our conversation above.")
1180	if len(todos) > 0 {
1181		sb.WriteString("\n\n## Current Todo List\n\n")
1182		for _, t := range todos {
1183			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1184		}
1185		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1186		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1187	}
1188	return sb.String()
1189}