agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"git.secluded.site/crush/internal/agent/tools"
  31	"git.secluded.site/crush/internal/config"
  32	"git.secluded.site/crush/internal/csync"
  33	"git.secluded.site/crush/internal/message"
  34	"git.secluded.site/crush/internal/notification"
  35	"git.secluded.site/crush/internal/permission"
  36	"git.secluded.site/crush/internal/session"
  37	"git.secluded.site/crush/internal/stringext"
  38	"github.com/charmbracelet/catwalk/pkg/catwalk"
  39)
  40
  41//go:embed templates/title.md
  42var titlePrompt []byte
  43
  44//go:embed templates/summary.md
  45var summaryPrompt []byte
  46
  47type SessionAgentCall struct {
  48	SessionID        string
  49	Prompt           string
  50	ProviderOptions  fantasy.ProviderOptions
  51	Attachments      []message.Attachment
  52	MaxOutputTokens  int64
  53	Temperature      *float64
  54	TopP             *float64
  55	TopK             *int64
  56	FrequencyPenalty *float64
  57	PresencePenalty  *float64
  58	NonInteractive   bool
  59}
  60
  61type SessionAgent interface {
  62	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  63	SetModels(large Model, small Model)
  64	SetTools(tools []fantasy.AgentTool)
  65	Cancel(sessionID string)
  66	CancelAll()
  67	IsSessionBusy(sessionID string) bool
  68	IsBusy() bool
  69	QueuedPrompts(sessionID string) int
  70	QueuedPromptsList(sessionID string) []string
  71	ClearQueue(sessionID string)
  72	Summarize(context.Context, string, fantasy.ProviderOptions) error
  73	Model() Model
  74}
  75
  76type Model struct {
  77	Model      fantasy.LanguageModel
  78	CatwalkCfg catwalk.Model
  79	ModelCfg   config.SelectedModel
  80}
  81
  82type sessionAgent struct {
  83	largeModel           Model
  84	smallModel           Model
  85	systemPromptPrefix   string
  86	systemPrompt         string
  87	isSubAgent           bool
  88	tools                []fantasy.AgentTool
  89	sessions             session.Service
  90	messages             message.Service
  91	disableAutoSummarize bool
  92	isYolo               bool
  93
  94	messageQueue   *csync.Map[string, []SessionAgentCall]
  95	activeRequests *csync.Map[string, context.CancelFunc]
  96}
  97
  98type SessionAgentOptions struct {
  99	LargeModel           Model
 100	SmallModel           Model
 101	SystemPromptPrefix   string
 102	SystemPrompt         string
 103	IsSubAgent           bool
 104	DisableAutoSummarize bool
 105	IsYolo               bool
 106	Sessions             session.Service
 107	Messages             message.Service
 108	Tools                []fantasy.AgentTool
 109}
 110
 111func NewSessionAgent(
 112	opts SessionAgentOptions,
 113) SessionAgent {
 114	return &sessionAgent{
 115		largeModel:           opts.LargeModel,
 116		smallModel:           opts.SmallModel,
 117		systemPromptPrefix:   opts.SystemPromptPrefix,
 118		systemPrompt:         opts.SystemPrompt,
 119		isSubAgent:           opts.IsSubAgent,
 120		sessions:             opts.Sessions,
 121		messages:             opts.Messages,
 122		disableAutoSummarize: opts.DisableAutoSummarize,
 123		tools:                opts.Tools,
 124		isYolo:               opts.IsYolo,
 125		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 126		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 127	}
 128}
 129
 130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 131	if call.Prompt == "" {
 132		return nil, ErrEmptyPrompt
 133	}
 134	if call.SessionID == "" {
 135		return nil, ErrSessionMissing
 136	}
 137
 138	// Queue the message if busy
 139	if a.IsSessionBusy(call.SessionID) {
 140		existing, ok := a.messageQueue.Get(call.SessionID)
 141		if !ok {
 142			existing = []SessionAgentCall{}
 143		}
 144		existing = append(existing, call)
 145		a.messageQueue.Set(call.SessionID, existing)
 146		return nil, nil
 147	}
 148
 149	if len(a.tools) > 0 {
 150		// Add Anthropic caching to the last tool.
 151		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 152	}
 153
 154	agent := fantasy.NewAgent(
 155		a.largeModel.Model,
 156		fantasy.WithSystemPrompt(a.systemPrompt),
 157		fantasy.WithTools(a.tools...),
 158	)
 159
 160	sessionLock := sync.Mutex{}
 161	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 162	if err != nil {
 163		return nil, fmt.Errorf("failed to get session: %w", err)
 164	}
 165
 166	msgs, err := a.getSessionMessages(ctx, currentSession)
 167	if err != nil {
 168		return nil, fmt.Errorf("failed to get session messages: %w", err)
 169	}
 170
 171	var wg sync.WaitGroup
 172	// Generate title if first message.
 173	if len(msgs) == 0 {
 174		wg.Go(func() {
 175			sessionLock.Lock()
 176			a.generateTitle(ctx, &currentSession, call.Prompt)
 177			sessionLock.Unlock()
 178		})
 179	}
 180
 181	// Add the user message to the session.
 182	_, err = a.createUserMessage(ctx, call)
 183	if err != nil {
 184		return nil, err
 185	}
 186
 187	// Add the session to the context.
 188	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 189
 190	genCtx, cancel := context.WithCancel(ctx)
 191	a.activeRequests.Set(call.SessionID, cancel)
 192
 193	defer cancel()
 194	defer a.activeRequests.Del(call.SessionID)
 195
 196	history, files := a.preparePrompt(msgs, call.Attachments...)
 197
 198	startTime := time.Now()
 199	a.eventPromptSent(call.SessionID)
 200
 201	var currentAssistant *message.Message
 202	var shouldSummarize bool
 203	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 204		Prompt:           call.Prompt,
 205		Files:            files,
 206		Messages:         history,
 207		ProviderOptions:  call.ProviderOptions,
 208		MaxOutputTokens:  &call.MaxOutputTokens,
 209		TopP:             call.TopP,
 210		Temperature:      call.Temperature,
 211		PresencePenalty:  call.PresencePenalty,
 212		TopK:             call.TopK,
 213		FrequencyPenalty: call.FrequencyPenalty,
 214		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 215			prepared.Messages = options.Messages
 216			for i := range prepared.Messages {
 217				prepared.Messages[i].ProviderOptions = nil
 218			}
 219
 220			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 221			a.messageQueue.Del(call.SessionID)
 222			for _, queued := range queuedCalls {
 223				userMessage, createErr := a.createUserMessage(callContext, queued)
 224				if createErr != nil {
 225					return callContext, prepared, createErr
 226				}
 227				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 228			}
 229
 230			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 231
 232			lastSystemRoleInx := 0
 233			systemMessageUpdated := false
 234			for i, msg := range prepared.Messages {
 235				// Only add cache control to the last message.
 236				if msg.Role == fantasy.MessageRoleSystem {
 237					lastSystemRoleInx = i
 238				} else if !systemMessageUpdated {
 239					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 240					systemMessageUpdated = true
 241				}
 242				// Than add cache control to the last 2 messages.
 243				if i > len(prepared.Messages)-3 {
 244					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 245				}
 246			}
 247
 248			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 249				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 250			}
 251
 252			var assistantMsg message.Message
 253			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 254				Role:     message.Assistant,
 255				Parts:    []message.ContentPart{},
 256				Model:    a.largeModel.ModelCfg.Model,
 257				Provider: a.largeModel.ModelCfg.Provider,
 258			})
 259			if err != nil {
 260				return callContext, prepared, err
 261			}
 262			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 263			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 264			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 265			currentAssistant = &assistantMsg
 266			return callContext, prepared, err
 267		},
 268		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 269			currentAssistant.AppendReasoningContent(reasoning.Text)
 270			return a.messages.Update(genCtx, *currentAssistant)
 271		},
 272		OnReasoningDelta: func(id string, text string) error {
 273			currentAssistant.AppendReasoningContent(text)
 274			return a.messages.Update(genCtx, *currentAssistant)
 275		},
 276		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 277			// handle anthropic signature
 278			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 279				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 280					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 281				}
 282			}
 283			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 284				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 285					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 286				}
 287			}
 288			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 289				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 290					currentAssistant.SetReasoningResponsesData(reasoning)
 291				}
 292			}
 293			currentAssistant.FinishThinking()
 294			return a.messages.Update(genCtx, *currentAssistant)
 295		},
 296		OnTextDelta: func(id string, text string) error {
 297			// Strip leading newline from initial text content. This is is
 298			// particularly important in non-interactive mode where leading
 299			// newlines are very visible.
 300			if len(currentAssistant.Parts) == 0 {
 301				text = strings.TrimPrefix(text, "\n")
 302			}
 303
 304			currentAssistant.AppendContent(text)
 305			return a.messages.Update(genCtx, *currentAssistant)
 306		},
 307		OnToolInputStart: func(id string, toolName string) error {
 308			toolCall := message.ToolCall{
 309				ID:               id,
 310				Name:             toolName,
 311				ProviderExecuted: false,
 312				Finished:         false,
 313			}
 314			currentAssistant.AddToolCall(toolCall)
 315			return a.messages.Update(genCtx, *currentAssistant)
 316		},
 317		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 318			// TODO: implement
 319		},
 320		OnToolCall: func(tc fantasy.ToolCallContent) error {
 321			toolCall := message.ToolCall{
 322				ID:               tc.ToolCallID,
 323				Name:             tc.ToolName,
 324				Input:            tc.Input,
 325				ProviderExecuted: false,
 326				Finished:         true,
 327			}
 328			currentAssistant.AddToolCall(toolCall)
 329			return a.messages.Update(genCtx, *currentAssistant)
 330		},
 331		OnToolResult: func(result fantasy.ToolResultContent) error {
 332			toolResult := a.convertToToolResult(result)
 333			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 334				Role: message.Tool,
 335				Parts: []message.ContentPart{
 336					toolResult,
 337				},
 338			})
 339			return createMsgErr
 340		},
 341		OnStepFinish: func(stepResult fantasy.StepResult) error {
 342			finishReason := message.FinishReasonUnknown
 343			switch stepResult.FinishReason {
 344			case fantasy.FinishReasonLength:
 345				finishReason = message.FinishReasonMaxTokens
 346			case fantasy.FinishReasonStop:
 347				finishReason = message.FinishReasonEndTurn
 348			case fantasy.FinishReasonToolCalls:
 349				finishReason = message.FinishReasonToolUse
 350			}
 351			currentAssistant.AddFinish(finishReason, "", "")
 352			sessionLock.Lock()
 353			updatedSession, getSessionErr := a.sessions.Get(genCtx, call.SessionID)
 354			if getSessionErr != nil {
 355				sessionLock.Unlock()
 356				return getSessionErr
 357			}
 358			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 359			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 360			sessionLock.Unlock()
 361			if sessionErr != nil {
 362				return sessionErr
 363			}
 364			return a.messages.Update(genCtx, *currentAssistant)
 365		},
 366		StopWhen: []fantasy.StopCondition{
 367			func(_ []fantasy.StepResult) bool {
 368				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 369				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 370				remaining := cw - tokens
 371				var threshold int64
 372				if cw > 200_000 {
 373					threshold = 20_000
 374				} else {
 375					threshold = int64(float64(cw) * 0.2)
 376				}
 377				if (remaining <= threshold) && !a.disableAutoSummarize {
 378					shouldSummarize = true
 379					return true
 380				}
 381				return false
 382			},
 383		},
 384	})
 385
 386	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 387
 388	if err != nil {
 389		isCancelErr := errors.Is(err, context.Canceled)
 390		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 391		if currentAssistant == nil {
 392			return result, err
 393		}
 394		// Ensure we finish thinking on error to close the reasoning state.
 395		currentAssistant.FinishThinking()
 396		toolCalls := currentAssistant.ToolCalls()
 397		// INFO: we use the parent context here because the genCtx has been cancelled.
 398		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 399		if createErr != nil {
 400			return nil, createErr
 401		}
 402		for _, tc := range toolCalls {
 403			if !tc.Finished {
 404				tc.Finished = true
 405				tc.Input = "{}"
 406				currentAssistant.AddToolCall(tc)
 407				updateErr := a.messages.Update(ctx, *currentAssistant)
 408				if updateErr != nil {
 409					return nil, updateErr
 410				}
 411			}
 412
 413			found := false
 414			for _, msg := range msgs {
 415				if msg.Role == message.Tool {
 416					for _, tr := range msg.ToolResults() {
 417						if tr.ToolCallID == tc.ID {
 418							found = true
 419							break
 420						}
 421					}
 422				}
 423				if found {
 424					break
 425				}
 426			}
 427			if found {
 428				continue
 429			}
 430			content := "There was an error while executing the tool"
 431			if isCancelErr {
 432				content = "Tool execution canceled by user"
 433			} else if isPermissionErr {
 434				content = "User denied permission"
 435			}
 436			toolResult := message.ToolResult{
 437				ToolCallID: tc.ID,
 438				Name:       tc.Name,
 439				Content:    content,
 440				IsError:    true,
 441			}
 442			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 443				Role: message.Tool,
 444				Parts: []message.ContentPart{
 445					toolResult,
 446				},
 447			})
 448			if createErr != nil {
 449				return nil, createErr
 450			}
 451		}
 452		var fantasyErr *fantasy.Error
 453		var providerErr *fantasy.ProviderError
 454		const defaultTitle = "Provider Error"
 455		if isCancelErr {
 456			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 457		} else if isPermissionErr {
 458			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 459		} else if errors.As(err, &providerErr) {
 460			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 461		} else if errors.As(err, &fantasyErr) {
 462			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 463		} else {
 464			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 465		}
 466		// Note: we use the parent context here because the genCtx has been
 467		// cancelled.
 468		updateErr := a.messages.Update(ctx, *currentAssistant)
 469		if updateErr != nil {
 470			return nil, updateErr
 471		}
 472		return nil, err
 473	}
 474	wg.Wait()
 475
 476	// Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
 477	if !call.NonInteractive {
 478		notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
 479		_ = notification.Send("Crush is waiting...", notifBody)
 480	}
 481
 482	if shouldSummarize {
 483		a.activeRequests.Del(call.SessionID)
 484		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 485			return nil, summarizeErr
 486		}
 487		// If the agent wasn't done...
 488		if len(currentAssistant.ToolCalls()) > 0 {
 489			existing, ok := a.messageQueue.Get(call.SessionID)
 490			if !ok {
 491				existing = []SessionAgentCall{}
 492			}
 493			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 494			existing = append(existing, call)
 495			a.messageQueue.Set(call.SessionID, existing)
 496		}
 497	}
 498
 499	// Release active request before processing queued messages.
 500	a.activeRequests.Del(call.SessionID)
 501	cancel()
 502
 503	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 504	if !ok || len(queuedMessages) == 0 {
 505		return result, err
 506	}
 507	// There are queued messages restart the loop.
 508	firstQueuedMessage := queuedMessages[0]
 509	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 510	return a.Run(ctx, firstQueuedMessage)
 511}
 512
 513func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 514	if a.IsSessionBusy(sessionID) {
 515		return ErrSessionBusy
 516	}
 517
 518	currentSession, err := a.sessions.Get(ctx, sessionID)
 519	if err != nil {
 520		return fmt.Errorf("failed to get session: %w", err)
 521	}
 522	msgs, err := a.getSessionMessages(ctx, currentSession)
 523	if err != nil {
 524		return err
 525	}
 526	if len(msgs) == 0 {
 527		// Nothing to summarize.
 528		return nil
 529	}
 530
 531	aiMsgs, _ := a.preparePrompt(msgs)
 532
 533	genCtx, cancel := context.WithCancel(ctx)
 534	a.activeRequests.Set(sessionID, cancel)
 535	defer a.activeRequests.Del(sessionID)
 536	defer cancel()
 537
 538	agent := fantasy.NewAgent(a.largeModel.Model,
 539		fantasy.WithSystemPrompt(string(summaryPrompt)),
 540	)
 541	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 542		Role:             message.Assistant,
 543		Model:            a.largeModel.Model.Model(),
 544		Provider:         a.largeModel.Model.Provider(),
 545		IsSummaryMessage: true,
 546	})
 547	if err != nil {
 548		return err
 549	}
 550
 551	summaryPromptText := "Provide a detailed summary of our conversation above."
 552	if len(currentSession.Todos) > 0 {
 553		summaryPromptText += "\n\n## Current Todo List\n\n"
 554		for _, t := range currentSession.Todos {
 555			summaryPromptText += fmt.Sprintf("- [%s] %s\n", t.Status, t.Content)
 556		}
 557		summaryPromptText += "\nInclude these tasks and their statuses in your summary. "
 558		summaryPromptText += "Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks."
 559	}
 560
 561	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 562		Prompt:          summaryPromptText,
 563		Messages:        aiMsgs,
 564		ProviderOptions: opts,
 565		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 566			prepared.Messages = options.Messages
 567			if a.systemPromptPrefix != "" {
 568				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 569			}
 570			return callContext, prepared, nil
 571		},
 572		OnReasoningDelta: func(id string, text string) error {
 573			summaryMessage.AppendReasoningContent(text)
 574			return a.messages.Update(genCtx, summaryMessage)
 575		},
 576		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 577			// Handle anthropic signature.
 578			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 579				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 580					summaryMessage.AppendReasoningSignature(signature.Signature)
 581				}
 582			}
 583			summaryMessage.FinishThinking()
 584			return a.messages.Update(genCtx, summaryMessage)
 585		},
 586		OnTextDelta: func(id, text string) error {
 587			summaryMessage.AppendContent(text)
 588			return a.messages.Update(genCtx, summaryMessage)
 589		},
 590	})
 591	if err != nil {
 592		isCancelErr := errors.Is(err, context.Canceled)
 593		if isCancelErr {
 594			// User cancelled summarize we need to remove the summary message.
 595			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 596			return deleteErr
 597		}
 598		return err
 599	}
 600
 601	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 602	err = a.messages.Update(genCtx, summaryMessage)
 603	if err != nil {
 604		return err
 605	}
 606
 607	var openrouterCost *float64
 608	for _, step := range resp.Steps {
 609		stepCost := a.openrouterCost(step.ProviderMetadata)
 610		if stepCost != nil {
 611			newCost := *stepCost
 612			if openrouterCost != nil {
 613				newCost += *openrouterCost
 614			}
 615			openrouterCost = &newCost
 616		}
 617	}
 618
 619	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 620
 621	// Just in case, get just the last usage info.
 622	usage := resp.Response.Usage
 623	currentSession.SummaryMessageID = summaryMessage.ID
 624	currentSession.CompletionTokens = usage.OutputTokens
 625	currentSession.PromptTokens = 0
 626	_, err = a.sessions.Save(genCtx, currentSession)
 627	return err
 628}
 629
 630func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 631	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 632		return fantasy.ProviderOptions{}
 633	}
 634	return fantasy.ProviderOptions{
 635		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 636			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 637		},
 638		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 639			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 640		},
 641	}
 642}
 643
 644func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 645	var attachmentParts []message.ContentPart
 646	for _, attachment := range call.Attachments {
 647		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 648	}
 649	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 650	parts = append(parts, attachmentParts...)
 651	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 652		Role:  message.User,
 653		Parts: parts,
 654	})
 655	if err != nil {
 656		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 657	}
 658	return msg, nil
 659}
 660
 661func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 662	var history []fantasy.Message
 663	if !a.isSubAgent {
 664		history = append(history, fantasy.NewUserMessage(
 665			fmt.Sprintf("<system_reminder>%s</system_reminder>",
 666				`This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
 667If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
 668If not, please feel free to ignore. Again do not mention this message to the user.`,
 669			),
 670		))
 671	}
 672	for _, m := range msgs {
 673		if len(m.Parts) == 0 {
 674			continue
 675		}
 676		// Assistant message without content or tool calls (cancelled before it
 677		// returned anything).
 678		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 679			continue
 680		}
 681		history = append(history, m.ToAIMessage()...)
 682	}
 683
 684	var files []fantasy.FilePart
 685	for _, attachment := range attachments {
 686		files = append(files, fantasy.FilePart{
 687			Filename:  attachment.FileName,
 688			Data:      attachment.Content,
 689			MediaType: attachment.MimeType,
 690		})
 691	}
 692
 693	return history, files
 694}
 695
 696func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 697	msgs, err := a.messages.List(ctx, session.ID)
 698	if err != nil {
 699		return nil, fmt.Errorf("failed to list messages: %w", err)
 700	}
 701
 702	if session.SummaryMessageID != "" {
 703		summaryMsgInex := -1
 704		for i, msg := range msgs {
 705			if msg.ID == session.SummaryMessageID {
 706				summaryMsgInex = i
 707				break
 708			}
 709		}
 710		if summaryMsgInex != -1 {
 711			msgs = msgs[summaryMsgInex:]
 712			msgs[0].Role = message.User
 713		}
 714	}
 715	return msgs, nil
 716}
 717
 718func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
 719	if prompt == "" {
 720		return
 721	}
 722
 723	var maxOutput int64 = 40
 724	if a.smallModel.CatwalkCfg.CanReason {
 725		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 726	}
 727
 728	agent := fantasy.NewAgent(a.smallModel.Model,
 729		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 730		fantasy.WithMaxOutputTokens(maxOutput),
 731	)
 732
 733	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 734		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 735		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 736			prepared.Messages = options.Messages
 737			if a.systemPromptPrefix != "" {
 738				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 739			}
 740			return callContext, prepared, nil
 741		},
 742	})
 743	if err != nil {
 744		slog.Error("error generating title", "err", err)
 745		return
 746	}
 747
 748	title := resp.Response.Content.Text()
 749
 750	title = strings.ReplaceAll(title, "\n", " ")
 751
 752	// Remove thinking tags if present.
 753	if idx := strings.Index(title, "</think>"); idx > 0 {
 754		title = title[idx+len("</think>"):]
 755	}
 756
 757	title = strings.TrimSpace(title)
 758	if title == "" {
 759		slog.Warn("failed to generate title", "warn", "empty title")
 760		return
 761	}
 762
 763	session.Title = title
 764
 765	var openrouterCost *float64
 766	for _, step := range resp.Steps {
 767		stepCost := a.openrouterCost(step.ProviderMetadata)
 768		if stepCost != nil {
 769			newCost := *stepCost
 770			if openrouterCost != nil {
 771				newCost += *openrouterCost
 772			}
 773			openrouterCost = &newCost
 774		}
 775	}
 776
 777	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
 778	_, saveErr := a.sessions.Save(ctx, *session)
 779	if saveErr != nil {
 780		slog.Error("failed to save session title & usage", "error", saveErr)
 781		return
 782	}
 783}
 784
 785func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 786	openrouterMetadata, ok := metadata[openrouter.Name]
 787	if !ok {
 788		return nil
 789	}
 790
 791	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 792	if !ok {
 793		return nil
 794	}
 795	return &opts.Usage.Cost
 796}
 797
 798func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 799	modelConfig := model.CatwalkCfg
 800	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 801		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 802		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 803		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 804
 805	if a.isClaudeCode() {
 806		cost = 0
 807	}
 808
 809	a.eventTokensUsed(session.ID, model, usage, cost)
 810
 811	if overrideCost != nil {
 812		session.Cost += *overrideCost
 813	} else {
 814		session.Cost += cost
 815	}
 816
 817	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 818	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 819}
 820
 821func (a *sessionAgent) Cancel(sessionID string) {
 822	// Cancel regular requests.
 823	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 824		slog.Info("Request cancellation initiated", "session_id", sessionID)
 825		cancel()
 826	}
 827
 828	// Also check for summarize requests.
 829	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 830		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 831		cancel()
 832	}
 833
 834	if a.QueuedPrompts(sessionID) > 0 {
 835		slog.Info("Clearing queued prompts", "session_id", sessionID)
 836		a.messageQueue.Del(sessionID)
 837	}
 838}
 839
 840func (a *sessionAgent) ClearQueue(sessionID string) {
 841	if a.QueuedPrompts(sessionID) > 0 {
 842		slog.Info("Clearing queued prompts", "session_id", sessionID)
 843		a.messageQueue.Del(sessionID)
 844	}
 845}
 846
 847func (a *sessionAgent) CancelAll() {
 848	if !a.IsBusy() {
 849		return
 850	}
 851	for key := range a.activeRequests.Seq2() {
 852		a.Cancel(key) // key is sessionID
 853	}
 854
 855	timeout := time.After(5 * time.Second)
 856	for a.IsBusy() {
 857		select {
 858		case <-timeout:
 859			return
 860		default:
 861			time.Sleep(200 * time.Millisecond)
 862		}
 863	}
 864}
 865
 866func (a *sessionAgent) IsBusy() bool {
 867	var busy bool
 868	for cancelFunc := range a.activeRequests.Seq() {
 869		if cancelFunc != nil {
 870			busy = true
 871			break
 872		}
 873	}
 874	return busy
 875}
 876
 877func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 878	_, busy := a.activeRequests.Get(sessionID)
 879	return busy
 880}
 881
 882func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 883	l, ok := a.messageQueue.Get(sessionID)
 884	if !ok {
 885		return 0
 886	}
 887	return len(l)
 888}
 889
 890func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 891	l, ok := a.messageQueue.Get(sessionID)
 892	if !ok {
 893		return nil
 894	}
 895	prompts := make([]string, len(l))
 896	for i, call := range l {
 897		prompts[i] = call.Prompt
 898	}
 899	return prompts
 900}
 901
 902func (a *sessionAgent) SetModels(large Model, small Model) {
 903	a.largeModel = large
 904	a.smallModel = small
 905}
 906
 907func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 908	a.tools = tools
 909}
 910
 911func (a *sessionAgent) Model() Model {
 912	return a.largeModel
 913}
 914
 915func (a *sessionAgent) promptPrefix() string {
 916	if a.isClaudeCode() {
 917		return "You are Claude Code, Anthropic's official CLI for Claude."
 918	}
 919	return a.systemPromptPrefix
 920}
 921
 922func (a *sessionAgent) isClaudeCode() bool {
 923	cfg := config.Get()
 924	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 925	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 926}
 927
 928// convertToToolResult converts a fantasy tool result to a message tool result.
 929func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 930	baseResult := message.ToolResult{
 931		ToolCallID: result.ToolCallID,
 932		Name:       result.ToolName,
 933		Metadata:   result.ClientMetadata,
 934	}
 935
 936	switch result.Result.GetType() {
 937	case fantasy.ToolResultContentTypeText:
 938		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 939			baseResult.Content = r.Text
 940		}
 941	case fantasy.ToolResultContentTypeError:
 942		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 943			baseResult.Content = r.Error.Error()
 944			baseResult.IsError = true
 945		}
 946	case fantasy.ToolResultContentTypeMedia:
 947		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 948			content := r.Text
 949			if content == "" {
 950				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 951			}
 952			baseResult.Content = content
 953			baseResult.Data = r.Data
 954			baseResult.MIMEType = r.MediaType
 955		}
 956	}
 957
 958	return baseResult
 959}
 960
 961// workaroundProviderMediaLimitations converts media content in tool results to
 962// user messages for providers that don't natively support images in tool results.
 963//
 964// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 965// don't support sending images/media in tool result messages - they only accept
 966// text in tool results. However, they DO support images in user messages.
 967//
 968// If we send media in tool results to these providers, the API returns an error.
 969//
 970// Solution: For these providers, we:
 971//  1. Replace the media in the tool result with a text placeholder
 972//  2. Inject a user message immediately after with the image as a file attachment
 973//  3. This maintains the tool execution flow while working around API limitations
 974//
 975// Anthropic and Bedrock support images natively in tool results, so we skip
 976// this workaround for them.
 977//
 978// Example transformation:
 979//
 980//	BEFORE: [tool result: image data]
 981//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
 982func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
 983	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
 984		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
 985
 986	if providerSupportsMedia {
 987		return messages
 988	}
 989
 990	convertedMessages := make([]fantasy.Message, 0, len(messages))
 991
 992	for _, msg := range messages {
 993		if msg.Role != fantasy.MessageRoleTool {
 994			convertedMessages = append(convertedMessages, msg)
 995			continue
 996		}
 997
 998		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
 999		var mediaFiles []fantasy.FilePart
1000
1001		for _, part := range msg.Content {
1002			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1003			if !ok {
1004				textParts = append(textParts, part)
1005				continue
1006			}
1007
1008			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1009				decoded, err := base64.StdEncoding.DecodeString(media.Data)
1010				if err != nil {
1011					slog.Warn("failed to decode media data", "error", err)
1012					textParts = append(textParts, part)
1013					continue
1014				}
1015
1016				mediaFiles = append(mediaFiles, fantasy.FilePart{
1017					Data:      decoded,
1018					MediaType: media.MediaType,
1019					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1020				})
1021
1022				textParts = append(textParts, fantasy.ToolResultPart{
1023					ToolCallID: toolResult.ToolCallID,
1024					Output: fantasy.ToolResultOutputContentText{
1025						Text: "[Image/media content loaded - see attached file]",
1026					},
1027					ProviderOptions: toolResult.ProviderOptions,
1028				})
1029			} else {
1030				textParts = append(textParts, part)
1031			}
1032		}
1033
1034		convertedMessages = append(convertedMessages, fantasy.Message{
1035			Role:    fantasy.MessageRoleTool,
1036			Content: textParts,
1037		})
1038
1039		if len(mediaFiles) > 0 {
1040			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1041				"Here is the media content from the tool result:",
1042				mediaFiles...,
1043			))
1044		}
1045	}
1046
1047	return convertedMessages
1048}