agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"regexp"
  20	"strconv"
  21	"strings"
  22	"sync"
  23	"time"
  24
  25	"charm.land/fantasy"
  26	"charm.land/fantasy/providers/anthropic"
  27	"charm.land/fantasy/providers/bedrock"
  28	"charm.land/fantasy/providers/google"
  29	"charm.land/fantasy/providers/openai"
  30	"charm.land/fantasy/providers/openrouter"
  31	"charm.land/lipgloss/v2"
  32	"git.secluded.site/crush/internal/agent/hyper"
  33	"git.secluded.site/crush/internal/agent/tools"
  34	"git.secluded.site/crush/internal/config"
  35	"git.secluded.site/crush/internal/csync"
  36	"git.secluded.site/crush/internal/message"
  37	"git.secluded.site/crush/internal/notification"
  38	"git.secluded.site/crush/internal/permission"
  39	"git.secluded.site/crush/internal/session"
  40	"git.secluded.site/crush/internal/stringext"
  41	"github.com/charmbracelet/catwalk/pkg/catwalk"
  42)
  43
  44const defaultSessionName = "Untitled Session"
  45
  46//go:embed templates/title.md
  47var titlePrompt []byte
  48
  49//go:embed templates/summary.md
  50var summaryPrompt []byte
  51
  52// Used to remove <think> tags from generated titles.
  53var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
  54
  55type SessionAgentCall struct {
  56	SessionID        string
  57	Prompt           string
  58	ProviderOptions  fantasy.ProviderOptions
  59	Attachments      []message.Attachment
  60	MaxOutputTokens  int64
  61	Temperature      *float64
  62	TopP             *float64
  63	TopK             *int64
  64	FrequencyPenalty *float64
  65	PresencePenalty  *float64
  66	NonInteractive   bool
  67}
  68
  69type SessionAgent interface {
  70	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  71	SetModels(large Model, small Model)
  72	SetTools(tools []fantasy.AgentTool)
  73	Cancel(sessionID string)
  74	CancelAll()
  75	IsSessionBusy(sessionID string) bool
  76	IsBusy() bool
  77	QueuedPrompts(sessionID string) int
  78	QueuedPromptsList(sessionID string) []string
  79	ClearQueue(sessionID string)
  80	Summarize(context.Context, string, fantasy.ProviderOptions) error
  81	Model() Model
  82}
  83
  84type Model struct {
  85	Model      fantasy.LanguageModel
  86	CatwalkCfg catwalk.Model
  87	ModelCfg   config.SelectedModel
  88}
  89
  90type sessionAgent struct {
  91	largeModel           Model
  92	smallModel           Model
  93	systemPromptPrefix   string
  94	systemPrompt         string
  95	isSubAgent           bool
  96	tools                []fantasy.AgentTool
  97	sessions             session.Service
  98	messages             message.Service
  99	disableAutoSummarize bool
 100	isYolo               bool
 101
 102	messageQueue   *csync.Map[string, []SessionAgentCall]
 103	activeRequests *csync.Map[string, context.CancelFunc]
 104}
 105
 106type SessionAgentOptions struct {
 107	LargeModel           Model
 108	SmallModel           Model
 109	SystemPromptPrefix   string
 110	SystemPrompt         string
 111	IsSubAgent           bool
 112	DisableAutoSummarize bool
 113	IsYolo               bool
 114	Sessions             session.Service
 115	Messages             message.Service
 116	Tools                []fantasy.AgentTool
 117}
 118
 119func NewSessionAgent(
 120	opts SessionAgentOptions,
 121) SessionAgent {
 122	return &sessionAgent{
 123		largeModel:           opts.LargeModel,
 124		smallModel:           opts.SmallModel,
 125		systemPromptPrefix:   opts.SystemPromptPrefix,
 126		systemPrompt:         opts.SystemPrompt,
 127		isSubAgent:           opts.IsSubAgent,
 128		sessions:             opts.Sessions,
 129		messages:             opts.Messages,
 130		disableAutoSummarize: opts.DisableAutoSummarize,
 131		tools:                opts.Tools,
 132		isYolo:               opts.IsYolo,
 133		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 134		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 135	}
 136}
 137
 138func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 139	if call.Prompt == "" {
 140		return nil, ErrEmptyPrompt
 141	}
 142	if call.SessionID == "" {
 143		return nil, ErrSessionMissing
 144	}
 145
 146	// Queue the message if busy
 147	if a.IsSessionBusy(call.SessionID) {
 148		existing, ok := a.messageQueue.Get(call.SessionID)
 149		if !ok {
 150			existing = []SessionAgentCall{}
 151		}
 152		existing = append(existing, call)
 153		a.messageQueue.Set(call.SessionID, existing)
 154		return nil, nil
 155	}
 156
 157	if len(a.tools) > 0 {
 158		// Add Anthropic caching to the last tool.
 159		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 160	}
 161
 162	agent := fantasy.NewAgent(
 163		a.largeModel.Model,
 164		fantasy.WithSystemPrompt(a.systemPrompt),
 165		fantasy.WithTools(a.tools...),
 166	)
 167
 168	sessionLock := sync.Mutex{}
 169	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 170	if err != nil {
 171		return nil, fmt.Errorf("failed to get session: %w", err)
 172	}
 173
 174	msgs, err := a.getSessionMessages(ctx, currentSession)
 175	if err != nil {
 176		return nil, fmt.Errorf("failed to get session messages: %w", err)
 177	}
 178
 179	var wg sync.WaitGroup
 180	// Generate title if first message.
 181	if len(msgs) == 0 {
 182		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 183		wg.Go(func() {
 184			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 185		})
 186	}
 187
 188	// Add the user message to the session.
 189	_, err = a.createUserMessage(ctx, call)
 190	if err != nil {
 191		return nil, err
 192	}
 193
 194	// Add the session to the context.
 195	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 196
 197	genCtx, cancel := context.WithCancel(ctx)
 198	a.activeRequests.Set(call.SessionID, cancel)
 199
 200	defer cancel()
 201	defer a.activeRequests.Del(call.SessionID)
 202
 203	history, files := a.preparePrompt(msgs, call.Attachments...)
 204
 205	startTime := time.Now()
 206	a.eventPromptSent(call.SessionID)
 207
 208	var currentAssistant *message.Message
 209	var shouldSummarize bool
 210	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 211		Prompt:           message.PromptWithTextAttachments(call.Prompt, call.Attachments),
 212		Files:            files,
 213		Messages:         history,
 214		ProviderOptions:  call.ProviderOptions,
 215		MaxOutputTokens:  &call.MaxOutputTokens,
 216		TopP:             call.TopP,
 217		Temperature:      call.Temperature,
 218		PresencePenalty:  call.PresencePenalty,
 219		TopK:             call.TopK,
 220		FrequencyPenalty: call.FrequencyPenalty,
 221		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 222			prepared.Messages = options.Messages
 223			for i := range prepared.Messages {
 224				prepared.Messages[i].ProviderOptions = nil
 225			}
 226
 227			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 228			a.messageQueue.Del(call.SessionID)
 229			for _, queued := range queuedCalls {
 230				userMessage, createErr := a.createUserMessage(callContext, queued)
 231				if createErr != nil {
 232					return callContext, prepared, createErr
 233				}
 234				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 235			}
 236
 237			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 238
 239			lastSystemRoleInx := 0
 240			systemMessageUpdated := false
 241			for i, msg := range prepared.Messages {
 242				// Only add cache control to the last message.
 243				if msg.Role == fantasy.MessageRoleSystem {
 244					lastSystemRoleInx = i
 245				} else if !systemMessageUpdated {
 246					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 247					systemMessageUpdated = true
 248				}
 249				// Than add cache control to the last 2 messages.
 250				if i > len(prepared.Messages)-3 {
 251					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 252				}
 253			}
 254
 255			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 256				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 257			}
 258
 259			var assistantMsg message.Message
 260			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 261				Role:     message.Assistant,
 262				Parts:    []message.ContentPart{},
 263				Model:    a.largeModel.ModelCfg.Model,
 264				Provider: a.largeModel.ModelCfg.Provider,
 265			})
 266			if err != nil {
 267				return callContext, prepared, err
 268			}
 269			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 270			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 271			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 272			currentAssistant = &assistantMsg
 273			return callContext, prepared, err
 274		},
 275		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 276			currentAssistant.AppendReasoningContent(reasoning.Text)
 277			return a.messages.Update(genCtx, *currentAssistant)
 278		},
 279		OnReasoningDelta: func(id string, text string) error {
 280			currentAssistant.AppendReasoningContent(text)
 281			return a.messages.Update(genCtx, *currentAssistant)
 282		},
 283		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 284			// handle anthropic signature
 285			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 286				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 287					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 288				}
 289			}
 290			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 291				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 292					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 293				}
 294			}
 295			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 296				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 297					currentAssistant.SetReasoningResponsesData(reasoning)
 298				}
 299			}
 300			currentAssistant.FinishThinking()
 301			return a.messages.Update(genCtx, *currentAssistant)
 302		},
 303		OnTextDelta: func(id string, text string) error {
 304			// Strip leading newline from initial text content. This is is
 305			// particularly important in non-interactive mode where leading
 306			// newlines are very visible.
 307			if len(currentAssistant.Parts) == 0 {
 308				text = strings.TrimPrefix(text, "\n")
 309			}
 310
 311			currentAssistant.AppendContent(text)
 312			return a.messages.Update(genCtx, *currentAssistant)
 313		},
 314		OnToolInputStart: func(id string, toolName string) error {
 315			toolCall := message.ToolCall{
 316				ID:               id,
 317				Name:             toolName,
 318				ProviderExecuted: false,
 319				Finished:         false,
 320			}
 321			currentAssistant.AddToolCall(toolCall)
 322			return a.messages.Update(genCtx, *currentAssistant)
 323		},
 324		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 325			// TODO: implement
 326		},
 327		OnToolCall: func(tc fantasy.ToolCallContent) error {
 328			toolCall := message.ToolCall{
 329				ID:               tc.ToolCallID,
 330				Name:             tc.ToolName,
 331				Input:            tc.Input,
 332				ProviderExecuted: false,
 333				Finished:         true,
 334			}
 335			currentAssistant.AddToolCall(toolCall)
 336			return a.messages.Update(genCtx, *currentAssistant)
 337		},
 338		OnToolResult: func(result fantasy.ToolResultContent) error {
 339			toolResult := a.convertToToolResult(result)
 340			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 341				Role: message.Tool,
 342				Parts: []message.ContentPart{
 343					toolResult,
 344				},
 345			})
 346			return createMsgErr
 347		},
 348		OnStepFinish: func(stepResult fantasy.StepResult) error {
 349			finishReason := message.FinishReasonUnknown
 350			switch stepResult.FinishReason {
 351			case fantasy.FinishReasonLength:
 352				finishReason = message.FinishReasonMaxTokens
 353			case fantasy.FinishReasonStop:
 354				finishReason = message.FinishReasonEndTurn
 355			case fantasy.FinishReasonToolCalls:
 356				finishReason = message.FinishReasonToolUse
 357			}
 358			currentAssistant.AddFinish(finishReason, "", "")
 359			sessionLock.Lock()
 360			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 361			if getSessionErr != nil {
 362				sessionLock.Unlock()
 363				return getSessionErr
 364			}
 365			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 366			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 367			sessionLock.Unlock()
 368			if sessionErr != nil {
 369				return sessionErr
 370			}
 371			return a.messages.Update(genCtx, *currentAssistant)
 372		},
 373		StopWhen: []fantasy.StopCondition{
 374			func(_ []fantasy.StepResult) bool {
 375				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 376				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 377				remaining := cw - tokens
 378				var threshold int64
 379				if cw > 200_000 {
 380					threshold = 20_000
 381				} else {
 382					threshold = int64(float64(cw) * 0.2)
 383				}
 384				if (remaining <= threshold) && !a.disableAutoSummarize {
 385					shouldSummarize = true
 386					return true
 387				}
 388				return false
 389			},
 390		},
 391	})
 392
 393	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 394
 395	if err != nil {
 396		isCancelErr := errors.Is(err, context.Canceled)
 397		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 398		if currentAssistant == nil {
 399			return result, err
 400		}
 401		// Ensure we finish thinking on error to close the reasoning state.
 402		currentAssistant.FinishThinking()
 403		toolCalls := currentAssistant.ToolCalls()
 404		// INFO: we use the parent context here because the genCtx has been cancelled.
 405		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 406		if createErr != nil {
 407			return nil, createErr
 408		}
 409		for _, tc := range toolCalls {
 410			if !tc.Finished {
 411				tc.Finished = true
 412				tc.Input = "{}"
 413				currentAssistant.AddToolCall(tc)
 414				updateErr := a.messages.Update(ctx, *currentAssistant)
 415				if updateErr != nil {
 416					return nil, updateErr
 417				}
 418			}
 419
 420			found := false
 421			for _, msg := range msgs {
 422				if msg.Role == message.Tool {
 423					for _, tr := range msg.ToolResults() {
 424						if tr.ToolCallID == tc.ID {
 425							found = true
 426							break
 427						}
 428					}
 429				}
 430				if found {
 431					break
 432				}
 433			}
 434			if found {
 435				continue
 436			}
 437			content := "There was an error while executing the tool"
 438			if isCancelErr {
 439				content = "Tool execution canceled by user"
 440			} else if isPermissionErr {
 441				content = "User denied permission"
 442			}
 443			toolResult := message.ToolResult{
 444				ToolCallID: tc.ID,
 445				Name:       tc.Name,
 446				Content:    content,
 447				IsError:    true,
 448			}
 449			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 450				Role: message.Tool,
 451				Parts: []message.ContentPart{
 452					toolResult,
 453				},
 454			})
 455			if createErr != nil {
 456				return nil, createErr
 457			}
 458		}
 459		var fantasyErr *fantasy.Error
 460		var providerErr *fantasy.ProviderError
 461		const defaultTitle = "Provider Error"
 462		if isCancelErr {
 463			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 464		} else if isPermissionErr {
 465			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 466		} else if errors.Is(err, hyper.ErrNoCredits) {
 467			url := hyper.BaseURL()
 468			link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 469			currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
 470		} else if errors.As(err, &providerErr) {
 471			if providerErr.Message == "The requested model is not supported." {
 472				url := "https://github.com/settings/copilot/features"
 473				link := lipgloss.NewStyle().Hyperlink(url, "id=hyper").Render(url)
 474				currentAssistant.AddFinish(
 475					message.FinishReasonError,
 476					"Copilot model not enabled",
 477					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait a minute before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
 478				)
 479			} else {
 480				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 481			}
 482		} else if errors.As(err, &fantasyErr) {
 483			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 484		} else {
 485			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 486		}
 487		// Note: we use the parent context here because the genCtx has been
 488		// cancelled.
 489		updateErr := a.messages.Update(ctx, *currentAssistant)
 490		if updateErr != nil {
 491			return nil, updateErr
 492		}
 493		return nil, err
 494	}
 495	wg.Wait()
 496
 497	// Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
 498	if !call.NonInteractive {
 499		notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
 500		_ = notification.Send("Crush is waiting...", notifBody)
 501	}
 502
 503	if shouldSummarize {
 504		a.activeRequests.Del(call.SessionID)
 505		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 506			return nil, summarizeErr
 507		}
 508		// If the agent wasn't done...
 509		if len(currentAssistant.ToolCalls()) > 0 {
 510			existing, ok := a.messageQueue.Get(call.SessionID)
 511			if !ok {
 512				existing = []SessionAgentCall{}
 513			}
 514			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 515			existing = append(existing, call)
 516			a.messageQueue.Set(call.SessionID, existing)
 517		}
 518	}
 519
 520	// Release active request before processing queued messages.
 521	a.activeRequests.Del(call.SessionID)
 522	cancel()
 523
 524	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 525	if !ok || len(queuedMessages) == 0 {
 526		return result, err
 527	}
 528	// There are queued messages restart the loop.
 529	firstQueuedMessage := queuedMessages[0]
 530	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 531	return a.Run(ctx, firstQueuedMessage)
 532}
 533
 534func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 535	if a.IsSessionBusy(sessionID) {
 536		return ErrSessionBusy
 537	}
 538
 539	currentSession, err := a.sessions.Get(ctx, sessionID)
 540	if err != nil {
 541		return fmt.Errorf("failed to get session: %w", err)
 542	}
 543	msgs, err := a.getSessionMessages(ctx, currentSession)
 544	if err != nil {
 545		return err
 546	}
 547	if len(msgs) == 0 {
 548		// Nothing to summarize.
 549		return nil
 550	}
 551
 552	aiMsgs, _ := a.preparePrompt(msgs)
 553
 554	genCtx, cancel := context.WithCancel(ctx)
 555	a.activeRequests.Set(sessionID, cancel)
 556	defer a.activeRequests.Del(sessionID)
 557	defer cancel()
 558
 559	agent := fantasy.NewAgent(a.largeModel.Model,
 560		fantasy.WithSystemPrompt(string(summaryPrompt)),
 561	)
 562	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 563		Role:             message.Assistant,
 564		Model:            a.largeModel.Model.Model(),
 565		Provider:         a.largeModel.Model.Provider(),
 566		IsSummaryMessage: true,
 567	})
 568	if err != nil {
 569		return err
 570	}
 571
 572	summaryPromptText := "Provide a detailed summary of our conversation above."
 573	if len(currentSession.Todos) > 0 {
 574		summaryPromptText += "\n\n## Current Todo List\n\n"
 575		for _, t := range currentSession.Todos {
 576			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 577		}
 578		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 579		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 580	}
 581
 582	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 583		Prompt:          summaryPromptText,
 584		Messages:        aiMsgs,
 585		ProviderOptions: opts,
 586		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 587			prepared.Messages = options.Messages
 588			if a.systemPromptPrefix != "" {
 589				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 590			}
 591			return callContext, prepared, nil
 592		},
 593		OnReasoningDelta: func(id string, text string) error {
 594			summaryMessage.AppendReasoningContent(text)
 595			return a.messages.Update(genCtx, summaryMessage)
 596		},
 597		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 598			// Handle anthropic signature.
 599			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 600				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 601					summaryMessage.AppendReasoningSignature(signature.Signature)
 602				}
 603			}
 604			summaryMessage.FinishThinking()
 605			return a.messages.Update(genCtx, summaryMessage)
 606		},
 607		OnTextDelta: func(id, text string) error {
 608			summaryMessage.AppendContent(text)
 609			return a.messages.Update(genCtx, summaryMessage)
 610		},
 611	})
 612	if err != nil {
 613		isCancelErr := errors.Is(err, context.Canceled)
 614		if isCancelErr {
 615			// User cancelled summarize we need to remove the summary message.
 616			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 617			return deleteErr
 618		}
 619		return err
 620	}
 621
 622	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 623	err = a.messages.Update(genCtx, summaryMessage)
 624	if err != nil {
 625		return err
 626	}
 627
 628	var openrouterCost *float64
 629	for _, step := range resp.Steps {
 630		stepCost := a.openrouterCost(step.ProviderMetadata)
 631		if stepCost != nil {
 632			newCost := *stepCost
 633			if openrouterCost != nil {
 634				newCost += *openrouterCost
 635			}
 636			openrouterCost = &newCost
 637		}
 638	}
 639
 640	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 641
 642	// Just in case, get just the last usage info.
 643	usage := resp.Response.Usage
 644	currentSession.SummaryMessageID = summaryMessage.ID
 645	currentSession.CompletionTokens = usage.OutputTokens
 646	currentSession.PromptTokens = 0
 647	_, err = a.sessions.Save(genCtx, currentSession)
 648	return err
 649}
 650
 651func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 652	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 653		return fantasy.ProviderOptions{}
 654	}
 655	return fantasy.ProviderOptions{
 656		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 657			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 658		},
 659		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 660			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 661		},
 662	}
 663}
 664
 665func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 666	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 667	var attachmentParts []message.ContentPart
 668	for _, attachment := range call.Attachments {
 669		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 670	}
 671	parts = append(parts, attachmentParts...)
 672	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 673		Role:  message.User,
 674		Parts: parts,
 675	})
 676	if err != nil {
 677		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 678	}
 679	return msg, nil
 680}
 681
 682func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 683	var history []fantasy.Message
 684	if !a.isSubAgent {
 685		history = append(history, fantasy.NewUserMessage(
 686			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 687				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 688If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 689If not, please feel free to ignore. Again do not mention this message to the user.`,
 690			),
 691		))
 692	}
 693	for _, m := range msgs {
 694		if len(m.Parts) == 0 {
 695			continue
 696		}
 697		// Assistant message without content or tool calls (cancelled before it
 698		// returned anything).
 699		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 700			continue
 701		}
 702		history = append(history, m.ToAIMessage()...)
 703	}
 704
 705	var files []fantasy.FilePart
 706	for _, attachment := range attachments {
 707		if attachment.IsText() {
 708			continue
 709		}
 710		files = append(files, fantasy.FilePart{
 711			Filename:  attachment.FileName,
 712			Data:      attachment.Content,
 713			MediaType: attachment.MimeType,
 714		})
 715	}
 716
 717	return history, files
 718}
 719
 720func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 721	msgs, err := a.messages.List(ctx, session.ID)
 722	if err != nil {
 723		return nil, fmt.Errorf("failed to list messages: %w", err)
 724	}
 725
 726	if session.SummaryMessageID != "" {
 727		summaryMsgInex := -1
 728		for i, msg := range msgs {
 729			if msg.ID == session.SummaryMessageID {
 730				summaryMsgInex = i
 731				break
 732			}
 733		}
 734		if summaryMsgInex != -1 {
 735			msgs = msgs[summaryMsgInex:]
 736			msgs[0].Role = message.User
 737		}
 738	}
 739	return msgs, nil
 740}
 741
 742// generateTitle generates a session titled based on the initial prompt.
 743func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
 744	if userPrompt == "" {
 745		return
 746	}
 747
 748	var maxOutputTokens int64 = 40
 749	if a.smallModel.CatwalkCfg.CanReason {
 750		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 751	}
 752
 753	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
 754		return fantasy.NewAgent(m,
 755			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
 756			fantasy.WithMaxOutputTokens(tok),
 757		)
 758	}
 759
 760	streamCall := fantasy.AgentStreamCall{
 761		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 762		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 763			prepared.Messages = opts.Messages
 764			if a.systemPromptPrefix != "" {
 765				prepared.Messages = append([]fantasy.Message{
 766					fantasy.NewSystemMessage(a.systemPromptPrefix),
 767				}, prepared.Messages...)
 768			}
 769			return callCtx, prepared, nil
 770		},
 771	}
 772
 773	// Use the small model to generate the title.
 774	model := &a.smallModel
 775	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 776	resp, err := agent.Stream(ctx, streamCall)
 777	if err == nil {
 778		// We successfully generated a title with the small model.
 779		slog.Info("generated title with small model")
 780	} else {
 781		// It didn't work. Let's try with the big model.
 782		slog.Error("error generating title with small model; trying big model", "err", err)
 783		model = &a.largeModel
 784		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 785		resp, err = agent.Stream(ctx, streamCall)
 786		if err == nil {
 787			slog.Info("generated title with large model")
 788		} else {
 789			// Welp, the large model didn't work either.
 790			slog.Error("error generating title with large model", "err", err)
 791		}
 792	}
 793
 794	title := strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
 795	slog.Info("generated title", "title", title)
 796
 797	// Remove thinking tags if present.
 798	title = thinkTagRegex.ReplaceAllString(title, "")
 799
 800	title = strings.TrimSpace(title)
 801	if title == "" {
 802		slog.Warn("empty title; using fallback")
 803		title = defaultSessionName
 804	}
 805
 806	// Calculate usage and cost.
 807	var openrouterCost *float64
 808	for _, step := range resp.Steps {
 809		stepCost := a.openrouterCost(step.ProviderMetadata)
 810		if stepCost != nil {
 811			newCost := *stepCost
 812			if openrouterCost != nil {
 813				newCost += *openrouterCost
 814			}
 815			openrouterCost = &newCost
 816		}
 817	}
 818
 819	if model == nil {
 820		slog.Error("no model available for cost calculation")
 821		return
 822	}
 823	modelConfig := model.CatwalkCfg
 824	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 825		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 826		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
 827		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
 828
 829	if a.isClaudeCode() {
 830		cost = 0
 831	}
 832
 833	// Use override cost if available (e.g., from OpenRouter).
 834	if openrouterCost != nil {
 835		cost = *openrouterCost
 836	}
 837
 838	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
 839	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
 840
 841	// Atomically update only title and usage fields to avoid overriding other
 842	// concurrent session updates.
 843	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 844	if saveErr != nil {
 845		slog.Error("failed to save session title and usage", "error", saveErr)
 846		return
 847	}
 848}
 849
 850func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 851	openrouterMetadata, ok := metadata[openrouter.Name]
 852	if !ok {
 853		return nil
 854	}
 855
 856	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 857	if !ok {
 858		return nil
 859	}
 860	return &opts.Usage.Cost
 861}
 862
 863func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 864	modelConfig := model.CatwalkCfg
 865	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 866		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 867		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 868		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 869
 870	if a.isClaudeCode() {
 871		cost = 0
 872	}
 873
 874	a.eventTokensUsed(session.ID, model, usage, cost)
 875
 876	if overrideCost != nil {
 877		session.Cost += *overrideCost
 878	} else {
 879		session.Cost += cost
 880	}
 881
 882	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 883	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 884}
 885
 886func (a *sessionAgent) Cancel(sessionID string) {
 887	// Cancel regular requests.
 888	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 889		slog.Info("Request cancellation initiated", "session_id", sessionID)
 890		cancel()
 891	}
 892
 893	// Also check for summarize requests.
 894	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 895		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 896		cancel()
 897	}
 898
 899	if a.QueuedPrompts(sessionID) > 0 {
 900		slog.Info("Clearing queued prompts", "session_id", sessionID)
 901		a.messageQueue.Del(sessionID)
 902	}
 903}
 904
 905func (a *sessionAgent) ClearQueue(sessionID string) {
 906	if a.QueuedPrompts(sessionID) > 0 {
 907		slog.Info("Clearing queued prompts", "session_id", sessionID)
 908		a.messageQueue.Del(sessionID)
 909	}
 910}
 911
 912func (a *sessionAgent) CancelAll() {
 913	if !a.IsBusy() {
 914		return
 915	}
 916	for key := range a.activeRequests.Seq2() {
 917		a.Cancel(key) // key is sessionID
 918	}
 919
 920	timeout := time.After(5 * time.Second)
 921	for a.IsBusy() {
 922		select {
 923		case <-timeout:
 924			return
 925		default:
 926			time.Sleep(200 * time.Millisecond)
 927		}
 928	}
 929}
 930
 931func (a *sessionAgent) IsBusy() bool {
 932	var busy bool
 933	for cancelFunc := range a.activeRequests.Seq() {
 934		if cancelFunc != nil {
 935			busy = true
 936			break
 937		}
 938	}
 939	return busy
 940}
 941
 942func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 943	_, busy := a.activeRequests.Get(sessionID)
 944	return busy
 945}
 946
 947func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 948	l, ok := a.messageQueue.Get(sessionID)
 949	if !ok {
 950		return 0
 951	}
 952	return len(l)
 953}
 954
 955func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 956	l, ok := a.messageQueue.Get(sessionID)
 957	if !ok {
 958		return nil
 959	}
 960	prompts := make([]string, len(l))
 961	for i, call := range l {
 962		prompts[i] = call.Prompt
 963	}
 964	return prompts
 965}
 966
 967func (a *sessionAgent) SetModels(large Model, small Model) {
 968	a.largeModel = large
 969	a.smallModel = small
 970}
 971
 972func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 973	a.tools = tools
 974}
 975
 976func (a *sessionAgent) Model() Model {
 977	return a.largeModel
 978}
 979
 980func (a *sessionAgent) promptPrefix() string {
 981	if a.isClaudeCode() {
 982		return "You are Claude Code, Anthropic's official CLI for Claude."
 983	}
 984	return a.systemPromptPrefix
 985}
 986
 987// XXX: this should be generalized to cover other subscription plans, like Copilot.
 988func (a *sessionAgent) isClaudeCode() bool {
 989	cfg := config.Get()
 990	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 991	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 992}
 993
 994// convertToToolResult converts a fantasy tool result to a message tool result.
 995func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 996	baseResult := message.ToolResult{
 997		ToolCallID: result.ToolCallID,
 998		Name:       result.ToolName,
 999		Metadata:   result.ClientMetadata,
1000	}
1001
1002	switch result.Result.GetType() {
1003	case fantasy.ToolResultContentTypeText:
1004		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1005			baseResult.Content = r.Text
1006		}
1007	case fantasy.ToolResultContentTypeError:
1008		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1009			baseResult.Content = r.Error.Error()
1010			baseResult.IsError = true
1011		}
1012	case fantasy.ToolResultContentTypeMedia:
1013		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1014			content := r.Text
1015			if content == "" {
1016				content = fmt.Sprintf("Loaded %s content", r.MediaType)
1017			}
1018			baseResult.Content = content
1019			baseResult.Data = r.Data
1020			baseResult.MIMEType = r.MediaType
1021		}
1022	}
1023
1024	return baseResult
1025}
1026
1027// workaroundProviderMediaLimitations converts media content in tool results to
1028// user messages for providers that don't natively support images in tool results.
1029//
1030// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1031// don't support sending images/media in tool result messages - they only accept
1032// text in tool results. However, they DO support images in user messages.
1033//
1034// If we send media in tool results to these providers, the API returns an error.
1035//
1036// Solution: For these providers, we:
1037//  1. Replace the media in the tool result with a text placeholder
1038//  2. Inject a user message immediately after with the image as a file attachment
1039//  3. This maintains the tool execution flow while working around API limitations
1040//
1041// Anthropic and Bedrock support images natively in tool results, so we skip
1042// this workaround for them.
1043//
1044// Example transformation:
1045//
1046//	BEFORE: [tool result: image data]
1047//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
1048func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
1049	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1050		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1051
1052	if providerSupportsMedia {
1053		return messages
1054	}
1055
1056	convertedMessages := make([]fantasy.Message, 0, len(messages))
1057
1058	for _, msg := range messages {
1059		if msg.Role != fantasy.MessageRoleTool {
1060			convertedMessages = append(convertedMessages, msg)
1061			continue
1062		}
1063
1064		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1065		var mediaFiles []fantasy.FilePart
1066
1067		for _, part := range msg.Content {
1068			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1069			if !ok {
1070				textParts = append(textParts, part)
1071				continue
1072			}
1073
1074			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1075				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1076				if err != nil {
1077					slog.Warn("failed to decode media data", "error", err)
1078					textParts = append(textParts, part)
1079					continue
1080				}
1081
1082				mediaFiles = append(mediaFiles, fantasy.FilePart{
1083					Data:      decoded,
1084					MediaType: media.MediaType,
1085					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1086				})
1087
1088				textParts = append(textParts, fantasy.ToolResultPart{
1089					ToolCallID: toolResult.ToolCallID,
1090					Output: fantasy.ToolResultOutputContentText{
1091						Text: "[Image/media content loaded - see attached file]",
1092					},
1093					ProviderOptions: toolResult.ProviderOptions,
1094				})
1095			} else {
1096				textParts = append(textParts, part)
1097			}
1098		}
1099
1100		convertedMessages = append(convertedMessages, fantasy.Message{
1101			Role:    fantasy.MessageRoleTool,
1102			Content: textParts,
1103		})
1104
1105		if len(mediaFiles) > 0 {
1106			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1107				"Here is the media content from the tool result:",
1108				mediaFiles...,
1109			))
1110		}
1111	}
1112
1113	return convertedMessages
1114}