agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	SetSystemPrompt(systemPrompt string)
  72	Cancel(sessionID string)
  73	CancelAll()
  74	IsSessionBusy(sessionID string) bool
  75	IsBusy() bool
  76	QueuedPrompts(sessionID string) int
  77	QueuedPromptsList(sessionID string) []string
  78	ClearQueue(sessionID string)
  79	Summarize(context.Context, string, fantasy.ProviderOptions) error
  80	Model() Model
  81}
  82
  83type Model struct {
  84	Model      fantasy.LanguageModel
  85	CatwalkCfg catwalk.Model
  86	ModelCfg   config.SelectedModel
  87}
  88
  89type sessionAgent struct {
  90	largeModel           Model
  91	smallModel           Model
  92	systemPromptPrefix   string
  93	systemPrompt         string
  94	isSubAgent           bool
  95	tools                []fantasy.AgentTool
  96	sessions             session.Service
  97	messages             message.Service
  98	disableAutoSummarize bool
  99	isYolo               bool
 100
 101	messageQueue   *csync.Map[string, []SessionAgentCall]
 102	activeRequests *csync.Map[string, context.CancelFunc]
 103}
 104
 105type SessionAgentOptions struct {
 106	LargeModel           Model
 107	SmallModel           Model
 108	SystemPromptPrefix   string
 109	SystemPrompt         string
 110	IsSubAgent           bool
 111	DisableAutoSummarize bool
 112	IsYolo               bool
 113	Sessions             session.Service
 114	Messages             message.Service
 115	Tools                []fantasy.AgentTool
 116}
 117
 118func NewSessionAgent(
 119	opts SessionAgentOptions,
 120) SessionAgent {
 121	return &sessionAgent{
 122		largeModel:           opts.LargeModel,
 123		smallModel:           opts.SmallModel,
 124		systemPromptPrefix:   opts.SystemPromptPrefix,
 125		systemPrompt:         opts.SystemPrompt,
 126		isSubAgent:           opts.IsSubAgent,
 127		sessions:             opts.Sessions,
 128		messages:             opts.Messages,
 129		disableAutoSummarize: opts.DisableAutoSummarize,
 130		tools:                opts.Tools,
 131		isYolo:               opts.IsYolo,
 132		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 133		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 134	}
 135}
 136
 137func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 138	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 139		return nil, ErrEmptyPrompt
 140	}
 141	if call.SessionID == "" {
 142		return nil, ErrSessionMissing
 143	}
 144
 145	// Queue the message if busy
 146	if a.IsSessionBusy(call.SessionID) {
 147		existing, ok := a.messageQueue.Get(call.SessionID)
 148		if !ok {
 149			existing = []SessionAgentCall{}
 150		}
 151		existing = append(existing, call)
 152		a.messageQueue.Set(call.SessionID, existing)
 153		return nil, nil
 154	}
 155
 156	if len(a.tools) > 0 {
 157		// Add Anthropic caching to the last tool.
 158		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 159	}
 160
 161	agent := fantasy.NewAgent(
 162		a.largeModel.Model,
 163		fantasy.WithSystemPrompt(a.systemPrompt),
 164		fantasy.WithTools(a.tools...),
 165	)
 166
 167	sessionLock := sync.Mutex{}
 168	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 169	if err != nil {
 170		return nil, fmt.Errorf("failed to get session: %w", err)
 171	}
 172
 173	msgs, err := a.getSessionMessages(ctx, currentSession)
 174	if err != nil {
 175		return nil, fmt.Errorf("failed to get session messages: %w", err)
 176	}
 177
 178	var wg sync.WaitGroup
 179	// Generate title if first message.
 180	if len(msgs) == 0 {
 181		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 182		wg.Go(func() {
 183			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 184		})
 185	}
 186	defer wg.Wait()
 187
 188	// Add the user message to the session.
 189	_, err = a.createUserMessage(ctx, call)
 190	if err != nil {
 191		return nil, err
 192	}
 193
 194	// Add the session to the context.
 195	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 196
 197	genCtx, cancel := context.WithCancel(ctx)
 198	a.activeRequests.Set(call.SessionID, cancel)
 199
 200	defer cancel()
 201	defer a.activeRequests.Del(call.SessionID)
 202
 203	history, files := a.preparePrompt(msgs, call.Attachments...)
 204
 205	startTime := time.Now()
 206	a.eventPromptSent(call.SessionID)
 207
 208	var currentAssistant *message.Message
 209	var shouldSummarize bool
 210	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 211		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 212		Files:            files,
 213		Messages:         history,
 214		ProviderOptions:  call.ProviderOptions,
 215		MaxOutputTokens:  &call.MaxOutputTokens,
 216		TopP:             call.TopP,
 217		Temperature:      call.Temperature,
 218		PresencePenalty:  call.PresencePenalty,
 219		TopK:             call.TopK,
 220		FrequencyPenalty: call.FrequencyPenalty,
 221		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 222			prepared.Messages = options.Messages
 223			for i := range prepared.Messages {
 224				prepared.Messages[i].ProviderOptions = nil
 225			}
 226
 227			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 228			a.messageQueue.Del(call.SessionID)
 229			for _, queued := range queuedCalls {
 230				userMessage, createErr := a.createUserMessage(callContext, queued)
 231				if createErr != nil {
 232					return callContext, prepared, createErr
 233				}
 234				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 235			}
 236
 237			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 238
 239			lastSystemRoleInx := 0
 240			systemMessageUpdated := false
 241			for i, msg := range prepared.Messages {
 242				// Only add cache control to the last message.
 243				if msg.Role == fantasy.MessageRoleSystem {
 244					lastSystemRoleInx = i
 245				} else if !systemMessageUpdated {
 246					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 247					systemMessageUpdated = true
 248				}
 249				// Than add cache control to the last 2 messages.
 250				if i > len(prepared.Messages)-3 {
 251					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 252				}
 253			}
 254
 255			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 256				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 257			}
 258
 259			var assistantMsg message.Message
 260			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 261				Role:     message.Assistant,
 262				Parts:    []message.ContentPart{},
 263				Model:    a.largeModel.ModelCfg.Model,
 264				Provider: a.largeModel.ModelCfg.Provider,
 265			})
 266			if err != nil {
 267				return callContext, prepared, err
 268			}
 269			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 270			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 271			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 272			currentAssistant = &assistantMsg
 273			return callContext, prepared, err
 274		},
 275		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 276			currentAssistant.AppendReasoningContent(reasoning.Text)
 277			return a.messages.Update(genCtx, *currentAssistant)
 278		},
 279		OnReasoningDelta: func(id string, text string) error {
 280			currentAssistant.AppendReasoningContent(text)
 281			return a.messages.Update(genCtx, *currentAssistant)
 282		},
 283		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 284			// handle anthropic signature
 285			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 286				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 287					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 288				}
 289			}
 290			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 291				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 292					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 293				}
 294			}
 295			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 296				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 297					currentAssistant.SetReasoningResponsesData(reasoning)
 298				}
 299			}
 300			currentAssistant.FinishThinking()
 301			return a.messages.Update(genCtx, *currentAssistant)
 302		},
 303		OnTextDelta: func(id string, text string) error {
 304			// Strip leading newline from initial text content. This is is
 305			// particularly important in non-interactive mode where leading
 306			// newlines are very visible.
 307			if len(currentAssistant.Parts) == 0 {
 308				text = strings.TrimPrefix(text, "\n")
 309			}
 310
 311			currentAssistant.AppendContent(text)
 312			return a.messages.Update(genCtx, *currentAssistant)
 313		},
 314		OnToolInputStart: func(id string, toolName string) error {
 315			toolCall := message.ToolCall{
 316				ID:               id,
 317				Name:             toolName,
 318				ProviderExecuted: false,
 319				Finished:         false,
 320			}
 321			currentAssistant.AddToolCall(toolCall)
 322			return a.messages.Update(genCtx, *currentAssistant)
 323		},
 324		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 325			// TODO: implement
 326		},
 327		OnToolCall: func(tc fantasy.ToolCallContent) error {
 328			toolCall := message.ToolCall{
 329				ID:               tc.ToolCallID,
 330				Name:             tc.ToolName,
 331				Input:            tc.Input,
 332				ProviderExecuted: false,
 333				Finished:         true,
 334			}
 335			currentAssistant.AddToolCall(toolCall)
 336			return a.messages.Update(genCtx, *currentAssistant)
 337		},
 338		OnToolResult: func(result fantasy.ToolResultContent) error {
 339			toolResult := a.convertToToolResult(result)
 340			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 341				Role: message.Tool,
 342				Parts: []message.ContentPart{
 343					toolResult,
 344				},
 345			})
 346			return createMsgErr
 347		},
 348		OnStepFinish: func(stepResult fantasy.StepResult) error {
 349			finishReason := message.FinishReasonUnknown
 350			switch stepResult.FinishReason {
 351			case fantasy.FinishReasonLength:
 352				finishReason = message.FinishReasonMaxTokens
 353			case fantasy.FinishReasonStop:
 354				finishReason = message.FinishReasonEndTurn
 355			case fantasy.FinishReasonToolCalls:
 356				finishReason = message.FinishReasonToolUse
 357			}
 358			currentAssistant.AddFinish(finishReason, "", "")
 359			sessionLock.Lock()
 360			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 361			if getSessionErr != nil {
 362				sessionLock.Unlock()
 363				return getSessionErr
 364			}
 365			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 366			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 367			sessionLock.Unlock()
 368			if sessionErr != nil {
 369				return sessionErr
 370			}
 371			return a.messages.Update(genCtx, *currentAssistant)
 372		},
 373		StopWhen: []fantasy.StopCondition{
 374			func(_ []fantasy.StepResult) bool {
 375				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 376				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 377				remaining := cw - tokens
 378				var threshold int64
 379				if cw > 200_000 {
 380					threshold = 20_000
 381				} else {
 382					threshold = int64(float64(cw) * 0.2)
 383				}
 384				if (remaining <= threshold) && !a.disableAutoSummarize {
 385					shouldSummarize = true
 386					return true
 387				}
 388				return false
 389			},
 390		},
 391	})
 392
 393	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 394
 395	if err != nil {
 396		isCancelErr := errors.Is(err, context.Canceled)
 397		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 398		if currentAssistant == nil {
 399			return result, err
 400		}
 401		// Ensure we finish thinking on error to close the reasoning state.
 402		currentAssistant.FinishThinking()
 403		toolCalls := currentAssistant.ToolCalls()
 404		// INFO: we use the parent context here because the genCtx has been cancelled.
 405		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 406		if createErr != nil {
 407			return nil, createErr
 408		}
 409		for _, tc := range toolCalls {
 410			if !tc.Finished {
 411				tc.Finished = true
 412				tc.Input = "{}"
 413				currentAssistant.AddToolCall(tc)
 414				updateErr := a.messages.Update(ctx, *currentAssistant)
 415				if updateErr != nil {
 416					return nil, updateErr
 417				}
 418			}
 419
 420			found := false
 421			for _, msg := range msgs {
 422				if msg.Role == message.Tool {
 423					for _, tr := range msg.ToolResults() {
 424						if tr.ToolCallID == tc.ID {
 425							found = true
 426							break
 427						}
 428					}
 429				}
 430				if found {
 431					break
 432				}
 433			}
 434			if found {
 435				continue
 436			}
 437			content := "There was an error while executing the tool"
 438			if isCancelErr {
 439				content = "Tool execution canceled by user"
 440			} else if isPermissionErr {
 441				content = "User denied permission"
 442			}
 443			toolResult := message.ToolResult{
 444				ToolCallID: tc.ID,
 445				Name:       tc.Name,
 446				Content:    content,
 447				IsError:    true,
 448			}
 449			_, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
 450				Role: message.Tool,
 451				Parts: []message.ContentPart{
 452					toolResult,
 453				},
 454			})
 455			if createErr != nil {
 456				return nil, createErr
 457			}
 458		}
 459		var fantasyErr *fantasy.Error
 460		var providerErr *fantasy.ProviderError
 461		const defaultTitle = "Provider Error"
 462		if isCancelErr {
 463			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 464		} else if isPermissionErr {
 465			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 466		} else if errors.Is(err, hyper.ErrNoCredits) {
 467			url := hyper.BaseURL()
 468			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 469			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 470		} else if errors.As(err, &providerErr) {
 471			if providerErr.Message == "The requested model is not supported." {
 472				url := "https://github.com/settings/copilot/features"
 473				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 474				currentAssistant.AddFinish(
 475					message.FinishReasonError,
 476					"Copilot model not enabled",
 477					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 478				)
 479			} else {
 480				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 481			}
 482		} else if errors.As(err, &fantasyErr) {
 483			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 484		} else {
 485			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 486		}
 487		// Note: we use the parent context here because the genCtx has been
 488		// cancelled.
 489		updateErr := a.messages.Update(ctx, *currentAssistant)
 490		if updateErr != nil {
 491			return nil, updateErr
 492		}
 493		return nil, err
 494	}
 495
 496	if shouldSummarize {
 497		a.activeRequests.Del(call.SessionID)
 498		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 499			return nil, summarizeErr
 500		}
 501		// If the agent wasn't done...
 502		if len(currentAssistant.ToolCalls()) > 0 {
 503			existing, ok := a.messageQueue.Get(call.SessionID)
 504			if !ok {
 505				existing = []SessionAgentCall{}
 506			}
 507			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 508			existing = append(existing, call)
 509			a.messageQueue.Set(call.SessionID, existing)
 510		}
 511	}
 512
 513	// Release active request before processing queued messages.
 514	a.activeRequests.Del(call.SessionID)
 515	cancel()
 516
 517	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 518	if !ok || len(queuedMessages) == 0 {
 519		return result, err
 520	}
 521	// There are queued messages restart the loop.
 522	firstQueuedMessage := queuedMessages[0]
 523	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 524	return a.Run(ctx, firstQueuedMessage)
 525}
 526
 527func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 528	if a.IsSessionBusy(sessionID) {
 529		return ErrSessionBusy
 530	}
 531
 532	currentSession, err := a.sessions.Get(ctx, sessionID)
 533	if err != nil {
 534		return fmt.Errorf("failed to get session: %w", err)
 535	}
 536	msgs, err := a.getSessionMessages(ctx, currentSession)
 537	if err != nil {
 538		return err
 539	}
 540	if len(msgs) == 0 {
 541		// Nothing to summarize.
 542		return nil
 543	}
 544
 545	aiMsgs, _ := a.preparePrompt(msgs)
 546
 547	genCtx, cancel := context.WithCancel(ctx)
 548	a.activeRequests.Set(sessionID, cancel)
 549	defer a.activeRequests.Del(sessionID)
 550	defer cancel()
 551
 552	agent := fantasy.NewAgent(a.largeModel.Model,
 553		fantasy.WithSystemPrompt(string(summaryPrompt)),
 554	)
 555	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 556		Role:             message.Assistant,
 557		Model:            a.largeModel.Model.Model(),
 558		Provider:         a.largeModel.Model.Provider(),
 559		IsSummaryMessage: true,
 560	})
 561	if err != nil {
 562		return err
 563	}
 564
 565	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 566
 567	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 568		Prompt:          summaryPromptText,
 569		Messages:        aiMsgs,
 570		ProviderOptions: opts,
 571		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 572			prepared.Messages = options.Messages
 573			if a.systemPromptPrefix != "" {
 574				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 575			}
 576			return callContext, prepared, nil
 577		},
 578		OnReasoningDelta: func(id string, text string) error {
 579			summaryMessage.AppendReasoningContent(text)
 580			return a.messages.Update(genCtx, summaryMessage)
 581		},
 582		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 583			// Handle anthropic signature.
 584			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 585				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 586					summaryMessage.AppendReasoningSignature(signature.Signature)
 587				}
 588			}
 589			summaryMessage.FinishThinking()
 590			return a.messages.Update(genCtx, summaryMessage)
 591		},
 592		OnTextDelta: func(id, text string) error {
 593			summaryMessage.AppendContent(text)
 594			return a.messages.Update(genCtx, summaryMessage)
 595		},
 596	})
 597	if err != nil {
 598		isCancelErr := errors.Is(err, context.Canceled)
 599		if isCancelErr {
 600			// User cancelled summarize we need to remove the summary message.
 601			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 602			return deleteErr
 603		}
 604		return err
 605	}
 606
 607	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 608	err = a.messages.Update(genCtx, summaryMessage)
 609	if err != nil {
 610		return err
 611	}
 612
 613	var openrouterCost *float64
 614	for _, step := range resp.Steps {
 615		stepCost := a.openrouterCost(step.ProviderMetadata)
 616		if stepCost != nil {
 617			newCost := *stepCost
 618			if openrouterCost != nil {
 619				newCost += *openrouterCost
 620			}
 621			openrouterCost = &newCost
 622		}
 623	}
 624
 625	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 626
 627	// Just in case, get just the last usage info.
 628	usage := resp.Response.Usage
 629	currentSession.SummaryMessageID = summaryMessage.ID
 630	currentSession.CompletionTokens = usage.OutputTokens
 631	currentSession.PromptTokens = 0
 632	_, err = a.sessions.Save(genCtx, currentSession)
 633	return err
 634}
 635
 636func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 637	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 638		return fantasy.ProviderOptions{}
 639	}
 640	return fantasy.ProviderOptions{
 641		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 642			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 643		},
 644		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 645			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 646		},
 647	}
 648}
 649
 650func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 651	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 652	var attachmentParts []message.ContentPart
 653	for _, attachment := range call.Attachments {
 654		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 655	}
 656	parts = append(parts, attachmentParts...)
 657	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 658		Role:  message.User,
 659		Parts: parts,
 660	})
 661	if err != nil {
 662		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 663	}
 664	return msg, nil
 665}
 666
 667func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 668	var history []fantasy.Message
 669	if !a.isSubAgent {
 670		history = append(history, fantasy.NewUserMessage(
 671			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 672				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 673If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 674If not, please feel free to ignore. Again do not mention this message to the user.`,
 675			),
 676		))
 677	}
 678	for _, m := range msgs {
 679		if len(m.Parts) == 0 {
 680			continue
 681		}
 682		// Assistant message without content or tool calls (cancelled before it
 683		// returned anything).
 684		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 685			continue
 686		}
 687		history = append(history, m.ToAIMessage()...)
 688	}
 689
 690	var files []fantasy.FilePart
 691	for _, attachment := range attachments {
 692		if attachment.IsText() {
 693			continue
 694		}
 695		files = append(files, fantasy.FilePart{
 696			Filename:  attachment.FileName,
 697			Data:      attachment.Content,
 698			MediaType: attachment.MimeType,
 699		})
 700	}
 701
 702	return history, files
 703}
 704
 705func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 706	msgs, err := a.messages.List(ctx, session.ID)
 707	if err != nil {
 708		return nil, fmt.Errorf("failed to list messages: %w", err)
 709	}
 710
 711	if session.SummaryMessageID != "" {
 712		summaryMsgInex := -1
 713		for i, msg := range msgs {
 714			if msg.ID == session.SummaryMessageID {
 715				summaryMsgInex = i
 716				break
 717			}
 718		}
 719		if summaryMsgInex != -1 {
 720			msgs = msgs[summaryMsgInex:]
 721			msgs[0].Role = message.User
 722		}
 723	}
 724	return msgs, nil
 725}
 726
 727// generateTitle generates a session titled based on the initial prompt.
 728func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 729	if userPrompt == "" {
 730		return
 731	}
 732
 733	var maxOutputTokens int64 = 40
 734	if a.smallModel.CatwalkCfg.CanReason {
 735		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 736	}
 737
 738	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 739		return fantasy.NewAgent(m,
 740			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 741			fantasy.WithMaxOutputTokens(tok),
 742		)
 743	}
 744
 745	streamCall := fantasy.AgentStreamCall{
 746		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 747		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 748			prepared.Messages = opts.Messages
 749			if a.systemPromptPrefix != "" {
 750				prepared.Messages = append([]fantasy.Message{
 751					fantasy.NewSystemMessage(a.systemPromptPrefix),
 752				}, prepared.Messages...)
 753			}
 754			return callCtx, prepared, nil
 755		},
 756	}
 757
 758	// Use the small model to generate the title.
 759	model := &a.smallModel
 760	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 761	resp, err := agent.Stream(ctx, streamCall)
 762	if err == nil {
 763		// We successfully generated a title with the small model.
 764		slog.Info("generated title with small model")
 765	} else {
 766		// It didn't work. Let's try with the big model.
 767		slog.Error("error generating title with small model; trying big model", "err", err)
 768		model = &a.largeModel
 769		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 770		resp, err = agent.Stream(ctx, streamCall)
 771		if err == nil {
 772			slog.Info("generated title with large model")
 773		} else {
 774			// Welp, the large model didn't work either. Use the default
 775			// session name and return.
 776			slog.Error("error generating title with large model", "err", err)
 777			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 778			if saveErr != nil {
 779				slog.Error("failed to save session title and usage", "error", saveErr)
 780			}
 781			return
 782		}
 783	}
 784
 785	if resp == nil {
 786		// Actually, we didn't get a response so we can't. Use the default
 787		// session name and return.
 788		slog.Error("response is nil; can't generate title")
 789		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 790		if saveErr != nil {
 791			slog.Error("failed to save session title and usage", "error", saveErr)
 792		}
 793		return
 794	}
 795
 796	// Clean up title.
 797	var title string
 798	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 799	slog.Info("generated title", "title", title)
 800
 801	// Remove thinking tags if present.
 802	title = thinkTagRegex.ReplaceAllString(title, "")
 803
 804	title = strings.TrimSpace(title)
 805	if title == "" {
 806		slog.Warn("empty title; using fallback")
 807		title = defaultSessionName
 808	}
 809
 810	// Calculate usage and cost.
 811	var openrouterCost *float64
 812	for _, step := range resp.Steps {
 813		stepCost := a.openrouterCost(step.ProviderMetadata)
 814		if stepCost != nil {
 815			newCost := *stepCost
 816			if openrouterCost != nil {
 817				newCost += *openrouterCost
 818			}
 819			openrouterCost = &newCost
 820		}
 821	}
 822
 823	modelConfig := model.CatwalkCfg
 824	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 825		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 826		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 827		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 828
 829	// Use override cost if available (e.g., from OpenRouter).
 830	if openrouterCost != nil {
 831		cost = *openrouterCost
 832	}
 833
 834	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 835	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 836
 837	// Atomically update only title and usage fields to avoid overriding other
 838	// concurrent session updates.
 839	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 840	if saveErr != nil {
 841		slog.Error("failed to save session title and usage", "error", saveErr)
 842		return
 843	}
 844}
 845
 846func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 847	openrouterMetadata, ok := metadata[openrouter.Name]
 848	if !ok {
 849		return nil
 850	}
 851
 852	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 853	if !ok {
 854		return nil
 855	}
 856	return &opts.Usage.Cost
 857}
 858
 859func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 860	modelConfig := model.CatwalkCfg
 861	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 862		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 863		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 864		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 865
 866	a.eventTokensUsed(session.ID, model, usage, cost)
 867
 868	if overrideCost != nil {
 869		session.Cost += *overrideCost
 870	} else {
 871		session.Cost += cost
 872	}
 873
 874	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 875	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 876}
 877
 878func (a *sessionAgent) Cancel(sessionID string) {
 879	// Cancel regular requests. Don't use Take() here - we need the entry to
 880	// remain in activeRequests so IsBusy() returns true until the goroutine
 881	// fully completes (including error handling that may access the DB).
 882	// The defer in processRequest will clean up the entry.
 883	if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
 884		slog.Info("Request cancellation initiated", "session_id", sessionID)
 885		cancel()
 886	}
 887
 888	// Also check for summarize requests.
 889	if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
 890		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 891		cancel()
 892	}
 893
 894	if a.QueuedPrompts(sessionID) > 0 {
 895		slog.Info("Clearing queued prompts", "session_id", sessionID)
 896		a.messageQueue.Del(sessionID)
 897	}
 898}
 899
 900func (a *sessionAgent) ClearQueue(sessionID string) {
 901	if a.QueuedPrompts(sessionID) > 0 {
 902		slog.Info("Clearing queued prompts", "session_id", sessionID)
 903		a.messageQueue.Del(sessionID)
 904	}
 905}
 906
 907func (a *sessionAgent) CancelAll() {
 908	if !a.IsBusy() {
 909		return
 910	}
 911	for key := range a.activeRequests.Seq2() {
 912		a.Cancel(key) // key is sessionID
 913	}
 914
 915	timeout := time.After(5 * time.Second)
 916	for a.IsBusy() {
 917		select {
 918		case <-timeout:
 919			return
 920		default:
 921			time.Sleep(200 * time.Millisecond)
 922		}
 923	}
 924}
 925
 926func (a *sessionAgent) IsBusy() bool {
 927	var busy bool
 928	for cancelFunc := range a.activeRequests.Seq() {
 929		if cancelFunc != nil {
 930			busy = true
 931			break
 932		}
 933	}
 934	return busy
 935}
 936
 937func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 938	_, busy := a.activeRequests.Get(sessionID)
 939	return busy
 940}
 941
 942func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 943	l, ok := a.messageQueue.Get(sessionID)
 944	if !ok {
 945		return 0
 946	}
 947	return len(l)
 948}
 949
 950func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 951	l, ok := a.messageQueue.Get(sessionID)
 952	if !ok {
 953		return nil
 954	}
 955	prompts := make([]string, len(l))
 956	for i, call := range l {
 957		prompts[i] = call.Prompt
 958	}
 959	return prompts
 960}
 961
 962func (a *sessionAgent) SetModels(large Model, small Model) {
 963	a.largeModel = large
 964	a.smallModel = small
 965}
 966
 967func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 968	a.tools = tools
 969}
 970
 971func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
 972	a.systemPrompt = systemPrompt
 973}
 974
 975func (a *sessionAgent) Model() Model {
 976	return a.largeModel
 977}
 978
 979func (a *sessionAgent) promptPrefix() string {
 980	return a.systemPromptPrefix
 981}
 982
 983// convertToToolResult converts a fantasy tool result to a message tool result.
 984func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 985	baseResult := message.ToolResult{
 986		ToolCallID: result.ToolCallID,
 987		Name:       result.ToolName,
 988		Metadata:   result.ClientMetadata,
 989	}
 990
 991	switch result.Result.GetType() {
 992	case fantasy.ToolResultContentTypeText:
 993		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 994			baseResult.Content = r.Text
 995		}
 996	case fantasy.ToolResultContentTypeError:
 997		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 998			baseResult.Content = r.Error.Error()
 999			baseResult.IsError = true
1000		}
1001	case fantasy.ToolResultContentTypeMedia:
1002		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1003			content := r.Text
1004			if content == "" {
1005				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1006			}
1007			baseResult.Content = content
1008			baseResult.Data = r.Data
1009			baseResult.MIMEType = r.MediaType
1010		}
1011	}
1012
1013	return baseResult
1014}
1015
1016// workaroundProviderMediaLimitations converts media content in tool results to
1017// user messages for providers that don't natively support images in tool results.
1018//
1019// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1020// don't support sending images/media in tool result messages - they only accept
1021// text in tool results. However, they DO support images in user messages.
1022//
1023// If we send media in tool results to these providers, the API returns an error.
1024//
1025// Solution: For these providers, we:
1026//  1. Replace the media in the tool result with a text placeholder
1027//  2. Inject a user message immediately after with the image as a file attachment
1028//  3. This maintains the tool execution flow while working around API limitations
1029//
1030// Anthropic and Bedrock support images natively in tool results, so we skip
1031// this workaround for them.
1032//
1033// Example transformation:
1034//
1035//	BEFORE: [tool result: image data]
1036//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1037func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1038	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1039		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1040
1041	if providerSupportsMedia {
1042		return messages
1043	}
1044
1045	convertedMessages := make([]fantasy.Message, 0, len(messages))
1046
1047	for _, msg := range messages {
1048		if msg.Role != fantasy.MessageRoleTool {
1049			convertedMessages = append(convertedMessages, msg)
1050			continue
1051		}
1052
1053		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1054		var mediaFiles []fantasy.FilePart
1055
1056		for _, part := range msg.Content {
1057			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1058			if !ok {
1059				textParts = append(textParts, part)
1060				continue
1061			}
1062
1063			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1064				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1065				if err != nil {
1066					slog.Warn("failed to decode media data", "error", err)
1067					textParts = append(textParts, part)
1068					continue
1069				}
1070
1071				mediaFiles = append(mediaFiles, fantasy.FilePart{
1072					Data:      decoded,
1073					MediaType: media.MediaType,
1074					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1075				})
1076
1077				textParts = append(textParts, fantasy.ToolResultPart{
1078					ToolCallID: toolResult.ToolCallID,
1079					Output: fantasy.ToolResultOutputContentText{
1080						Text: "[Image/media content loaded - see attached file]",
1081					},
1082					ProviderOptions: toolResult.ProviderOptions,
1083				})
1084			} else {
1085				textParts = append(textParts, part)
1086			}
1087		}
1088
1089		convertedMessages = append(convertedMessages, fantasy.Message{
1090			Role:    fantasy.MessageRoleTool,
1091			Content: textParts,
1092		})
1093
1094		if len(mediaFiles) > 0 {
1095			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1096				"Here is the media content from the tool result:",
1097				mediaFiles...,
1098			))
1099		}
1100	}
1101
1102	return convertedMessages
1103}
1104
1105// buildSummaryPrompt constructs the prompt text for session summarization.
1106func buildSummaryPrompt(todos []session.Todo) string {
1107	var sb strings.Builder
1108	sb.WriteString("Provide a detailed summary of our conversation above.")
1109	if len(todos) > 0 {
1110		sb.WriteString("\n\n## Current Todo List\n\n")
1111		for _, t := range todos {
1112			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1113		}
1114		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1115		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1116	}
1117	return sb.String()
1118}