agent.go

   1// Package agent is the core orchestration layer for Crush AI agents.
   2//
   3// It provides session-based AI agent functionality for managing
   4// conversations, tool execution, and message handling. It coordinates
   5// interactions between language models, messages, sessions, and tools while
   6// handling features like automatic summarization, queuing, and token
   7// management.
   8package agent
   9
  10import (
  11	"cmp"
  12	"context"
  13	_ "embed"
  14	"encoding/base64"
  15	"errors"
  16	"fmt"
  17	"log/slog"
  18	"os"
  19	"strconv"
  20	"strings"
  21	"sync"
  22	"time"
  23
  24	"charm.land/fantasy"
  25	"charm.land/fantasy/providers/anthropic"
  26	"charm.land/fantasy/providers/bedrock"
  27	"charm.land/fantasy/providers/google"
  28	"charm.land/fantasy/providers/openai"
  29	"charm.land/fantasy/providers/openrouter"
  30	"git.secluded.site/crush/internal/agent/tools"
  31	"git.secluded.site/crush/internal/config"
  32	"git.secluded.site/crush/internal/csync"
  33	"git.secluded.site/crush/internal/message"
  34	"git.secluded.site/crush/internal/notification"
  35	"git.secluded.site/crush/internal/permission"
  36	"git.secluded.site/crush/internal/session"
  37	"git.secluded.site/crush/internal/stringext"
  38	"github.com/charmbracelet/catwalk/pkg/catwalk"
  39)
  40
  41//go:embed templates/title.md
  42var titlePrompt []byte
  43
  44//go:embed templates/summary.md
  45var summaryPrompt []byte
  46
  47type SessionAgentCall struct {
  48	SessionID        string
  49	Prompt           string
  50	ProviderOptions  fantasy.ProviderOptions
  51	Attachments      []message.Attachment
  52	MaxOutputTokens  int64
  53	Temperature      *float64
  54	TopP             *float64
  55	TopK             *int64
  56	FrequencyPenalty *float64
  57	PresencePenalty  *float64
  58	NonInteractive   bool
  59}
  60
  61type SessionAgent interface {
  62	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
  63	SetModels(large Model, small Model)
  64	SetTools(tools []fantasy.AgentTool)
  65	Cancel(sessionID string)
  66	CancelAll()
  67	IsSessionBusy(sessionID string) bool
  68	IsBusy() bool
  69	QueuedPrompts(sessionID string) int
  70	ClearQueue(sessionID string)
  71	Summarize(context.Context, string, fantasy.ProviderOptions) error
  72	Model() Model
  73}
  74
  75type Model struct {
  76	Model      fantasy.LanguageModel
  77	CatwalkCfg catwalk.Model
  78	ModelCfg   config.SelectedModel
  79}
  80
  81type sessionAgent struct {
  82	largeModel           Model
  83	smallModel           Model
  84	systemPromptPrefix   string
  85	systemPrompt         string
  86	tools                []fantasy.AgentTool
  87	sessions             session.Service
  88	messages             message.Service
  89	disableAutoSummarize bool
  90	isYolo               bool
  91
  92	messageQueue   *csync.Map[string, []SessionAgentCall]
  93	activeRequests *csync.Map[string, context.CancelFunc]
  94}
  95
  96type SessionAgentOptions struct {
  97	LargeModel           Model
  98	SmallModel           Model
  99	SystemPromptPrefix   string
 100	SystemPrompt         string
 101	DisableAutoSummarize bool
 102	IsYolo               bool
 103	Sessions             session.Service
 104	Messages             message.Service
 105	Tools                []fantasy.AgentTool
 106}
 107
 108func NewSessionAgent(
 109	opts SessionAgentOptions,
 110) SessionAgent {
 111	return &sessionAgent{
 112		largeModel:           opts.LargeModel,
 113		smallModel:           opts.SmallModel,
 114		systemPromptPrefix:   opts.SystemPromptPrefix,
 115		systemPrompt:         opts.SystemPrompt,
 116		sessions:             opts.Sessions,
 117		messages:             opts.Messages,
 118		disableAutoSummarize: opts.DisableAutoSummarize,
 119		tools:                opts.Tools,
 120		isYolo:               opts.IsYolo,
 121		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 122		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 123	}
 124}
 125
 126func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 127	if call.Prompt == "" {
 128		return nil, ErrEmptyPrompt
 129	}
 130	if call.SessionID == "" {
 131		return nil, ErrSessionMissing
 132	}
 133
 134	// Queue the message if busy
 135	if a.IsSessionBusy(call.SessionID) {
 136		existing, ok := a.messageQueue.Get(call.SessionID)
 137		if !ok {
 138			existing = []SessionAgentCall{}
 139		}
 140		existing = append(existing, call)
 141		a.messageQueue.Set(call.SessionID, existing)
 142		return nil, nil
 143	}
 144
 145	if len(a.tools) > 0 {
 146		// Add Anthropic caching to the last tool.
 147		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
 148	}
 149
 150	agent := fantasy.NewAgent(
 151		a.largeModel.Model,
 152		fantasy.WithSystemPrompt(a.systemPrompt),
 153		fantasy.WithTools(a.tools...),
 154	)
 155
 156	sessionLock := sync.Mutex{}
 157	currentSession, err := a.sessions.Get(ctx, call.SessionID)
 158	if err != nil {
 159		return nil, fmt.Errorf("failed to get session: %w", err)
 160	}
 161
 162	msgs, err := a.getSessionMessages(ctx, currentSession)
 163	if err != nil {
 164		return nil, fmt.Errorf("failed to get session messages: %w", err)
 165	}
 166
 167	var wg sync.WaitGroup
 168	// Generate title if first message.
 169	if len(msgs) == 0 {
 170		wg.Go(func() {
 171			sessionLock.Lock()
 172			a.generateTitle(ctx, &currentSession, call.Prompt)
 173			sessionLock.Unlock()
 174		})
 175	}
 176
 177	// Add the user message to the session.
 178	_, err = a.createUserMessage(ctx, call)
 179	if err != nil {
 180		return nil, err
 181	}
 182
 183	// Add the session to the context.
 184	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
 185
 186	genCtx, cancel := context.WithCancel(ctx)
 187	a.activeRequests.Set(call.SessionID, cancel)
 188
 189	defer cancel()
 190	defer a.activeRequests.Del(call.SessionID)
 191
 192	history, files := a.preparePrompt(msgs, call.Attachments...)
 193
 194	startTime := time.Now()
 195	a.eventPromptSent(call.SessionID)
 196
 197	var currentAssistant *message.Message
 198	var shouldSummarize bool
 199	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 200		Prompt:           call.Prompt,
 201		Files:            files,
 202		Messages:         history,
 203		ProviderOptions:  call.ProviderOptions,
 204		MaxOutputTokens:  &call.MaxOutputTokens,
 205		TopP:             call.TopP,
 206		Temperature:      call.Temperature,
 207		PresencePenalty:  call.PresencePenalty,
 208		TopK:             call.TopK,
 209		FrequencyPenalty: call.FrequencyPenalty,
 210		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 211			prepared.Messages = options.Messages
 212			for i := range prepared.Messages {
 213				prepared.Messages[i].ProviderOptions = nil
 214			}
 215
 216			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
 217			a.messageQueue.Del(call.SessionID)
 218			for _, queued := range queuedCalls {
 219				userMessage, createErr := a.createUserMessage(callContext, queued)
 220				if createErr != nil {
 221					return callContext, prepared, createErr
 222				}
 223				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 224			}
 225
 226			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
 227
 228			lastSystemRoleInx := 0
 229			systemMessageUpdated := false
 230			for i, msg := range prepared.Messages {
 231				// Only add cache control to the last message.
 232				if msg.Role == fantasy.MessageRoleSystem {
 233					lastSystemRoleInx = i
 234				} else if !systemMessageUpdated {
 235					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
 236					systemMessageUpdated = true
 237				}
 238				// Than add cache control to the last 2 messages.
 239				if i > len(prepared.Messages)-3 {
 240					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 241				}
 242			}
 243
 244			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
 245				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 246			}
 247
 248			var assistantMsg message.Message
 249			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 250				Role:     message.Assistant,
 251				Parts:    []message.ContentPart{},
 252				Model:    a.largeModel.ModelCfg.Model,
 253				Provider: a.largeModel.ModelCfg.Provider,
 254			})
 255			if err != nil {
 256				return callContext, prepared, err
 257			}
 258			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
 259			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
 260			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
 261			currentAssistant = &assistantMsg
 262			return callContext, prepared, err
 263		},
 264		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
 265			currentAssistant.AppendReasoningContent(reasoning.Text)
 266			return a.messages.Update(genCtx, *currentAssistant)
 267		},
 268		OnReasoningDelta: func(id string, text string) error {
 269			currentAssistant.AppendReasoningContent(text)
 270			return a.messages.Update(genCtx, *currentAssistant)
 271		},
 272		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 273			// handle anthropic signature
 274			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
 275				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
 276					currentAssistant.AppendReasoningSignature(reasoning.Signature)
 277				}
 278			}
 279			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
 280				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
 281					currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
 282				}
 283			}
 284			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
 285				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
 286					currentAssistant.SetReasoningResponsesData(reasoning)
 287				}
 288			}
 289			currentAssistant.FinishThinking()
 290			return a.messages.Update(genCtx, *currentAssistant)
 291		},
 292		OnTextDelta: func(id string, text string) error {
 293			// Strip leading newline from initial text content. This is is
 294			// particularly important in non-interactive mode where leading
 295			// newlines are very visible.
 296			if len(currentAssistant.Parts) == 0 {
 297				text = strings.TrimPrefix(text, "\n")
 298			}
 299
 300			currentAssistant.AppendContent(text)
 301			return a.messages.Update(genCtx, *currentAssistant)
 302		},
 303		OnToolInputStart: func(id string, toolName string) error {
 304			toolCall := message.ToolCall{
 305				ID:               id,
 306				Name:             toolName,
 307				ProviderExecuted: false,
 308				Finished:         false,
 309			}
 310			currentAssistant.AddToolCall(toolCall)
 311			return a.messages.Update(genCtx, *currentAssistant)
 312		},
 313		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
 314			// TODO: implement
 315		},
 316		OnToolCall: func(tc fantasy.ToolCallContent) error {
 317			toolCall := message.ToolCall{
 318				ID:               tc.ToolCallID,
 319				Name:             tc.ToolName,
 320				Input:            tc.Input,
 321				ProviderExecuted: false,
 322				Finished:         true,
 323			}
 324			currentAssistant.AddToolCall(toolCall)
 325			return a.messages.Update(genCtx, *currentAssistant)
 326		},
 327		OnToolResult: func(result fantasy.ToolResultContent) error {
 328			toolResult := a.convertToToolResult(result)
 329			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
 330				Role: message.Tool,
 331				Parts: []message.ContentPart{
 332					toolResult,
 333				},
 334			})
 335			return createMsgErr
 336		},
 337		OnStepFinish: func(stepResult fantasy.StepResult) error {
 338			finishReason := message.FinishReasonUnknown
 339			switch stepResult.FinishReason {
 340			case fantasy.FinishReasonLength:
 341				finishReason = message.FinishReasonMaxTokens
 342			case fantasy.FinishReasonStop:
 343				finishReason = message.FinishReasonEndTurn
 344			case fantasy.FinishReasonToolCalls:
 345				finishReason = message.FinishReasonToolUse
 346			}
 347			currentAssistant.AddFinish(finishReason, "", "")
 348			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 349			sessionLock.Lock()
 350			_, sessionErr := a.sessions.Save(genCtx, currentSession)
 351			sessionLock.Unlock()
 352			if sessionErr != nil {
 353				return sessionErr
 354			}
 355			return a.messages.Update(genCtx, *currentAssistant)
 356		},
 357		StopWhen: []fantasy.StopCondition{
 358			func(_ []fantasy.StepResult) bool {
 359				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
 360				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 361				remaining := cw - tokens
 362				var threshold int64
 363				if cw > 200_000 {
 364					threshold = 20_000
 365				} else {
 366					threshold = int64(float64(cw) * 0.2)
 367				}
 368				if (remaining <= threshold) && !a.disableAutoSummarize {
 369					shouldSummarize = true
 370					return true
 371				}
 372				return false
 373			},
 374		},
 375	})
 376
 377	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
 378
 379	if err != nil {
 380		isCancelErr := errors.Is(err, context.Canceled)
 381		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
 382		if currentAssistant == nil {
 383			return result, err
 384		}
 385		// Ensure we finish thinking on error to close the reasoning state.
 386		currentAssistant.FinishThinking()
 387		toolCalls := currentAssistant.ToolCalls()
 388		// INFO: we use the parent context here because the genCtx has been cancelled.
 389		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
 390		if createErr != nil {
 391			return nil, createErr
 392		}
 393		for _, tc := range toolCalls {
 394			if !tc.Finished {
 395				tc.Finished = true
 396				tc.Input = "{}"
 397				currentAssistant.AddToolCall(tc)
 398				updateErr := a.messages.Update(ctx, *currentAssistant)
 399				if updateErr != nil {
 400					return nil, updateErr
 401				}
 402			}
 403
 404			found := false
 405			for _, msg := range msgs {
 406				if msg.Role == message.Tool {
 407					for _, tr := range msg.ToolResults() {
 408						if tr.ToolCallID == tc.ID {
 409							found = true
 410							break
 411						}
 412					}
 413				}
 414				if found {
 415					break
 416				}
 417			}
 418			if found {
 419				continue
 420			}
 421			content := "There was an error while executing the tool"
 422			if isCancelErr {
 423				content = "Tool execution canceled by user"
 424			} else if isPermissionErr {
 425				content = "User denied permission"
 426			}
 427			toolResult := message.ToolResult{
 428				ToolCallID: tc.ID,
 429				Name:       tc.Name,
 430				Content:    content,
 431				IsError:    true,
 432			}
 433			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
 434				Role: message.Tool,
 435				Parts: []message.ContentPart{
 436					toolResult,
 437				},
 438			})
 439			if createErr != nil {
 440				return nil, createErr
 441			}
 442		}
 443		var fantasyErr *fantasy.Error
 444		var providerErr *fantasy.ProviderError
 445		const defaultTitle = "Provider Error"
 446		if isCancelErr {
 447			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
 448		} else if isPermissionErr {
 449			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
 450		} else if errors.As(err, &providerErr) {
 451			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
 452		} else if errors.As(err, &fantasyErr) {
 453			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
 454		} else {
 455			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
 456		}
 457		// Note: we use the parent context here because the genCtx has been
 458		// cancelled.
 459		updateErr := a.messages.Update(ctx, *currentAssistant)
 460		if updateErr != nil {
 461			return nil, updateErr
 462		}
 463		return nil, err
 464	}
 465	wg.Wait()
 466
 467	// Send notification that agent has finished its turn (skip for nested/non-interactive sessions).
 468	if !call.NonInteractive {
 469		notifBody := fmt.Sprintf("Agent's turn completed in \"%s\"", currentSession.Title)
 470		_ = notification.Send("Crush is waiting...", notifBody)
 471	}
 472
 473	if shouldSummarize {
 474		a.activeRequests.Del(call.SessionID)
 475		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
 476			return nil, summarizeErr
 477		}
 478		// If the agent wasn't done...
 479		if len(currentAssistant.ToolCalls()) > 0 {
 480			existing, ok := a.messageQueue.Get(call.SessionID)
 481			if !ok {
 482				existing = []SessionAgentCall{}
 483			}
 484			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
 485			existing = append(existing, call)
 486			a.messageQueue.Set(call.SessionID, existing)
 487		}
 488	}
 489
 490	// Release active request before processing queued messages.
 491	a.activeRequests.Del(call.SessionID)
 492	cancel()
 493
 494	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 495	if !ok || len(queuedMessages) == 0 {
 496		return result, err
 497	}
 498	// There are queued messages restart the loop.
 499	firstQueuedMessage := queuedMessages[0]
 500	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
 501	return a.Run(ctx, firstQueuedMessage)
 502}
 503
 504func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
 505	if a.IsSessionBusy(sessionID) {
 506		return ErrSessionBusy
 507	}
 508
 509	currentSession, err := a.sessions.Get(ctx, sessionID)
 510	if err != nil {
 511		return fmt.Errorf("failed to get session: %w", err)
 512	}
 513	msgs, err := a.getSessionMessages(ctx, currentSession)
 514	if err != nil {
 515		return err
 516	}
 517	if len(msgs) == 0 {
 518		// Nothing to summarize.
 519		return nil
 520	}
 521
 522	aiMsgs, _ := a.preparePrompt(msgs)
 523
 524	genCtx, cancel := context.WithCancel(ctx)
 525	a.activeRequests.Set(sessionID, cancel)
 526	defer a.activeRequests.Del(sessionID)
 527	defer cancel()
 528
 529	agent := fantasy.NewAgent(a.largeModel.Model,
 530		fantasy.WithSystemPrompt(string(summaryPrompt)),
 531	)
 532	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 533		Role:             message.Assistant,
 534		Model:            a.largeModel.Model.Model(),
 535		Provider:         a.largeModel.Model.Provider(),
 536		IsSummaryMessage: true,
 537	})
 538	if err != nil {
 539		return err
 540	}
 541
 542	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
 543		Prompt:          "Provide a detailed summary of our conversation above.",
 544		Messages:        aiMsgs,
 545		ProviderOptions: opts,
 546		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 547			prepared.Messages = options.Messages
 548			if a.systemPromptPrefix != "" {
 549				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 550			}
 551			return callContext, prepared, nil
 552		},
 553		OnReasoningDelta: func(id string, text string) error {
 554			summaryMessage.AppendReasoningContent(text)
 555			return a.messages.Update(genCtx, summaryMessage)
 556		},
 557		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
 558			// Handle anthropic signature.
 559			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
 560				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
 561					summaryMessage.AppendReasoningSignature(signature.Signature)
 562				}
 563			}
 564			summaryMessage.FinishThinking()
 565			return a.messages.Update(genCtx, summaryMessage)
 566		},
 567		OnTextDelta: func(id, text string) error {
 568			summaryMessage.AppendContent(text)
 569			return a.messages.Update(genCtx, summaryMessage)
 570		},
 571	})
 572	if err != nil {
 573		isCancelErr := errors.Is(err, context.Canceled)
 574		if isCancelErr {
 575			// User cancelled summarize we need to remove the summary message.
 576			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
 577			return deleteErr
 578		}
 579		return err
 580	}
 581
 582	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
 583	err = a.messages.Update(genCtx, summaryMessage)
 584	if err != nil {
 585		return err
 586	}
 587
 588	var openrouterCost *float64
 589	for _, step := range resp.Steps {
 590		stepCost := a.openrouterCost(step.ProviderMetadata)
 591		if stepCost != nil {
 592			newCost := *stepCost
 593			if openrouterCost != nil {
 594				newCost += *openrouterCost
 595			}
 596			openrouterCost = &newCost
 597		}
 598	}
 599
 600	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 601
 602	// Just in case, get just the last usage info.
 603	usage := resp.Response.Usage
 604	currentSession.SummaryMessageID = summaryMessage.ID
 605	currentSession.CompletionTokens = usage.OutputTokens
 606	currentSession.PromptTokens = 0
 607	_, err = a.sessions.Save(genCtx, currentSession)
 608	return err
 609}
 610
 611func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
 612	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
 613		return fantasy.ProviderOptions{}
 614	}
 615	return fantasy.ProviderOptions{
 616		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 617			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 618		},
 619		bedrock.Name: &anthropic.ProviderCacheControlOptions{
 620			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 621		},
 622	}
 623}
 624
 625func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
 626	var attachmentParts []message.ContentPart
 627	for _, attachment := range call.Attachments {
 628		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 629	}
 630	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
 631	parts = append(parts, attachmentParts...)
 632	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
 633		Role:  message.User,
 634		Parts: parts,
 635	})
 636	if err != nil {
 637		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
 638	}
 639	return msg, nil
 640}
 641
 642func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
 643	var history []fantasy.Message
 644	for _, m := range msgs {
 645		if len(m.Parts) == 0 {
 646			continue
 647		}
 648		// Assistant message without content or tool calls (cancelled before it
 649		// returned anything).
 650		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
 651			continue
 652		}
 653		history = append(history, m.ToAIMessage()...)
 654	}
 655
 656	var files []fantasy.FilePart
 657	for _, attachment := range attachments {
 658		files = append(files, fantasy.FilePart{
 659			Filename:  attachment.FileName,
 660			Data:      attachment.Content,
 661			MediaType: attachment.MimeType,
 662		})
 663	}
 664
 665	return history, files
 666}
 667
 668func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
 669	msgs, err := a.messages.List(ctx, session.ID)
 670	if err != nil {
 671		return nil, fmt.Errorf("failed to list messages: %w", err)
 672	}
 673
 674	if session.SummaryMessageID != "" {
 675		summaryMsgInex := -1
 676		for i, msg := range msgs {
 677			if msg.ID == session.SummaryMessageID {
 678				summaryMsgInex = i
 679				break
 680			}
 681		}
 682		if summaryMsgInex != -1 {
 683			msgs = msgs[summaryMsgInex:]
 684			msgs[0].Role = message.User
 685		}
 686	}
 687	return msgs, nil
 688}
 689
 690func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
 691	if prompt == "" {
 692		return
 693	}
 694
 695	var maxOutput int64 = 40
 696	if a.smallModel.CatwalkCfg.CanReason {
 697		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
 698	}
 699
 700	agent := fantasy.NewAgent(a.smallModel.Model,
 701		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
 702		fantasy.WithMaxOutputTokens(maxOutput),
 703	)
 704
 705	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
 706		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
 707		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 708			prepared.Messages = options.Messages
 709			if a.systemPromptPrefix != "" {
 710				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
 711			}
 712			return callContext, prepared, nil
 713		},
 714	})
 715	if err != nil {
 716		slog.Error("error generating title", "err", err)
 717		return
 718	}
 719
 720	title := resp.Response.Content.Text()
 721
 722	title = strings.ReplaceAll(title, "\n", " ")
 723
 724	// Remove thinking tags if present.
 725	if idx := strings.Index(title, "</think>"); idx > 0 {
 726		title = title[idx+len("</think>"):]
 727	}
 728
 729	title = strings.TrimSpace(title)
 730	if title == "" {
 731		slog.Warn("failed to generate title", "warn", "empty title")
 732		return
 733	}
 734
 735	session.Title = title
 736
 737	var openrouterCost *float64
 738	for _, step := range resp.Steps {
 739		stepCost := a.openrouterCost(step.ProviderMetadata)
 740		if stepCost != nil {
 741			newCost := *stepCost
 742			if openrouterCost != nil {
 743				newCost += *openrouterCost
 744			}
 745			openrouterCost = &newCost
 746		}
 747	}
 748
 749	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
 750	_, saveErr := a.sessions.Save(ctx, *session)
 751	if saveErr != nil {
 752		slog.Error("failed to save session title & usage", "error", saveErr)
 753		return
 754	}
 755}
 756
 757func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
 758	openrouterMetadata, ok := metadata[openrouter.Name]
 759	if !ok {
 760		return nil
 761	}
 762
 763	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
 764	if !ok {
 765		return nil
 766	}
 767	return &opts.Usage.Cost
 768}
 769
 770func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 771	modelConfig := model.CatwalkCfg
 772	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 773		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 774		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 775		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 776
 777	if a.isClaudeCode() {
 778		cost = 0
 779	}
 780
 781	a.eventTokensUsed(session.ID, model, usage, cost)
 782
 783	if overrideCost != nil {
 784		session.Cost += *overrideCost
 785	} else {
 786		session.Cost += cost
 787	}
 788
 789	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 790	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 791}
 792
 793func (a *sessionAgent) Cancel(sessionID string) {
 794	// Cancel regular requests.
 795	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 796		slog.Info("Request cancellation initiated", "session_id", sessionID)
 797		cancel()
 798	}
 799
 800	// Also check for summarize requests.
 801	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 802		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 803		cancel()
 804	}
 805
 806	if a.QueuedPrompts(sessionID) > 0 {
 807		slog.Info("Clearing queued prompts", "session_id", sessionID)
 808		a.messageQueue.Del(sessionID)
 809	}
 810}
 811
 812func (a *sessionAgent) ClearQueue(sessionID string) {
 813	if a.QueuedPrompts(sessionID) > 0 {
 814		slog.Info("Clearing queued prompts", "session_id", sessionID)
 815		a.messageQueue.Del(sessionID)
 816	}
 817}
 818
 819func (a *sessionAgent) CancelAll() {
 820	if !a.IsBusy() {
 821		return
 822	}
 823	for key := range a.activeRequests.Seq2() {
 824		a.Cancel(key) // key is sessionID
 825	}
 826
 827	timeout := time.After(5 * time.Second)
 828	for a.IsBusy() {
 829		select {
 830		case <-timeout:
 831			return
 832		default:
 833			time.Sleep(200 * time.Millisecond)
 834		}
 835	}
 836}
 837
 838func (a *sessionAgent) IsBusy() bool {
 839	var busy bool
 840	for cancelFunc := range a.activeRequests.Seq() {
 841		if cancelFunc != nil {
 842			busy = true
 843			break
 844		}
 845	}
 846	return busy
 847}
 848
 849func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
 850	_, busy := a.activeRequests.Get(sessionID)
 851	return busy
 852}
 853
 854func (a *sessionAgent) QueuedPrompts(sessionID string) int {
 855	l, ok := a.messageQueue.Get(sessionID)
 856	if !ok {
 857		return 0
 858	}
 859	return len(l)
 860}
 861
 862func (a *sessionAgent) SetModels(large Model, small Model) {
 863	a.largeModel = large
 864	a.smallModel = small
 865}
 866
 867func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
 868	a.tools = tools
 869}
 870
 871func (a *sessionAgent) Model() Model {
 872	return a.largeModel
 873}
 874
 875func (a *sessionAgent) promptPrefix() string {
 876	if a.isClaudeCode() {
 877		return "You are Claude Code, Anthropic's official CLI for Claude."
 878	}
 879	return a.systemPromptPrefix
 880}
 881
 882func (a *sessionAgent) isClaudeCode() bool {
 883	cfg := config.Get()
 884	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)
 885	return ok && pc.ID == string(catwalk.InferenceProviderAnthropic) && pc.OAuthToken != nil
 886}
 887
 888// convertToToolResult converts a fantasy tool result to a message tool result.
 889func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
 890	baseResult := message.ToolResult{
 891		ToolCallID: result.ToolCallID,
 892		Name:       result.ToolName,
 893		Metadata:   result.ClientMetadata,
 894	}
 895
 896	switch result.Result.GetType() {
 897	case fantasy.ToolResultContentTypeText:
 898		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
 899			baseResult.Content = r.Text
 900		}
 901	case fantasy.ToolResultContentTypeError:
 902		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
 903			baseResult.Content = r.Error.Error()
 904			baseResult.IsError = true
 905		}
 906	case fantasy.ToolResultContentTypeMedia:
 907		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
 908			content := r.Text
 909			if content == "" {
 910				content = fmt.Sprintf("Loaded %s content", r.MediaType)
 911			}
 912			baseResult.Content = content
 913			baseResult.Data = r.Data
 914			baseResult.MIMEType = r.MediaType
 915		}
 916	}
 917
 918	return baseResult
 919}
 920
 921// workaroundProviderMediaLimitations converts media content in tool results to
 922// user messages for providers that don't natively support images in tool results.
 923//
 924// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
 925// don't support sending images/media in tool result messages - they only accept
 926// text in tool results. However, they DO support images in user messages.
 927//
 928// If we send media in tool results to these providers, the API returns an error.
 929//
 930// Solution: For these providers, we:
 931//  1. Replace the media in the tool result with a text placeholder
 932//  2. Inject a user message immediately after with the image as a file attachment
 933//  3. This maintains the tool execution flow while working around API limitations
 934//
 935// Anthropic and Bedrock support images natively in tool results, so we skip
 936// this workaround for them.
 937//
 938// Example transformation:
 939//
 940//	BEFORE: [tool result: image data]
 941//	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
 942func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
 943	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
 944		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
 945
 946	if providerSupportsMedia {
 947		return messages
 948	}
 949
 950	convertedMessages := make([]fantasy.Message, 0, len(messages))
 951
 952	for _, msg := range messages {
 953		if msg.Role != fantasy.MessageRoleTool {
 954			convertedMessages = append(convertedMessages, msg)
 955			continue
 956		}
 957
 958		textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
 959		var mediaFiles []fantasy.FilePart
 960
 961		for _, part := range msg.Content {
 962			toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
 963			if !ok {
 964				textParts = append(textParts, part)
 965				continue
 966			}
 967
 968			if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
 969				decoded, err := base64.StdEncoding.DecodeString(media.Data)
 970				if err != nil {
 971					slog.Warn("failed to decode media data", "error", err)
 972					textParts = append(textParts, part)
 973					continue
 974				}
 975
 976				mediaFiles = append(mediaFiles, fantasy.FilePart{
 977					Data:      decoded,
 978					MediaType: media.MediaType,
 979					Filename:  fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
 980				})
 981
 982				textParts = append(textParts, fantasy.ToolResultPart{
 983					ToolCallID: toolResult.ToolCallID,
 984					Output: fantasy.ToolResultOutputContentText{
 985						Text: "[Image/media content loaded - see attached file]",
 986					},
 987					ProviderOptions: toolResult.ProviderOptions,
 988				})
 989			} else {
 990				textParts = append(textParts, part)
 991			}
 992		}
 993
 994		convertedMessages = append(convertedMessages, fantasy.Message{
 995			Role:    fantasy.MessageRoleTool,
 996			Content: textParts,
 997		})
 998
 999		if len(mediaFiles) > 0 {
1000			convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1001				"Here is the media content from the tool result:",
1002				mediaFiles...,
1003			))
1004		}
1005	}
1006
1007	return convertedMessages
1008}