agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"github.com/charmbracelet/catwalk/pkg/catwalk"
  33	"github.com/charmbracelet/crush/internal/agent/hyper"
  34	"github.com/charmbracelet/crush/internal/agent/tools"
  35	"github.com/charmbracelet/crush/internal/config"
  36	"github.com/charmbracelet/crush/internal/csync"
  37	"github.com/charmbracelet/crush/internal/message"
  38	"github.com/charmbracelet/crush/internal/permission"
  39	"github.com/charmbracelet/crush/internal/session"
  40	"github.com/charmbracelet/crush/internal/stringext"
  41)
  42
  43const defaultSessionName = "Untitled Session"
  44
  45//go:embed templates/title.md
  46var titlePrompt []byte
  47
  48//go:embed templates/summary.md
  49var summaryPrompt []byte
  50
  51// Used to remove <think> tags from generated titles.
  52var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  53
  54type SessionAgentCall struct {
  55	SessionID        string
  56	Prompt           string
  57	ProviderOptions  fantasy.ProviderOptions
  58	Attachments      []message.Attachment
  59	MaxOutputTokens  int64
  60	Temperature      *float64
  61	TopP             *float64
  62	TopK             *int64
  63	FrequencyPenalty *float64
  64	PresencePenalty  *float64
  65}
  66
  67type SessionAgent interface {
  68	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  69	SetModels(large Model, small Model)
  70	SetTools(tools []fantasy.AgentTool)
  71	SetSystemPrompt(systemPrompt string)
  72	Cancel(sessionID string)
  73	CancelAll()
  74	IsSessionBusy(sessionID string) bool
  75	IsBusy() bool
  76	QueuedPrompts(sessionID string) int
  77	QueuedPromptsList(sessionID string) []string
  78	ClearQueue(sessionID string)
  79	Summarize(context.Context, string, fantasy.ProviderOptions) error
  80	Model() Model
  81}
  82
  83type Model struct {
  84	Model      fantasy.LanguageModel
  85	CatwalkCfg catwalk.Model
  86	ModelCfg   config.SelectedModel
  87}
  88
  89type sessionAgent struct {
  90	largeModel           Model
  91	smallModel           Model
  92	systemPromptPrefix   string
  93	systemPrompt         string
  94	isSubAgent           bool
  95	tools                []fantasy.AgentTool
  96	sessions             session.Service
  97	messages             message.Service
  98	disableAutoSummarize bool
  99	isYolo               bool
 100
 101	messageQueue   *csync.Map[string, []SessionAgentCall]
 102	activeRequests *csync.Map[string, context.CancelFunc]
 103}
 104
 105type SessionAgentOptions struct {
 106	LargeModel           Model
 107	SmallModel           Model
 108	SystemPromptPrefix   string
 109	SystemPrompt         string
 110	IsSubAgent           bool
 111	DisableAutoSummarize bool
 112	IsYolo               bool
 113	Sessions             session.Service
 114	Messages             message.Service
 115	Tools                []fantasy.AgentTool
 116}
 117
 118func NewSessionAgent(
 119	opts SessionAgentOptions,
 120) SessionAgent {
 121	return &sessionAgent{
 122		largeModel:           opts.LargeModel,
 123		smallModel:           opts.SmallModel,
 124		systemPromptPrefix:   opts.SystemPromptPrefix,
 125		systemPrompt:         opts.SystemPrompt,
 126		isSubAgent:           opts.IsSubAgent,
 127		sessions:             opts.Sessions,
 128		messages:             opts.Messages,
 129		disableAutoSummarize: opts.DisableAutoSummarize,
 130		tools:                opts.Tools,
 131		isYolo:               opts.IsYolo,
 132		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 133		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 134	}
 135}
 136
 137func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 138	if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
 139		return nil, ErrEmptyPrompt
 140	}
 141	if call.SessionID == "" {
 142		return nil, ErrSessionMissing
 143	}
 144
 145	// Queue the message if busy
 146	if a.IsSessionBusy(call.SessionID) {
 147		existing, ok := a.messageQueue.Get(call.SessionID)
 148		if !ok {
 149			existing = []SessionAgentCall{}
 150		}
 151		existing = append(existing, call)
 152		a.messageQueue.Set(call.SessionID, existing)
 153		return nil, nil
 154	}
 155
 156	if len(a.tools) > 0 {
 157		// Add Anthropic caching to the last tool.
 158		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 159	}
 160
 161	agent := fantasy.NewAgent(
 162		a.largeModel.Model,
 163		fantasy.WithSystemPrompt(a.systemPrompt),
 164		fantasy.WithTools(a.tools...),
 165	)
 166
 167	sessionLock := sync.Mutex{}
 168	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 169	if err != nil {
 170		return nil, fmt.Errorf("failed to get session: %w", err)
 171	}
 172
 173	msgs, err := a.getSessionMessages(ctx, currentSession)
 174	if err != nil {
 175		return nil, fmt.Errorf("failed to get session messages: %w", err)
 176	}
 177
 178	var wg sync.WaitGroup
 179	// Generate title if first message.
 180	if len(msgs) == 0 {
 181		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 182		wg.Go(func() {
 183			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 184		})
 185	}
 186
 187	// Add the user message to the session.
 188	_, err = a.createUserMessage(ctx, call)
 189	if err != nil {
 190		return nil, err
 191	}
 192
 193	// Add the session to the context.
 194	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 195
 196	genCtx, cancel := context.WithCancel(ctx)
 197	a.activeRequests.Set(call.SessionID, cancel)
 198
 199	defer cancel()
 200	defer a.activeRequests.Del(call.SessionID)
 201
 202	history, files := a.preparePrompt(msgs, call.Attachments...)
 203
 204	startTime := time.Now()
 205	a.eventPromptSent(call.SessionID)
 206
 207	var currentAssistant *message.Message
 208	var shouldSummarize bool
 209	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 210		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 211		Files:            files,
 212		Messages:         history,
 213		ProviderOptions:  call.ProviderOptions,
 214		MaxOutputTokens:  &call.MaxOutputTokens,
 215		TopP:             call.TopP,
 216		Temperature:      call.Temperature,
 217		PresencePenalty:  call.PresencePenalty,
 218		TopK:             call.TopK,
 219		FrequencyPenalty: call.FrequencyPenalty,
 220		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 221			prepared.Messages = options.Messages
 222			for i := range prepared.Messages {
 223				prepared.Messages[i].ProviderOptions = nil
 224			}
 225
 226			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 227			a.messageQueue.Del(call.SessionID)
 228			for _, queued := range queuedCalls {
 229				userMessage, createErr := a.createUserMessage(callContext, queued)
 230				if createErr != nil {
 231					return callContext, prepared, createErr
 232				}
 233				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 234			}
 235
 236			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 237
 238			lastSystemRoleInx := 0
 239			systemMessageUpdated := false
 240			for i, msg := range prepared.Messages {
 241				// Only add cache control to the last message.
 242				if msg.Role == fantasy.MessageRoleSystem {
 243					lastSystemRoleInx = i
 244				} else if !systemMessageUpdated {
 245					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 246					systemMessageUpdated = true
 247				}
 248				// Than add cache control to the last 2 messages.
 249				if i > len(prepared.Messages)-3 {
 250					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 251				}
 252			}
 253
 254			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 255				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 256			}
 257
 258			var assistantMsg message.Message
 259			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 260				Role:     message.Assistant,
 261				Parts:    []message.ContentPart{},
 262				Model:    a.largeModel.ModelCfg.Model,
 263				Provider: a.largeModel.ModelCfg.Provider,
 264			})
 265			if err != nil {
 266				return callContext, prepared, err
 267			}
 268			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 269			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 270			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 271			currentAssistant = &assistantMsg
 272			return callContext, prepared, err
 273		},
 274		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 275			currentAssistant.AppendReasoningContent(reasoning.Text)
 276			return a.messages.Update(genCtx, *currentAssistant)
 277		},
 278		OnReasoningDelta: func(id string, text string) error {
 279			currentAssistant.AppendReasoningContent(text)
 280			return a.messages.Update(genCtx, *currentAssistant)
 281		},
 282		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 283			// handle anthropic signature
 284			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 285				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 286					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 287				}
 288			}
 289			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 290				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 291					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 292				}
 293			}
 294			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 295				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 296					currentAssistant.SetReasoningResponsesData(reasoning)
 297				}
 298			}
 299			currentAssistant.FinishThinking()
 300			return a.messages.Update(genCtx, *currentAssistant)
 301		},
 302		OnTextDelta: func(id string, text string) error {
 303			// Strip leading newline from initial text content. This is is
 304			// particularly important in non-interactive mode where leading
 305			// newlines are very visible.
 306			if len(currentAssistant.Parts) == 0 {
 307				text = strings.TrimPrefix(text, "\n")
 308			}
 309
 310			currentAssistant.AppendContent(text)
 311			return a.messages.Update(genCtx, *currentAssistant)
 312		},
 313		OnToolInputStart: func(id string, toolName string) error {
 314			toolCall := message.ToolCall{
 315				ID:               id,
 316				Name:             toolName,
 317				ProviderExecuted: false,
 318				Finished:         false,
 319			}
 320			currentAssistant.AddToolCall(toolCall)
 321			return a.messages.Update(genCtx, *currentAssistant)
 322		},
 323		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 324			// TODO: implement
 325		},
 326		OnToolCall: func(tc fantasy.ToolCallContent) error {
 327			toolCall := message.ToolCall{
 328				ID:               tc.ToolCallID,
 329				Name:             tc.ToolName,
 330				Input:            tc.Input,
 331				ProviderExecuted: false,
 332				Finished:         true,
 333			}
 334			currentAssistant.AddToolCall(toolCall)
 335			return a.messages.Update(genCtx, *currentAssistant)
 336		},
 337		OnToolResult: func(result fantasy.ToolResultContent) error {
 338			toolResult := a.convertToToolResult(result)
 339			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 340				Role: message.Tool,
 341				Parts: []message.ContentPart{
 342					toolResult,
 343				},
 344			})
 345			return createMsgErr
 346		},
 347		OnStepFinish: func(stepResult fantasy.StepResult) error {
 348			finishReason := message.FinishReasonUnknown
 349			switch stepResult.FinishReason {
 350			case fantasy.FinishReasonLength:
 351				finishReason = message.FinishReasonMaxTokens
 352			case fantasy.FinishReasonStop:
 353				finishReason = message.FinishReasonEndTurn
 354			case fantasy.FinishReasonToolCalls:
 355				finishReason = message.FinishReasonToolUse
 356			}
 357			currentAssistant.AddFinish(finishReason, "", "")
 358			sessionLock.Lock()
 359			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 360			if getSessionErr != nil {
 361				sessionLock.Unlock()
 362				return getSessionErr
 363			}
 364			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 365			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 366			sessionLock.Unlock()
 367			if sessionErr != nil {
 368				return sessionErr
 369			}
 370			return a.messages.Update(genCtx, *currentAssistant)
 371		},
 372		StopWhen: []fantasy.StopCondition{
 373			func(_ []fantasy.StepResult) bool {
 374				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 375				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 376				remaining := cw - tokens
 377				var threshold int64
 378				if cw > 200_000 {
 379					threshold = 20_000
 380				} else {
 381					threshold = int64(float64(cw) * 0.2)
 382				}
 383				if (remaining <= threshold) && !a.disableAutoSummarize {
 384					shouldSummarize = true
 385					return true
 386				}
 387				return false
 388			},
 389		},
 390	})
 391
 392	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 393
 394	if err != nil {
 395		isCancelErr := errors.Is(err, context.Canceled)
 396		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 397		if currentAssistant == nil {
 398			return result, err
 399		}
 400		// Ensure we finish thinking on error to close the reasoning state.
 401		currentAssistant.FinishThinking()
 402		toolCalls := currentAssistant.ToolCalls()
 403		// INFO: we use the parent context here because the genCtx has been cancelled.
 404		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 405		if createErr != nil {
 406			return nil, createErr
 407		}
 408		for _, tc := range toolCalls {
 409			if !tc.Finished {
 410				tc.Finished = true
 411				tc.Input = "{}"
 412				currentAssistant.AddToolCall(tc)
 413				updateErr := a.messages.Update(ctx, *currentAssistant)
 414				if updateErr != nil {
 415					return nil, updateErr
 416				}
 417			}
 418
 419			found := false
 420			for _, msg := range msgs {
 421				if msg.Role == message.Tool {
 422					for _, tr := range msg.ToolResults() {
 423						if tr.ToolCallID == tc.ID {
 424							found = true
 425							break
 426						}
 427					}
 428				}
 429				if found {
 430					break
 431				}
 432			}
 433			if found {
 434				continue
 435			}
 436			content := "There was an error while executing the tool"
 437			if isCancelErr {
 438				content = "Tool execution canceled by user"
 439			} else if isPermissionErr {
 440				content = "User denied permission"
 441			}
 442			toolResult := message.ToolResult{
 443				ToolCallID: tc.ID,
 444				Name:       tc.Name,
 445				Content:    content,
 446				IsError:    true,
 447			}
 448			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 449				Role: message.Tool,
 450				Parts: []message.ContentPart{
 451					toolResult,
 452				},
 453			})
 454			if createErr != nil {
 455				return nil, createErr
 456			}
 457		}
 458		var fantasyErr *fantasy.Error
 459		var providerErr *fantasy.ProviderError
 460		const defaultTitle = "Provider Error"
 461		if isCancelErr {
 462			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 463		} else if isPermissionErr {
 464			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 465		} else if errors.Is(err, hyper.ErrNoCredits) {
 466			url := hyper.BaseURL()
 467			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 468			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 469		} else if errors.As(err, &providerErr) {
 470			if providerErr.Message == "The requested model is not supported." {
 471				url := "https://github.com/settings/copilot/features"
 472				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 473				currentAssistant.AddFinish(
 474					message.FinishReasonError,
 475					"Copilot model not enabled",
 476					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 477				)
 478			} else {
 479				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 480			}
 481		} else if errors.As(err, &fantasyErr) {
 482			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 483		} else {
 484			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 485		}
 486		// Note: we use the parent context here because the genCtx has been
 487		// cancelled.
 488		updateErr := a.messages.Update(ctx, *currentAssistant)
 489		if updateErr != nil {
 490			return nil, updateErr
 491		}
 492		return nil, err
 493	}
 494	wg.Wait()
 495
 496	if shouldSummarize {
 497		a.activeRequests.Del(call.SessionID)
 498		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 499			return nil, summarizeErr
 500		}
 501		// If the agent wasn't done...
 502		if len(currentAssistant.ToolCalls()) > 0 {
 503			existing, ok := a.messageQueue.Get(call.SessionID)
 504			if !ok {
 505				existing = []SessionAgentCall{}
 506			}
 507			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 508			existing = append(existing, call)
 509			a.messageQueue.Set(call.SessionID, existing)
 510		}
 511	}
 512
 513	// Release active request before processing queued messages.
 514	a.activeRequests.Del(call.SessionID)
 515	cancel()
 516
 517	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 518	if !ok || len(queuedMessages) == 0 {
 519		return result, err
 520	}
 521	// There are queued messages restart the loop.
 522	firstQueuedMessage := queuedMessages[0]
 523	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 524	return a.Run(ctx, firstQueuedMessage)
 525}
 526
 527func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 528	if a.IsSessionBusy(sessionID) {
 529		return ErrSessionBusy
 530	}
 531
 532	currentSession, err := a.sessions.Get(ctx, sessionID)
 533	if err != nil {
 534		return fmt.Errorf("failed to get session: %w", err)
 535	}
 536	msgs, err := a.getSessionMessages(ctx, currentSession)
 537	if err != nil {
 538		return err
 539	}
 540	if len(msgs) == 0 {
 541		// Nothing to summarize.
 542		return nil
 543	}
 544
 545	aiMsgs, _ := a.preparePrompt(msgs)
 546
 547	genCtx, cancel := context.WithCancel(ctx)
 548	a.activeRequests.Set(sessionID, cancel)
 549	defer a.activeRequests.Del(sessionID)
 550	defer cancel()
 551
 552	agent := fantasy.NewAgent(a.largeModel.Model,
 553		fantasy.WithSystemPrompt(string(summaryPrompt)),
 554	)
 555	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 556		Role:             message.Assistant,
 557		Model:            a.largeModel.Model.Model(),
 558		Provider:         a.largeModel.Model.Provider(),
 559		IsSummaryMessage: true,
 560	})
 561	if err != nil {
 562		return err
 563	}
 564
 565	summaryPromptText := buildSummaryPrompt(currentSession.Todos)
 566
 567	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 568		Prompt:          summaryPromptText,
 569		Messages:        aiMsgs,
 570		ProviderOptions: opts,
 571		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 572			prepared.Messages = options.Messages
 573			if a.systemPromptPrefix != "" {
 574				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 575			}
 576			return callContext, prepared, nil
 577		},
 578		OnReasoningDelta: func(id string, text string) error {
 579			summaryMessage.AppendReasoningContent(text)
 580			return a.messages.Update(genCtx, summaryMessage)
 581		},
 582		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 583			// Handle anthropic signature.
 584			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 585				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 586					summaryMessage.AppendReasoningSignature(signature.Signature)
 587				}
 588			}
 589			summaryMessage.FinishThinking()
 590			return a.messages.Update(genCtx, summaryMessage)
 591		},
 592		OnTextDelta: func(id, text string) error {
 593			summaryMessage.AppendContent(text)
 594			return a.messages.Update(genCtx, summaryMessage)
 595		},
 596	})
 597	if err != nil {
 598		isCancelErr := errors.Is(err, context.Canceled)
 599		if isCancelErr {
 600			// User cancelled summarize we need to remove the summary message.
 601			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 602			return deleteErr
 603		}
 604		return err
 605	}
 606
 607	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 608	err = a.messages.Update(genCtx, summaryMessage)
 609	if err != nil {
 610		return err
 611	}
 612
 613	var openrouterCost *float64
 614	for _, step := range resp.Steps {
 615		stepCost := a.openrouterCost(step.ProviderMetadata)
 616		if stepCost != nil {
 617			newCost := *stepCost
 618			if openrouterCost != nil {
 619				newCost += *openrouterCost
 620			}
 621			openrouterCost = &newCost
 622		}
 623	}
 624
 625	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 626
 627	// Just in case, get just the last usage info.
 628	usage := resp.Response.Usage
 629	currentSession.SummaryMessageID = summaryMessage.ID
 630	currentSession.CompletionTokens = usage.OutputTokens
 631	currentSession.PromptTokens = 0
 632	_, err = a.sessions.Save(genCtx, currentSession)
 633	return err
 634}
 635
 636func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 637	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 638		return fantasy.ProviderOptions{}
 639	}
 640	return fantasy.ProviderOptions{
 641		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 642			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 643		},
 644		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 645			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 646		},
 647	}
 648}
 649
 650func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 651	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 652	var attachmentParts []message.ContentPart
 653	for _, attachment := range call.Attachments {
 654		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 655	}
 656	parts = append(parts, attachmentParts...)
 657	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 658		Role:  message.User,
 659		Parts: parts,
 660	})
 661	if err != nil {
 662		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 663	}
 664	return msg, nil
 665}
 666
 667func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 668	var history []fantasy.Message
 669	if !a.isSubAgent {
 670		history = append(history, fantasy.NewUserMessage(
 671			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 672				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 673If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 674If not, please feel free to ignore. Again do not mention this message to the user.`,
 675			),
 676		))
 677	}
 678	for _, m := range msgs {
 679		if len(m.Parts) == 0 {
 680			continue
 681		}
 682		// Assistant message without content or tool calls (cancelled before it
 683		// returned anything).
 684		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 685			continue
 686		}
 687		history = append(history, m.ToAIMessage()...)
 688	}
 689
 690	var files []fantasy.FilePart
 691	for _, attachment := range attachments {
 692		if attachment.IsText() {
 693			continue
 694		}
 695		files = append(files, fantasy.FilePart{
 696			Filename:  attachment.FileName,
 697			Data:      attachment.Content,
 698			MediaType: attachment.MimeType,
 699		})
 700	}
 701
 702	return history, files
 703}
 704
 705func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 706	msgs, err := a.messages.List(ctx, session.ID)
 707	if err != nil {
 708		return nil, fmt.Errorf("failed to list messages: %w", err)
 709	}
 710
 711	if session.SummaryMessageID != "" {
 712		summaryMsgInex := -1
 713		for i, msg := range msgs {
 714			if msg.ID == session.SummaryMessageID {
 715				summaryMsgInex = i
 716				break
 717			}
 718		}
 719		if summaryMsgInex != -1 {
 720			msgs = msgs[summaryMsgInex:]
 721			msgs[0].Role = message.User
 722		}
 723	}
 724	return msgs, nil
 725}
 726
 727// generateTitle generates a session titled based on the initial prompt.
 728func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 729	if userPrompt == "" {
 730		return
 731	}
 732
 733	var maxOutputTokens int64 = 40
 734	if a.smallModel.CatwalkCfg.CanReason {
 735		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 736	}
 737
 738	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 739		return fantasy.NewAgent(m,
 740			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 741			fantasy.WithMaxOutputTokens(tok),
 742		)
 743	}
 744
 745	streamCall := fantasy.AgentStreamCall{
 746		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 747		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 748			prepared.Messages = opts.Messages
 749			if a.systemPromptPrefix != "" {
 750				prepared.Messages = append([]fantasy.Message{
 751					fantasy.NewSystemMessage(a.systemPromptPrefix),
 752				}, prepared.Messages...)
 753			}
 754			return callCtx, prepared, nil
 755		},
 756	}
 757
 758	// Use the small model to generate the title.
 759	model := &a.smallModel
 760	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 761	resp, err := agent.Stream(ctx, streamCall)
 762	if err == nil {
 763		// We successfully generated a title with the small model.
 764		slog.Info("generated title with small model")
 765	} else {
 766		// It didn't work. Let's try with the big model.
 767		slog.Error("error generating title with small model; trying big model", "err", err)
 768		model = &a.largeModel
 769		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 770		resp, err = agent.Stream(ctx, streamCall)
 771		if err == nil {
 772			slog.Info("generated title with large model")
 773		} else {
 774			// Welp, the large model didn't work either. Use the default
 775			// session name and return.
 776			slog.Error("error generating title with large model", "err", err)
 777			saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 778			if saveErr != nil {
 779				slog.Error("failed to save session title and usage", "error", saveErr)
 780			}
 781			return
 782		}
 783	}
 784
 785	if resp == nil {
 786		// Actually, we didn't get a response so we can't. Use the default
 787		// session name and return.
 788		slog.Error("response is nil; can't generate title")
 789		saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
 790		if saveErr != nil {
 791			slog.Error("failed to save session title and usage", "error", saveErr)
 792		}
 793		return
 794	}
 795
 796	// Clean up title.
 797	var title string
 798	title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 799	slog.Info("generated title", "title", title)
 800
 801	// Remove thinking tags if present.
 802	title = thinkTagRegex.ReplaceAllString(title, "")
 803
 804	title = strings.TrimSpace(title)
 805	if title == "" {
 806		slog.Warn("empty title; using fallback")
 807		title = defaultSessionName
 808	}
 809
 810	// Calculate usage and cost.
 811	var openrouterCost *float64
 812	for _, step := range resp.Steps {
 813		stepCost := a.openrouterCost(step.ProviderMetadata)
 814		if stepCost != nil {
 815			newCost := *stepCost
 816			if openrouterCost != nil {
 817				newCost += *openrouterCost
 818			}
 819			openrouterCost = &newCost
 820		}
 821	}
 822
 823	modelConfig := model.CatwalkCfg
 824	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 825		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 826		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 827		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 828
 829	// Use override cost if available (e.g., from OpenRouter).
 830	if openrouterCost != nil {
 831		cost = *openrouterCost
 832	}
 833
 834	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 835	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 836
 837	// Atomically update only title and usage fields to avoid overriding other
 838	// concurrent session updates.
 839	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 840	if saveErr != nil {
 841		slog.Error("failed to save session title and usage", "error", saveErr)
 842		return
 843	}
 844}
 845
 846func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 847	openrouterMetadata, ok := metadata[openrouter.Name]
 848	if !ok {
 849		return nil
 850	}
 851
 852	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 853	if !ok {
 854		return nil
 855	}
 856	return &opts.Usage.Cost
 857}
 858
 859func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 860	modelConfig := model.CatwalkCfg
 861	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 862		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 863		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 864		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 865
 866	a.eventTokensUsed(session.ID, model, usage, cost)
 867
 868	if overrideCost != nil {
 869		session.Cost += *overrideCost
 870	} else {
 871		session.Cost += cost
 872	}
 873
 874	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 875	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 876}
 877
 878func (a *sessionAgent) Cancel(sessionID string) {
 879	// Cancel regular requests.
 880	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 881		slog.Info("Request cancellation initiated", "session_id", sessionID)
 882		cancel()
 883	}
 884
 885	// Also check for summarize requests.
 886	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 887		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 888		cancel()
 889	}
 890
 891	if a.QueuedPrompts(sessionID) > 0 {
 892		slog.Info("Clearing queued prompts", "session_id", sessionID)
 893		a.messageQueue.Del(sessionID)
 894	}
 895}
 896
 897func (a *sessionAgent) ClearQueue(sessionID string) {
 898	if a.QueuedPrompts(sessionID) > 0 {
 899		slog.Info("Clearing queued prompts", "session_id", sessionID)
 900		a.messageQueue.Del(sessionID)
 901	}
 902}
 903
 904func (a *sessionAgent) CancelAll() {
 905	if !a.IsBusy() {
 906		return
 907	}
 908	for key := range a.activeRequests.Seq2() {
 909		a.Cancel(key) // key is sessionID
 910	}
 911
 912	timeout := time.After(5 * time.Second)
 913	for a.IsBusy() {
 914		select {
 915		case <-timeout:
 916			return
 917		default:
 918			time.Sleep(200 * time.Millisecond)
 919		}
 920	}
 921}
 922
 923func (a *sessionAgent) IsBusy() bool {
 924	var busy bool
 925	for cancelFunc := range a.activeRequests.Seq() {
 926		if cancelFunc != nil {
 927			busy = true
 928			break
 929		}
 930	}
 931	return busy
 932}
 933
 934func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 935	_, busy := a.activeRequests.Get(sessionID)
 936	return busy
 937}
 938
 939func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 940	l, ok := a.messageQueue.Get(sessionID)
 941	if !ok {
 942		return 0
 943	}
 944	return len(l)
 945}
 946
 947func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 948	l, ok := a.messageQueue.Get(sessionID)
 949	if !ok {
 950		return nil
 951	}
 952	prompts := make([]string, len(l))
 953	for i, call := range l {
 954		prompts[i] = call.Prompt
 955	}
 956	return prompts
 957}
 958
 959func (a *sessionAgent) SetModels(large Model, small Model) {
 960	a.largeModel = large
 961	a.smallModel = small
 962}
 963
 964func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 965	a.tools = tools
 966}
 967
 968func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
 969	a.systemPrompt = systemPrompt
 970}
 971
 972func (a *sessionAgent) Model() Model {
 973	return a.largeModel
 974}
 975
 976func (a *sessionAgent) promptPrefix() string {
 977	return a.systemPromptPrefix
 978}
 979
 980// convertToToolResult converts a fantasy tool result to a message tool result.
 981func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 982	baseResult := message.ToolResult{
 983		ToolCallID: result.ToolCallID,
 984		Name:       result.ToolName,
 985		Metadata:   result.ClientMetadata,
 986	}
 987
 988	switch result.Result.GetType() {
 989	case fantasy.ToolResultContentTypeText:
 990		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 991			baseResult.Content = r.Text
 992		}
 993	case fantasy.ToolResultContentTypeError:
 994		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 995			baseResult.Content = r.Error.Error()
 996			baseResult.IsError = true
 997		}
 998	case fantasy.ToolResultContentTypeMedia:
 999		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1000			content := r.Text
1001			if content == "" {
1002				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1003			}
1004			baseResult.Content = content
1005			baseResult.Data = r.Data
1006			baseResult.MIMEType = r.MediaType
1007		}
1008	}
1009
1010	return baseResult
1011}
1012
1013// workaroundProviderMediaLimitations converts media content in tool results to
1014// user messages for providers that don't natively support images in tool results.
1015//
1016// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1017// don't support sending images/media in tool result messages - they only accept
1018// text in tool results. However, they DO support images in user messages.
1019//
1020// If we send media in tool results to these providers, the API returns an error.
1021//
1022// Solution: For these providers, we:
1023//  1. Replace the media in the tool result with a text placeholder
1024//  2. Inject a user message immediately after with the image as a file attachment
1025//  3. This maintains the tool execution flow while working around API limitations
1026//
1027// Anthropic and Bedrock support images natively in tool results, so we skip
1028// this workaround for them.
1029//
1030// Example transformation:
1031//
1032//	BEFORE: [tool result: image data]
1033//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1034func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1035	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1036		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1037
1038	if providerSupportsMedia {
1039		return messages
1040	}
1041
1042	convertedMessages := make([]fantasy.Message, 0, len(messages))
1043
1044	for _, msg := range messages {
1045		if msg.Role != fantasy.MessageRoleTool {
1046			convertedMessages = append(convertedMessages, msg)
1047			continue
1048		}
1049
1050		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1051		var mediaFiles []fantasy.FilePart
1052
1053		for _, part := range msg.Content {
1054			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1055			if !ok {
1056				textParts = append(textParts, part)
1057				continue
1058			}
1059
1060			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1061				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1062				if err != nil {
1063					slog.Warn("failed to decode media data", "error", err)
1064					textParts = append(textParts, part)
1065					continue
1066				}
1067
1068				mediaFiles = append(mediaFiles, fantasy.FilePart{
1069					Data:      decoded,
1070					MediaType: media.MediaType,
1071					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1072				})
1073
1074				textParts = append(textParts, fantasy.ToolResultPart{
1075					ToolCallID: toolResult.ToolCallID,
1076					Output: fantasy.ToolResultOutputContentText{
1077						Text: "[Image/media content loaded - see attached file]",
1078					},
1079					ProviderOptions: toolResult.ProviderOptions,
1080				})
1081			} else {
1082				textParts = append(textParts, part)
1083			}
1084		}
1085
1086		convertedMessages = append(convertedMessages, fantasy.Message{
1087			Role:    fantasy.MessageRoleTool,
1088			Content: textParts,
1089		})
1090
1091		if len(mediaFiles) > 0 {
1092			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1093				"Here is the media content from the tool result:",
1094				mediaFiles...,
1095			))
1096		}
1097	}
1098
1099	return convertedMessages
1100}
1101
1102// buildSummaryPrompt constructs the prompt text for session summarization.
1103func buildSummaryPrompt(todos []session.Todo) string {
1104	var sb strings.Builder
1105	sb.WriteString("Provide a detailed summary of our conversation above.")
1106	if len(todos) > 0 {
1107		sb.WriteString("\n\n## Current Todo List\n\n")
1108		for _, t := range todos {
1109			fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1110		}
1111		sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1112		sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1113	}
1114	return sb.String()
1115}