agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"github.com/charmbracelet/catwalk/pkg/catwalk"
  31	"github.com/charmbracelet/crush/internal/agent/tools"
  32	"github.com/charmbracelet/crush/internal/config"
  33	"github.com/charmbracelet/crush/internal/csync"
  34	"github.com/charmbracelet/crush/internal/message"
  35	"github.com/charmbracelet/crush/internal/permission"
  36	"github.com/charmbracelet/crush/internal/session"
  37	"github.com/charmbracelet/crush/internal/stringext"
  38)
  39
  40//go:embed templates/title.md
  41var titlePrompt []byte
  42
  43//go:embed templates/summary.md
  44var summaryPrompt []byte
  45
  46type SessionAgentCall struct {
  47	SessionID        string
  48	Prompt           string
  49	ProviderOptions  fantasy.ProviderOptions
  50	Attachments      []message.Attachment
  51	MaxOutputTokens  int64
  52	Temperature      *float64
  53	TopP             *float64
  54	TopK             *int64
  55	FrequencyPenalty *float64
  56	PresencePenalty  *float64
  57}
  58
  59type SessionAgent interface {
  60	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  61	SetModels(large Model, small Model)
  62	SetTools(tools []fantasy.AgentTool)
  63	Cancel(sessionID string)
  64	CancelAll()
  65	IsSessionBusy(sessionID string) bool
  66	IsBusy() bool
  67	QueuedPrompts(sessionID string) int
  68	QueuedPromptsList(sessionID string) []string
  69	ClearQueue(sessionID string)
  70	Summarize(context.Context, string, fantasy.ProviderOptions) error
  71	Model() Model
  72}
  73
  74type Model struct {
  75	Model      fantasy.LanguageModel
  76	CatwalkCfg catwalk.Model
  77	ModelCfg   config.SelectedModel
  78}
  79
  80type sessionAgent struct {
  81	largeModel           Model
  82	smallModel           Model
  83	systemPromptPrefix   string
  84	systemPrompt         string
  85	isSubAgent           bool
  86	tools                []fantasy.AgentTool
  87	sessions             session.Service
  88	messages             message.Service
  89	disableAutoSummarize bool
  90	isYolo               bool
  91
  92	messageQueue   *csync.Map[string, []SessionAgentCall]
  93	activeRequests *csync.Map[string, context.CancelFunc]
  94}
  95
  96type SessionAgentOptions struct {
  97	LargeModel           Model
  98	SmallModel           Model
  99	SystemPromptPrefix   string
 100	SystemPrompt         string
 101	IsSubAgent           bool
 102	DisableAutoSummarize bool
 103	IsYolo               bool
 104	Sessions             session.Service
 105	Messages             message.Service
 106	Tools                []fantasy.AgentTool
 107}
 108
 109func NewSessionAgent(
 110	opts SessionAgentOptions,
 111) SessionAgent {
 112	return &sessionAgent{
 113		largeModel:           opts.LargeModel,
 114		smallModel:           opts.SmallModel,
 115		systemPromptPrefix:   opts.SystemPromptPrefix,
 116		systemPrompt:         opts.SystemPrompt,
 117		isSubAgent:           opts.IsSubAgent,
 118		sessions:             opts.Sessions,
 119		messages:             opts.Messages,
 120		disableAutoSummarize: opts.DisableAutoSummarize,
 121		tools:                opts.Tools,
 122		isYolo:               opts.IsYolo,
 123		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 124		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 125	}
 126}
 127
 128func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 129	if call.Prompt == "" {
 130		return nil, ErrEmptyPrompt
 131	}
 132	if call.SessionID == "" {
 133		return nil, ErrSessionMissing
 134	}
 135
 136	// Queue the message if busy
 137	if a.IsSessionBusy(call.SessionID) {
 138		existing, ok := a.messageQueue.Get(call.SessionID)
 139		if !ok {
 140			existing = []SessionAgentCall{}
 141		}
 142		existing = append(existing, call)
 143		a.messageQueue.Set(call.SessionID, existing)
 144		return nil, nil
 145	}
 146
 147	if len(a.tools) > 0 {
 148		// Add Anthropic caching to the last tool.
 149		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 150	}
 151
 152	agent := fantasy.NewAgent(
 153		a.largeModel.Model,
 154		fantasy.WithSystemPrompt(a.systemPrompt),
 155		fantasy.WithTools(a.tools...),
 156	)
 157
 158	sessionLock := sync.Mutex{}
 159	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 160	if err != nil {
 161		return nil, fmt.Errorf("failed to get session: %w", err)
 162	}
 163
 164	msgs, err := a.getSessionMessages(ctx, currentSession)
 165	if err != nil {
 166		return nil, fmt.Errorf("failed to get session messages: %w", err)
 167	}
 168
 169	var wg sync.WaitGroup
 170	// Generate title if first message.
 171	if len(msgs) == 0 {
 172		wg.Go(func() {
 173			sessionLock.Lock()
 174			a.generateTitle(ctx, &currentSession, call.Prompt)
 175			sessionLock.Unlock()
 176		})
 177	}
 178
 179	// Add the user message to the session.
 180	_, err = a.createUserMessage(ctx, call)
 181	if err != nil {
 182		return nil, err
 183	}
 184
 185	// Add the session to the context.
 186	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 187
 188	genCtx, cancel := context.WithCancel(ctx)
 189	a.activeRequests.Set(call.SessionID, cancel)
 190
 191	defer cancel()
 192	defer a.activeRequests.Del(call.SessionID)
 193
 194	history, files := a.preparePrompt(msgs, call.Attachments...)
 195
 196	startTime := time.Now()
 197	a.eventPromptSent(call.SessionID)
 198
 199	var currentAssistant *message.Message
 200	var shouldSummarize bool
 201	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 202		Prompt:           call.Prompt,
 203		Files:            files,
 204		Messages:         history,
 205		ProviderOptions:  call.ProviderOptions,
 206		MaxOutputTokens:  &call.MaxOutputTokens,
 207		TopP:             call.TopP,
 208		Temperature:      call.Temperature,
 209		PresencePenalty:  call.PresencePenalty,
 210		TopK:             call.TopK,
 211		FrequencyPenalty: call.FrequencyPenalty,
 212		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 213			prepared.Messages = options.Messages
 214			for i := range prepared.Messages {
 215				prepared.Messages[i].ProviderOptions = nil
 216			}
 217
 218			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 219			a.messageQueue.Del(call.SessionID)
 220			for _, queued := range queuedCalls {
 221				userMessage, createErr := a.createUserMessage(callContext, queued)
 222				if createErr != nil {
 223					return callContext, prepared, createErr
 224				}
 225				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 226			}
 227
 228			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 229
 230			lastSystemRoleInx := 0
 231			systemMessageUpdated := false
 232			for i, msg := range prepared.Messages {
 233				// Only add cache control to the last message.
 234				if msg.Role == fantasy.MessageRoleSystem {
 235					lastSystemRoleInx = i
 236				} else if !systemMessageUpdated {
 237					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 238					systemMessageUpdated = true
 239				}
 240				// Than add cache control to the last 2 messages.
 241				if i > len(prepared.Messages)-3 {
 242					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 243				}
 244			}
 245
 246			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 247				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 248			}
 249
 250			var assistantMsg message.Message
 251			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 252				Role:     message.Assistant,
 253				Parts:    []message.ContentPart{},
 254				Model:    a.largeModel.ModelCfg.Model,
 255				Provider: a.largeModel.ModelCfg.Provider,
 256			})
 257			if err != nil {
 258				return callContext, prepared, err
 259			}
 260			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 261			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 262			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 263			currentAssistant = &assistantMsg
 264			return callContext, prepared, err
 265		},
 266		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 267			currentAssistant.AppendReasoningContent(reasoning.Text)
 268			return a.messages.Update(genCtx, *currentAssistant)
 269		},
 270		OnReasoningDelta: func(id string, text string) error {
 271			currentAssistant.AppendReasoningContent(text)
 272			return a.messages.Update(genCtx, *currentAssistant)
 273		},
 274		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 275			// handle anthropic signature
 276			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 277				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 278					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 279				}
 280			}
 281			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 282				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 283					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 284				}
 285			}
 286			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 287				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 288					currentAssistant.SetReasoningResponsesData(reasoning)
 289				}
 290			}
 291			currentAssistant.FinishThinking()
 292			return a.messages.Update(genCtx, *currentAssistant)
 293		},
 294		OnTextDelta: func(id string, text string) error {
 295			// Strip leading newline from initial text content. This is is
 296			// particularly important in non-interactive mode where leading
 297			// newlines are very visible.
 298			if len(currentAssistant.Parts) == 0 {
 299				text = strings.TrimPrefix(text, "\n")
 300			}
 301
 302			currentAssistant.AppendContent(text)
 303			return a.messages.Update(genCtx, *currentAssistant)
 304		},
 305		OnToolInputStart: func(id string, toolName string) error {
 306			toolCall := message.ToolCall{
 307				ID:               id,
 308				Name:             toolName,
 309				ProviderExecuted: false,
 310				Finished:         false,
 311			}
 312			currentAssistant.AddToolCall(toolCall)
 313			return a.messages.Update(genCtx, *currentAssistant)
 314		},
 315		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 316			// TODO: implement
 317		},
 318		OnToolCall: func(tc fantasy.ToolCallContent) error {
 319			toolCall := message.ToolCall{
 320				ID:               tc.ToolCallID,
 321				Name:             tc.ToolName,
 322				Input:            tc.Input,
 323				ProviderExecuted: false,
 324				Finished:         true,
 325			}
 326			currentAssistant.AddToolCall(toolCall)
 327			return a.messages.Update(genCtx, *currentAssistant)
 328		},
 329		OnToolResult: func(result fantasy.ToolResultContent) error {
 330			toolResult := a.convertToToolResult(result)
 331			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 332				Role: message.Tool,
 333				Parts: []message.ContentPart{
 334					toolResult,
 335				},
 336			})
 337			return createMsgErr
 338		},
 339		OnStepFinish: func(stepResult fantasy.StepResult) error {
 340			finishReason := message.FinishReasonUnknown
 341			switch stepResult.FinishReason {
 342			case fantasy.FinishReasonLength:
 343				finishReason = message.FinishReasonMaxTokens
 344			case fantasy.FinishReasonStop:
 345				finishReason = message.FinishReasonEndTurn
 346			case fantasy.FinishReasonToolCalls:
 347				finishReason = message.FinishReasonToolUse
 348			}
 349			currentAssistant.AddFinish(finishReason, "", "")
 350			sessionLock.Lock()
 351			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 352			if getSessionErr != nil {
 353				sessionLock.Unlock()
 354				return getSessionErr
 355			}
 356			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 357			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 358			sessionLock.Unlock()
 359			if sessionErr != nil {
 360				return sessionErr
 361			}
 362			return a.messages.Update(genCtx, *currentAssistant)
 363		},
 364		StopWhen: []fantasy.StopCondition{
 365			func(_ []fantasy.StepResult) bool {
 366				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 367				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 368				remaining := cw - tokens
 369				var threshold int64
 370				if cw > 200_000 {
 371					threshold = 20_000
 372				} else {
 373					threshold = int64(float64(cw) * 0.2)
 374				}
 375				if (remaining <= threshold) && !a.disableAutoSummarize {
 376					shouldSummarize = true
 377					return true
 378				}
 379				return false
 380			},
 381		},
 382	})
 383
 384	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 385
 386	if err != nil {
 387		isCancelErr := errors.Is(err, context.Canceled)
 388		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 389		if currentAssistant == nil {
 390			return result, err
 391		}
 392		// Ensure we finish thinking on error to close the reasoning state.
 393		currentAssistant.FinishThinking()
 394		toolCalls := currentAssistant.ToolCalls()
 395		// INFO: we use the parent context here because the genCtx has been cancelled.
 396		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 397		if createErr != nil {
 398			return nil, createErr
 399		}
 400		for _, tc := range toolCalls {
 401			if !tc.Finished {
 402				tc.Finished = true
 403				tc.Input = "{}"
 404				currentAssistant.AddToolCall(tc)
 405				updateErr := a.messages.Update(ctx, *currentAssistant)
 406				if updateErr != nil {
 407					return nil, updateErr
 408				}
 409			}
 410
 411			found := false
 412			for _, msg := range msgs {
 413				if msg.Role == message.Tool {
 414					for _, tr := range msg.ToolResults() {
 415						if tr.ToolCallID == tc.ID {
 416							found = true
 417							break
 418						}
 419					}
 420				}
 421				if found {
 422					break
 423				}
 424			}
 425			if found {
 426				continue
 427			}
 428			content := "There was an error while executing the tool"
 429			if isCancelErr {
 430				content = "Tool execution canceled by user"
 431			} else if isPermissionErr {
 432				content = "User denied permission"
 433			}
 434			toolResult := message.ToolResult{
 435				ToolCallID: tc.ID,
 436				Name:       tc.Name,
 437				Content:    content,
 438				IsError:    true,
 439			}
 440			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 441				Role: message.Tool,
 442				Parts: []message.ContentPart{
 443					toolResult,
 444				},
 445			})
 446			if createErr != nil {
 447				return nil, createErr
 448			}
 449		}
 450		var fantasyErr *fantasy.Error
 451		var providerErr *fantasy.ProviderError
 452		const defaultTitle = "Provider Error"
 453		if isCancelErr {
 454			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 455		} else if isPermissionErr {
 456			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 457		} else if errors.As(err, &providerErr) {
 458			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 459		} else if errors.As(err, &fantasyErr) {
 460			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 461		} else {
 462			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 463		}
 464		// Note: we use the parent context here because the genCtx has been
 465		// cancelled.
 466		updateErr := a.messages.Update(ctx, *currentAssistant)
 467		if updateErr != nil {
 468			return nil, updateErr
 469		}
 470		return nil, err
 471	}
 472	wg.Wait()
 473
 474	if shouldSummarize {
 475		a.activeRequests.Del(call.SessionID)
 476		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 477			return nil, summarizeErr
 478		}
 479		// If the agent wasn't done...
 480		if len(currentAssistant.ToolCalls()) > 0 {
 481			existing, ok := a.messageQueue.Get(call.SessionID)
 482			if !ok {
 483				existing = []SessionAgentCall{}
 484			}
 485			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 486			existing = append(existing, call)
 487			a.messageQueue.Set(call.SessionID, existing)
 488		}
 489	}
 490
 491	// Release active request before processing queued messages.
 492	a.activeRequests.Del(call.SessionID)
 493	cancel()
 494
 495	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 496	if !ok || len(queuedMessages) == 0 {
 497		return result, err
 498	}
 499	// There are queued messages restart the loop.
 500	firstQueuedMessage := queuedMessages[0]
 501	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 502	return a.Run(ctx, firstQueuedMessage)
 503}
 504
 505func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 506	if a.IsSessionBusy(sessionID) {
 507		return ErrSessionBusy
 508	}
 509
 510	currentSession, err := a.sessions.Get(ctx, sessionID)
 511	if err != nil {
 512		return fmt.Errorf("failed to get session: %w", err)
 513	}
 514	msgs, err := a.getSessionMessages(ctx, currentSession)
 515	if err != nil {
 516		return err
 517	}
 518	if len(msgs) == 0 {
 519		// Nothing to summarize.
 520		return nil
 521	}
 522
 523	aiMsgs, _ := a.preparePrompt(msgs)
 524
 525	genCtx, cancel := context.WithCancel(ctx)
 526	a.activeRequests.Set(sessionID, cancel)
 527	defer a.activeRequests.Del(sessionID)
 528	defer cancel()
 529
 530	agent := fantasy.NewAgent(a.largeModel.Model,
 531		fantasy.WithSystemPrompt(string(summaryPrompt)),
 532	)
 533	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 534		Role:             message.Assistant,
 535		Model:            a.largeModel.Model.Model(),
 536		Provider:         a.largeModel.Model.Provider(),
 537		IsSummaryMessage: true,
 538	})
 539	if err != nil {
 540		return err
 541	}
 542
 543	summaryPromptText := "Provide a detailed summary of our conversation above."
 544	if len(currentSession.Todos) > 0 {
 545		summaryPromptText += "\n\n## Current Todo List\n\n"
 546		for _, t := range currentSession.Todos {
 547			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 548		}
 549		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 550		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 551	}
 552
 553	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 554		Prompt:          summaryPromptText,
 555		Messages:        aiMsgs,
 556		ProviderOptions: opts,
 557		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 558			prepared.Messages = options.Messages
 559			if a.systemPromptPrefix != "" {
 560				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 561			}
 562			return callContext, prepared, nil
 563		},
 564		OnReasoningDelta: func(id string, text string) error {
 565			summaryMessage.AppendReasoningContent(text)
 566			return a.messages.Update(genCtx, summaryMessage)
 567		},
 568		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 569			// Handle anthropic signature.
 570			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 571				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 572					summaryMessage.AppendReasoningSignature(signature.Signature)
 573				}
 574			}
 575			summaryMessage.FinishThinking()
 576			return a.messages.Update(genCtx, summaryMessage)
 577		},
 578		OnTextDelta: func(id, text string) error {
 579			summaryMessage.AppendContent(text)
 580			return a.messages.Update(genCtx, summaryMessage)
 581		},
 582	})
 583	if err != nil {
 584		isCancelErr := errors.Is(err, context.Canceled)
 585		if isCancelErr {
 586			// User cancelled summarize we need to remove the summary message.
 587			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 588			return deleteErr
 589		}
 590		return err
 591	}
 592
 593	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 594	err = a.messages.Update(genCtx, summaryMessage)
 595	if err != nil {
 596		return err
 597	}
 598
 599	var openrouterCost *float64
 600	for _, step := range resp.Steps {
 601		stepCost := a.openrouterCost(step.ProviderMetadata)
 602		if stepCost != nil {
 603			newCost := *stepCost
 604			if openrouterCost != nil {
 605				newCost += *openrouterCost
 606			}
 607			openrouterCost = &newCost
 608		}
 609	}
 610
 611	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 612
 613	// Just in case, get just the last usage info.
 614	usage := resp.Response.Usage
 615	currentSession.SummaryMessageID = summaryMessage.ID
 616	currentSession.CompletionTokens = usage.OutputTokens
 617	currentSession.PromptTokens = 0
 618	_, err = a.sessions.Save(genCtx, currentSession)
 619	return err
 620}
 621
 622func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 623	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 624		return fantasy.ProviderOptions{}
 625	}
 626	return fantasy.ProviderOptions{
 627		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 628			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 629		},
 630		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 631			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 632		},
 633	}
 634}
 635
 636func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 637	var attachmentParts []message.ContentPart
 638	for _, attachment := range call.Attachments {
 639		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 640	}
 641	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 642	parts = append(parts, attachmentParts...)
 643	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 644		Role:  message.User,
 645		Parts: parts,
 646	})
 647	if err != nil {
 648		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 649	}
 650	return msg, nil
 651}
 652
 653func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 654	var history []fantasy.Message
 655	if !a.isSubAgent {
 656		history = append(history, fantasy.NewUserMessage(
 657			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 658				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 659If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 660If not, please feel free to ignore. Again do not mention this message to the user.`,
 661			),
 662		))
 663	}
 664	for _, m := range msgs {
 665		if len(m.Parts) == 0 {
 666			continue
 667		}
 668		// Assistant message without content or tool calls (cancelled before it
 669		// returned anything).
 670		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 671			continue
 672		}
 673		history = append(history, m.ToAIMessage()...)
 674	}
 675
 676	var files []fantasy.FilePart
 677	for _, attachment := range attachments {
 678		files = append(files, fantasy.FilePart{
 679			Filename:  attachment.FileName,
 680			Data:      attachment.Content,
 681			MediaType: attachment.MimeType,
 682		})
 683	}
 684
 685	return history, files
 686}
 687
 688func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 689	msgs, err := a.messages.List(ctx, session.ID)
 690	if err != nil {
 691		return nil, fmt.Errorf("failed to list messages: %w", err)
 692	}
 693
 694	if session.SummaryMessageID != "" {
 695		summaryMsgInex := -1
 696		for i, msg := range msgs {
 697			if msg.ID == session.SummaryMessageID {
 698				summaryMsgInex = i
 699				break
 700			}
 701		}
 702		if summaryMsgInex != -1 {
 703			msgs = msgs[summaryMsgInex:]
 704			msgs[0].Role = message.User
 705		}
 706	}
 707	return msgs, nil
 708}
 709
 710func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
 711	if prompt == "" {
 712		return
 713	}
 714
 715	var maxOutput int64 = 40
 716	if a.smallModel.CatwalkCfg.CanReason {
 717		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 718	}
 719
 720	agent := fantasy.NewAgent(a.smallModel.Model,
 721		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 722		fantasy.WithMaxOutputTokens(maxOutput),
 723	)
 724
 725	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 726		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 727		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 728			prepared.Messages = options.Messages
 729			if a.systemPromptPrefix != "" {
 730				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 731			}
 732			return callContext, prepared, nil
 733		},
 734	})
 735	if err != nil {
 736		slog.Error("error generating title", "err", err)
 737		return
 738	}
 739
 740	title := resp.Response.Content.Text()
 741
 742	title = strings.ReplaceAll(title, "\n", " ")
 743
 744	// Remove thinking tags if present.
 745	if idx := strings.Index(title, "</think>"); idx > 0 {
 746		title = title[idx+len("</think>"):]
 747	}
 748
 749	title = strings.TrimSpace(title)
 750	if title == "" {
 751		slog.Warn("failed to generate title", "warn", "empty title")
 752		return
 753	}
 754
 755	session.Title = title
 756
 757	var openrouterCost *float64
 758	for _, step := range resp.Steps {
 759		stepCost := a.openrouterCost(step.ProviderMetadata)
 760		if stepCost != nil {
 761			newCost := *stepCost
 762			if openrouterCost != nil {
 763				newCost += *openrouterCost
 764			}
 765			openrouterCost = &newCost
 766		}
 767	}
 768
 769	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
 770	_, saveErr := a.sessions.Save(ctx, *session)
 771	if saveErr != nil {
 772		slog.Error("failed to save session title & usage", "error", saveErr)
 773		return
 774	}
 775}
 776
 777func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 778	openrouterMetadata, ok := metadata[openrouter.Name]
 779	if !ok {
 780		return nil
 781	}
 782
 783	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 784	if !ok {
 785		return nil
 786	}
 787	return &opts.Usage.Cost
 788}
 789
 790func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 791	modelConfig := model.CatwalkCfg
 792	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 793		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 794		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 795		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 796
 797	if a.isClaudeCode() {
 798		cost = 0
 799	}
 800
 801	a.eventTokensUsed(session.ID, model, usage, cost)
 802
 803	if overrideCost != nil {
 804		session.Cost += *overrideCost
 805	} else {
 806		session.Cost += cost
 807	}
 808
 809	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 810	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 811}
 812
 813func (a *sessionAgent) Cancel(sessionID string) {
 814	// Cancel regular requests.
 815	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 816		slog.Info("Request cancellation initiated", "session_id", sessionID)
 817		cancel()
 818	}
 819
 820	// Also check for summarize requests.
 821	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 822		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 823		cancel()
 824	}
 825
 826	if a.QueuedPrompts(sessionID) > 0 {
 827		slog.Info("Clearing queued prompts", "session_id", sessionID)
 828		a.messageQueue.Del(sessionID)
 829	}
 830}
 831
 832func (a *sessionAgent) ClearQueue(sessionID string) {
 833	if a.QueuedPrompts(sessionID) > 0 {
 834		slog.Info("Clearing queued prompts", "session_id", sessionID)
 835		a.messageQueue.Del(sessionID)
 836	}
 837}
 838
 839func (a *sessionAgent) CancelAll() {
 840	if !a.IsBusy() {
 841		return
 842	}
 843	for key := range a.activeRequests.Seq2() {
 844		a.Cancel(key) // key is sessionID
 845	}
 846
 847	timeout := time.After(5 * time.Second)
 848	for a.IsBusy() {
 849		select {
 850		case <-timeout:
 851			return
 852		default:
 853			time.Sleep(200 * time.Millisecond)
 854		}
 855	}
 856}
 857
 858func (a *sessionAgent) IsBusy() bool {
 859	var busy bool
 860	for cancelFunc := range a.activeRequests.Seq() {
 861		if cancelFunc != nil {
 862			busy = true
 863			break
 864		}
 865	}
 866	return busy
 867}
 868
 869func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 870	_, busy := a.activeRequests.Get(sessionID)
 871	return busy
 872}
 873
 874func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 875	l, ok := a.messageQueue.Get(sessionID)
 876	if !ok {
 877		return 0
 878	}
 879	return len(l)
 880}
 881
 882func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 883	l, ok := a.messageQueue.Get(sessionID)
 884	if !ok {
 885		return nil
 886	}
 887	prompts := make([]string, len(l))
 888	for i, call := range l {
 889		prompts[i] = call.Prompt
 890	}
 891	return prompts
 892}
 893
 894func (a *sessionAgent) SetModels(large Model, small Model) {
 895	a.largeModel = large
 896	a.smallModel = small
 897}
 898
 899func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 900	a.tools = tools
 901}
 902
 903func (a *sessionAgent) Model() Model {
 904	return a.largeModel
 905}
 906
 907func (a *sessionAgent) promptPrefix() string {
 908	if a.isClaudeCode() {
 909		return "You are Claude Code, Anthropic's official CLI for Claude."
 910	}
 911	return a.systemPromptPrefix
 912}
 913
 914func (a *sessionAgent) isClaudeCode() bool {
 915	cfg := config.Get()
 916	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 917	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 918}
 919
 920// convertToToolResult converts a fantasy tool result to a message tool result.
 921func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 922	baseResult := message.ToolResult{
 923		ToolCallID: result.ToolCallID,
 924		Name:       result.ToolName,
 925		Metadata:   result.ClientMetadata,
 926	}
 927
 928	switch result.Result.GetType() {
 929	case fantasy.ToolResultContentTypeText:
 930		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 931			baseResult.Content = r.Text
 932		}
 933	case fantasy.ToolResultContentTypeError:
 934		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 935			baseResult.Content = r.Error.Error()
 936			baseResult.IsError = true
 937		}
 938	case fantasy.ToolResultContentTypeMedia:
 939		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 940			content := r.Text
 941			if content == "" {
 942				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 943			}
 944			baseResult.Content = content
 945			baseResult.Data = r.Data
 946			baseResult.MIMEType = r.MediaType
 947		}
 948	}
 949
 950	return baseResult
 951}
 952
 953// workaroundProviderMediaLimitations converts media content in tool results to
 954// user messages for providers that don't natively support images in tool results.
 955//
 956// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 957// don't support sending images/media in tool result messages - they only accept
 958// text in tool results. However, they DO support images in user messages.
 959//
 960// If we send media in tool results to these providers, the API returns an error.
 961//
 962// Solution: For these providers, we:
 963//  1. Replace the media in the tool result with a text placeholder
 964//  2. Inject a user message immediately after with the image as a file attachment
 965//  3. This maintains the tool execution flow while working around API limitations
 966//
 967// Anthropic and Bedrock support images natively in tool results, so we skip
 968// this workaround for them.
 969//
 970// Example transformation:
 971//
 972//	BEFORE: [tool result: image data]
 973//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
 974func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
 975	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
 976		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
 977
 978	if providerSupportsMedia {
 979		return messages
 980	}
 981
 982	convertedMessages := make([]fantasy.Message, 0, len(messages))
 983
 984	for _, msg := range messages {
 985		if msg.Role != fantasy.MessageRoleTool {
 986			convertedMessages = append(convertedMessages, msg)
 987			continue
 988		}
 989
 990		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
 991		var mediaFiles []fantasy.FilePart
 992
 993		for _, part := range msg.Content {
 994			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
 995			if !ok {
 996				textParts = append(textParts, part)
 997				continue
 998			}
 999
1000			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1001				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1002				if err != nil {
1003					slog.Warn("failed to decode media data", "error", err)
1004					textParts = append(textParts, part)
1005					continue
1006				}
1007
1008				mediaFiles = append(mediaFiles, fantasy.FilePart{
1009					Data:      decoded,
1010					MediaType: media.MediaType,
1011					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1012				})
1013
1014				textParts = append(textParts, fantasy.ToolResultPart{
1015					ToolCallID: toolResult.ToolCallID,
1016					Output: fantasy.ToolResultOutputContentText{
1017						Text: "[Image/media content loaded - see attached file]",
1018					},
1019					ProviderOptions: toolResult.ProviderOptions,
1020				})
1021			} else {
1022				textParts = append(textParts, part)
1023			}
1024		}
1025
1026		convertedMessages = append(convertedMessages, fantasy.Message{
1027			Role:    fantasy.MessageRoleTool,
1028			Content: textParts,
1029		})
1030
1031		if len(mediaFiles) > 0 {
1032			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1033				"Here is the media content from the tool result:",
1034				mediaFiles...,
1035			))
1036		}
1037	}
1038
1039	return convertedMessages
1040}