coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"slices"
  16	"strings"
  17
  18	"charm.land/catwalk/pkg/catwalk"
  19	"charm.land/fantasy"
  20	"github.com/charmbracelet/crush/internal/agent/hyper"
  21	"github.com/charmbracelet/crush/internal/agent/notify"
  22	"github.com/charmbracelet/crush/internal/agent/prompt"
  23	"github.com/charmbracelet/crush/internal/agent/tools"
  24	"github.com/charmbracelet/crush/internal/config"
  25	"github.com/charmbracelet/crush/internal/filetracker"
  26	"github.com/charmbracelet/crush/internal/history"
  27	"github.com/charmbracelet/crush/internal/log"
  28	"github.com/charmbracelet/crush/internal/lsp"
  29	"github.com/charmbracelet/crush/internal/message"
  30	"github.com/charmbracelet/crush/internal/oauth/copilot"
  31	"github.com/charmbracelet/crush/internal/permission"
  32	"github.com/charmbracelet/crush/internal/pubsub"
  33	"github.com/charmbracelet/crush/internal/session"
  34	"golang.org/x/sync/errgroup"
  35
  36	"charm.land/fantasy/providers/anthropic"
  37	"charm.land/fantasy/providers/azure"
  38	"charm.land/fantasy/providers/bedrock"
  39	"charm.land/fantasy/providers/google"
  40	"charm.land/fantasy/providers/openai"
  41	"charm.land/fantasy/providers/openaicompat"
  42	"charm.land/fantasy/providers/openrouter"
  43	"charm.land/fantasy/providers/vercel"
  44	openaisdk "github.com/openai/openai-go/v2/option"
  45	"github.com/qjebbs/go-jsons"
  46)
  47
  48// Coordinator errors.
  49var (
  50	errCoderAgentNotConfigured         = errors.New("coder agent not configured")
  51	errModelProviderNotConfigured      = errors.New("model provider not configured")
  52	errLargeModelNotSelected           = errors.New("large model not selected")
  53	errSmallModelNotSelected           = errors.New("small model not selected")
  54	errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  55	errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  56	errLargeModelNotFound              = errors.New("large model not found in provider config")
  57	errSmallModelNotFound              = errors.New("small model not found in provider config")
  58)
  59
  60type Coordinator interface {
  61	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  62	// SetMainAgent(string)
  63	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  64	Cancel(sessionID string)
  65	CancelAll()
  66	IsSessionBusy(sessionID string) bool
  67	IsBusy() bool
  68	QueuedPrompts(sessionID string) int
  69	QueuedPromptsList(sessionID string) []string
  70	ClearQueue(sessionID string)
  71	Summarize(context.Context, string) error
  72	Model() Model
  73	UpdateModels(ctx context.Context) error
  74}
  75
  76type coordinator struct {
  77	cfg         *config.ConfigStore
  78	sessions    session.Service
  79	messages    message.Service
  80	permissions permission.Service
  81	history     history.Service
  82	filetracker filetracker.Service
  83	lspManager  *lsp.Manager
  84	notify      pubsub.Publisher[notify.Notification]
  85
  86	currentAgent SessionAgent
  87	agents       map[string]SessionAgent
  88
  89	readyWg errgroup.Group
  90}
  91
  92func NewCoordinator(
  93	ctx context.Context,
  94	cfg *config.ConfigStore,
  95	sessions session.Service,
  96	messages message.Service,
  97	permissions permission.Service,
  98	history history.Service,
  99	filetracker filetracker.Service,
 100	lspManager *lsp.Manager,
 101	notify pubsub.Publisher[notify.Notification],
 102) (Coordinator, error) {
 103	c := &coordinator{
 104		cfg:         cfg,
 105		sessions:    sessions,
 106		messages:    messages,
 107		permissions: permissions,
 108		history:     history,
 109		filetracker: filetracker,
 110		lspManager:  lspManager,
 111		notify:      notify,
 112		agents:      make(map[string]SessionAgent),
 113	}
 114
 115	agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
 116	if !ok {
 117		return nil, errCoderAgentNotConfigured
 118	}
 119
 120	// TODO: make this dynamic when we support multiple agents
 121	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 122	if err != nil {
 123		return nil, err
 124	}
 125
 126	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 127	if err != nil {
 128		return nil, err
 129	}
 130	c.currentAgent = agent
 131	c.agents[config.AgentCoder] = agent
 132	return c, nil
 133}
 134
 135// Run implements Coordinator.
 136func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 137	if err := c.readyWg.Wait(); err != nil {
 138		return nil, err
 139	}
 140
 141	// refresh models before each run
 142	if err := c.UpdateModels(ctx); err != nil {
 143		return nil, fmt.Errorf("failed to update models: %w", err)
 144	}
 145
 146	model := c.currentAgent.Model()
 147	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 148	if model.ModelCfg.MaxTokens != 0 {
 149		maxTokens = model.ModelCfg.MaxTokens
 150	}
 151
 152	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 153		// filter out image attachments
 154		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 155		for _, att := range attachments {
 156			if att.IsText() {
 157				filteredAttachments = append(filteredAttachments, att)
 158			}
 159		}
 160		attachments = filteredAttachments
 161	}
 162
 163	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 164	if !ok {
 165		return nil, errModelProviderNotConfigured
 166	}
 167
 168	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 169
 170	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 171		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 172		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 173			return nil, err
 174		}
 175	}
 176
 177	run := func() (*fantasy.AgentResult, error) {
 178		return c.currentAgent.Run(ctx, SessionAgentCall{
 179			SessionID:        sessionID,
 180			Prompt:           prompt,
 181			Attachments:      attachments,
 182			MaxOutputTokens:  maxTokens,
 183			ProviderOptions:  mergedOptions,
 184			Temperature:      temp,
 185			TopP:             topP,
 186			TopK:             topK,
 187			FrequencyPenalty: freqPenalty,
 188			PresencePenalty:  presPenalty,
 189		})
 190	}
 191	result, originalErr := run()
 192
 193	if c.isUnauthorized(originalErr) {
 194		switch {
 195		case providerCfg.OAuthToken != nil:
 196			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 197			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 198				return nil, originalErr
 199			}
 200			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 201			return run()
 202		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 203			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 204			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 205				return nil, originalErr
 206			}
 207			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 208			return run()
 209		}
 210	}
 211
 212	return result, originalErr
 213}
 214
 215func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 216	options := fantasy.ProviderOptions{}
 217
 218	cfgOpts := []byte("{}")
 219	providerCfgOpts := []byte("{}")
 220	catwalkOpts := []byte("{}")
 221
 222	if model.ModelCfg.ProviderOptions != nil {
 223		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 224		if err == nil {
 225			cfgOpts = data
 226		}
 227	}
 228
 229	if providerCfg.ProviderOptions != nil {
 230		data, err := json.Marshal(providerCfg.ProviderOptions)
 231		if err == nil {
 232			providerCfgOpts = data
 233		}
 234	}
 235
 236	if model.CatwalkCfg.Options.ProviderOptions != nil {
 237		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 238		if err == nil {
 239			catwalkOpts = data
 240		}
 241	}
 242
 243	readers := []io.Reader{
 244		bytes.NewReader(catwalkOpts),
 245		bytes.NewReader(providerCfgOpts),
 246		bytes.NewReader(cfgOpts),
 247	}
 248
 249	got, err := jsons.Merge(readers)
 250	if err != nil {
 251		slog.Error("Could not merge call config", "err", err)
 252		return options
 253	}
 254
 255	mergedOptions := make(map[string]any)
 256
 257	err = json.Unmarshal([]byte(got), &mergedOptions)
 258	if err != nil {
 259		slog.Error("Could not create config for call", "err", err)
 260		return options
 261	}
 262
 263	providerType := providerCfg.Type
 264	if providerType == "hyper" {
 265		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 266			providerType = anthropic.Name
 267		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 268			providerType = openai.Name
 269		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 270			providerType = google.Name
 271		} else {
 272			providerType = openaicompat.Name
 273		}
 274	}
 275
 276	switch providerType {
 277	case openai.Name, azure.Name:
 278		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 279		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 280			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 281		}
 282		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 283			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 284				mergedOptions["reasoning_summary"] = "auto"
 285				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 286			}
 287			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 288			if err == nil {
 289				options[openai.Name] = parsed
 290			}
 291		} else {
 292			parsed, err := openai.ParseOptions(mergedOptions)
 293			if err == nil {
 294				options[openai.Name] = parsed
 295			}
 296		}
 297	case anthropic.Name:
 298		var (
 299			_, hasEffort = mergedOptions["effort"]
 300			_, hasThink  = mergedOptions["thinking"]
 301		)
 302		switch {
 303		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 304			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 305		case !hasThink && model.ModelCfg.Think:
 306			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 307		}
 308		parsed, err := anthropic.ParseOptions(mergedOptions)
 309		if err == nil {
 310			options[anthropic.Name] = parsed
 311		}
 312
 313	case openrouter.Name:
 314		_, hasReasoning := mergedOptions["reasoning"]
 315		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 316			mergedOptions["reasoning"] = map[string]any{
 317				"enabled": true,
 318				"effort":  model.ModelCfg.ReasoningEffort,
 319			}
 320		}
 321		parsed, err := openrouter.ParseOptions(mergedOptions)
 322		if err == nil {
 323			options[openrouter.Name] = parsed
 324		}
 325	case vercel.Name:
 326		_, hasReasoning := mergedOptions["reasoning"]
 327		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 328			mergedOptions["reasoning"] = map[string]any{
 329				"enabled": true,
 330				"effort":  model.ModelCfg.ReasoningEffort,
 331			}
 332		}
 333		parsed, err := vercel.ParseOptions(mergedOptions)
 334		if err == nil {
 335			options[vercel.Name] = parsed
 336		}
 337	case google.Name:
 338		_, hasReasoning := mergedOptions["thinking_config"]
 339		if !hasReasoning {
 340			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 341				mergedOptions["thinking_config"] = map[string]any{
 342					"thinking_budget":  2000,
 343					"include_thoughts": true,
 344				}
 345			} else {
 346				mergedOptions["thinking_config"] = map[string]any{
 347					"thinking_level":   model.ModelCfg.ReasoningEffort,
 348					"include_thoughts": true,
 349				}
 350			}
 351		}
 352		parsed, err := google.ParseOptions(mergedOptions)
 353		if err == nil {
 354			options[google.Name] = parsed
 355		}
 356	case openaicompat.Name:
 357		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 358		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 359			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 360		}
 361		parsed, err := openaicompat.ParseOptions(mergedOptions)
 362		if err == nil {
 363			options[openaicompat.Name] = parsed
 364		}
 365	}
 366
 367	return options
 368}
 369
 370func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 371	modelOptions := getProviderOptions(model, cfg)
 372	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 373	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 374	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 375	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 376	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 377	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 378}
 379
 380func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 381	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 382	if err != nil {
 383		return nil, err
 384	}
 385
 386	largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
 387	result := NewSessionAgent(SessionAgentOptions{
 388		LargeModel:           large,
 389		SmallModel:           small,
 390		SystemPromptPrefix:   largeProviderCfg.SystemPromptPrefix,
 391		SystemPrompt:         "",
 392		IsSubAgent:           isSubAgent,
 393		DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
 394		IsYolo:               c.permissions.SkipRequests(),
 395		Sessions:             c.sessions,
 396		Messages:             c.messages,
 397		Tools:                nil,
 398		Notify:               c.notify,
 399	})
 400
 401	c.readyWg.Go(func() error {
 402		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
 403		if err != nil {
 404			return err
 405		}
 406		result.SetSystemPrompt(systemPrompt)
 407		return nil
 408	})
 409
 410	c.readyWg.Go(func() error {
 411		tools, err := c.buildTools(ctx, agent)
 412		if err != nil {
 413			return err
 414		}
 415		result.SetTools(tools)
 416		return nil
 417	})
 418
 419	return result, nil
 420}
 421
 422func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 423	var allTools []fantasy.AgentTool
 424	if slices.Contains(agent.AllowedTools, AgentToolName) {
 425		agentTool, err := c.agentTool(ctx)
 426		if err != nil {
 427			return nil, err
 428		}
 429		allTools = append(allTools, agentTool)
 430	}
 431
 432	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 433		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 434		if err != nil {
 435			return nil, err
 436		}
 437		allTools = append(allTools, agenticFetchTool)
 438	}
 439
 440	// Get the model name for the agent
 441	modelName := ""
 442	if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
 443		if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 444			modelName = model.Name
 445		}
 446	}
 447
 448	allTools = append(allTools,
 449		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
 450		tools.NewJobOutputTool(),
 451		tools.NewJobKillTool(),
 452		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 453		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 454		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 455		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 456		tools.NewGlobTool(c.cfg.WorkingDir()),
 457		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
 458		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
 459		tools.NewSourcegraphTool(nil),
 460		tools.NewTodosTool(c.sessions),
 461		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
 462		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 463	)
 464
 465	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 466	if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
 467		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 468	}
 469
 470	if len(c.cfg.Config().MCP) > 0 {
 471		allTools = append(
 472			allTools,
 473			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 474			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 475		)
 476	}
 477
 478	var filteredTools []fantasy.AgentTool
 479	for _, tool := range allTools {
 480		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 481			filteredTools = append(filteredTools, tool)
 482		}
 483	}
 484
 485	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 486		if agent.AllowedMCP == nil {
 487			// No MCP restrictions
 488			filteredTools = append(filteredTools, tool)
 489			continue
 490		}
 491		if len(agent.AllowedMCP) == 0 {
 492			// No MCPs allowed
 493			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 494			break
 495		}
 496
 497		for mcp, tools := range agent.AllowedMCP {
 498			if mcp != tool.MCP() {
 499				continue
 500			}
 501			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 502				filteredTools = append(filteredTools, tool)
 503				break
 504			}
 505			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 506		}
 507	}
 508	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 509		return strings.Compare(a.Info().Name, b.Info().Name)
 510	})
 511	return filteredTools, nil
 512}
 513
 514// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 515func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 516	largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
 517	if !ok {
 518		return Model{}, Model{}, errLargeModelNotSelected
 519	}
 520	smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
 521	if !ok {
 522		return Model{}, Model{}, errSmallModelNotSelected
 523	}
 524
 525	largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
 526	if !ok {
 527		return Model{}, Model{}, errLargeModelProviderNotConfigured
 528	}
 529
 530	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 531	if err != nil {
 532		return Model{}, Model{}, err
 533	}
 534
 535	smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
 536	if !ok {
 537		return Model{}, Model{}, errSmallModelProviderNotConfigured
 538	}
 539
 540	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 541	if err != nil {
 542		return Model{}, Model{}, err
 543	}
 544
 545	var largeCatwalkModel *catwalk.Model
 546	var smallCatwalkModel *catwalk.Model
 547
 548	for _, m := range largeProviderCfg.Models {
 549		if m.ID == largeModelCfg.Model {
 550			largeCatwalkModel = &m
 551		}
 552	}
 553	for _, m := range smallProviderCfg.Models {
 554		if m.ID == smallModelCfg.Model {
 555			smallCatwalkModel = &m
 556		}
 557	}
 558
 559	if largeCatwalkModel == nil {
 560		return Model{}, Model{}, errLargeModelNotFound
 561	}
 562
 563	if smallCatwalkModel == nil {
 564		return Model{}, Model{}, errSmallModelNotFound
 565	}
 566
 567	largeModelID := largeModelCfg.Model
 568	smallModelID := smallModelCfg.Model
 569
 570	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 571		largeModelID += ":exacto"
 572	}
 573
 574	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 575		smallModelID += ":exacto"
 576	}
 577
 578	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 579	if err != nil {
 580		return Model{}, Model{}, err
 581	}
 582	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 583	if err != nil {
 584		return Model{}, Model{}, err
 585	}
 586
 587	return Model{
 588			Model:      largeModel,
 589			CatwalkCfg: *largeCatwalkModel,
 590			ModelCfg:   largeModelCfg,
 591		}, Model{
 592			Model:      smallModel,
 593			CatwalkCfg: *smallCatwalkModel,
 594			ModelCfg:   smallModelCfg,
 595		}, nil
 596}
 597
 598func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 599	var opts []anthropic.Option
 600
 601	switch {
 602	case strings.HasPrefix(apiKey, "Bearer "):
 603		// NOTE: Prevent the SDK from picking up the API key from env.
 604		os.Setenv("ANTHROPIC_API_KEY", "")
 605		headers["Authorization"] = apiKey
 606	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 607		// NOTE: Prevent the SDK from picking up the API key from env.
 608		os.Setenv("ANTHROPIC_API_KEY", "")
 609		headers["Authorization"] = "Bearer " + apiKey
 610	case apiKey != "":
 611		// X-Api-Key header
 612		opts = append(opts, anthropic.WithAPIKey(apiKey))
 613	}
 614
 615	if len(headers) > 0 {
 616		opts = append(opts, anthropic.WithHeaders(headers))
 617	}
 618
 619	if baseURL != "" {
 620		opts = append(opts, anthropic.WithBaseURL(baseURL))
 621	}
 622
 623	if c.cfg.Config().Options.Debug {
 624		httpClient := log.NewHTTPClient()
 625		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 626	}
 627	return anthropic.New(opts...)
 628}
 629
 630func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 631	opts := []openai.Option{
 632		openai.WithAPIKey(apiKey),
 633		openai.WithUseResponsesAPI(),
 634	}
 635	if c.cfg.Config().Options.Debug {
 636		httpClient := log.NewHTTPClient()
 637		opts = append(opts, openai.WithHTTPClient(httpClient))
 638	}
 639	if len(headers) > 0 {
 640		opts = append(opts, openai.WithHeaders(headers))
 641	}
 642	if baseURL != "" {
 643		opts = append(opts, openai.WithBaseURL(baseURL))
 644	}
 645	return openai.New(opts...)
 646}
 647
 648func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 649	opts := []openrouter.Option{
 650		openrouter.WithAPIKey(apiKey),
 651	}
 652	if c.cfg.Config().Options.Debug {
 653		httpClient := log.NewHTTPClient()
 654		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 655	}
 656	if len(headers) > 0 {
 657		opts = append(opts, openrouter.WithHeaders(headers))
 658	}
 659	return openrouter.New(opts...)
 660}
 661
 662func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 663	opts := []vercel.Option{
 664		vercel.WithAPIKey(apiKey),
 665	}
 666	if c.cfg.Config().Options.Debug {
 667		httpClient := log.NewHTTPClient()
 668		opts = append(opts, vercel.WithHTTPClient(httpClient))
 669	}
 670	if len(headers) > 0 {
 671		opts = append(opts, vercel.WithHeaders(headers))
 672	}
 673	return vercel.New(opts...)
 674}
 675
 676func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 677	opts := []openaicompat.Option{
 678		openaicompat.WithBaseURL(baseURL),
 679		openaicompat.WithAPIKey(apiKey),
 680	}
 681
 682	// Set HTTP client based on provider and debug mode.
 683	var httpClient *http.Client
 684	if providerID == string(catwalk.InferenceProviderCopilot) {
 685		opts = append(opts, openaicompat.WithUseResponsesAPI())
 686		httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
 687	} else if c.cfg.Config().Options.Debug {
 688		httpClient = log.NewHTTPClient()
 689	}
 690	if httpClient != nil {
 691		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 692	}
 693
 694	if len(headers) > 0 {
 695		opts = append(opts, openaicompat.WithHeaders(headers))
 696	}
 697
 698	for extraKey, extraValue := range extraBody {
 699		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 700	}
 701
 702	return openaicompat.New(opts...)
 703}
 704
 705func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 706	opts := []azure.Option{
 707		azure.WithBaseURL(baseURL),
 708		azure.WithAPIKey(apiKey),
 709		azure.WithUseResponsesAPI(),
 710	}
 711	if c.cfg.Config().Options.Debug {
 712		httpClient := log.NewHTTPClient()
 713		opts = append(opts, azure.WithHTTPClient(httpClient))
 714	}
 715	if options == nil {
 716		options = make(map[string]string)
 717	}
 718	if apiVersion, ok := options["apiVersion"]; ok {
 719		opts = append(opts, azure.WithAPIVersion(apiVersion))
 720	}
 721	if len(headers) > 0 {
 722		opts = append(opts, azure.WithHeaders(headers))
 723	}
 724
 725	return azure.New(opts...)
 726}
 727
 728func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
 729	var opts []bedrock.Option
 730	if c.cfg.Config().Options.Debug {
 731		httpClient := log.NewHTTPClient()
 732		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 733	}
 734	if len(headers) > 0 {
 735		opts = append(opts, bedrock.WithHeaders(headers))
 736	}
 737	bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
 738	if bearerToken != "" {
 739		opts = append(opts, bedrock.WithAPIKey(bearerToken))
 740	}
 741	return bedrock.New(opts...)
 742}
 743
 744func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 745	opts := []google.Option{
 746		google.WithBaseURL(baseURL),
 747		google.WithGeminiAPIKey(apiKey),
 748	}
 749	if c.cfg.Config().Options.Debug {
 750		httpClient := log.NewHTTPClient()
 751		opts = append(opts, google.WithHTTPClient(httpClient))
 752	}
 753	if len(headers) > 0 {
 754		opts = append(opts, google.WithHeaders(headers))
 755	}
 756	return google.New(opts...)
 757}
 758
 759func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 760	opts := []google.Option{}
 761	if c.cfg.Config().Options.Debug {
 762		httpClient := log.NewHTTPClient()
 763		opts = append(opts, google.WithHTTPClient(httpClient))
 764	}
 765	if len(headers) > 0 {
 766		opts = append(opts, google.WithHeaders(headers))
 767	}
 768
 769	project := options["project"]
 770	location := options["location"]
 771
 772	opts = append(opts, google.WithVertex(project, location))
 773
 774	return google.New(opts...)
 775}
 776
 777func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
 778	opts := []hyper.Option{
 779		hyper.WithBaseURL(baseURL),
 780		hyper.WithAPIKey(apiKey),
 781	}
 782	if c.cfg.Config().Options.Debug {
 783		httpClient := log.NewHTTPClient()
 784		opts = append(opts, hyper.WithHTTPClient(httpClient))
 785	}
 786	return hyper.New(opts...)
 787}
 788
 789func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 790	if model.Think {
 791		return true
 792	}
 793	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 794	return err == nil && opts.Thinking != nil
 795}
 796
 797func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 798	headers := maps.Clone(providerCfg.ExtraHeaders)
 799	if headers == nil {
 800		headers = make(map[string]string)
 801	}
 802
 803	// handle special headers for anthropic
 804	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 805		if v, ok := headers["anthropic-beta"]; ok {
 806			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 807		} else {
 808			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 809		}
 810	}
 811
 812	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 813	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 814
 815	switch providerCfg.Type {
 816	case openai.Name:
 817		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 818	case anthropic.Name:
 819		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 820	case openrouter.Name:
 821		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 822	case vercel.Name:
 823		return c.buildVercelProvider(baseURL, apiKey, headers)
 824	case azure.Name:
 825		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 826	case bedrock.Name:
 827		return c.buildBedrockProvider(headers)
 828	case google.Name:
 829		return c.buildGoogleProvider(baseURL, apiKey, headers)
 830	case "google-vertex":
 831		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 832	case openaicompat.Name:
 833		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 834			if providerCfg.ExtraBody == nil {
 835				providerCfg.ExtraBody = map[string]any{}
 836			}
 837			providerCfg.ExtraBody["tool_stream"] = true
 838		}
 839		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 840	case hyper.Name:
 841		return c.buildHyperProvider(baseURL, apiKey)
 842	default:
 843		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 844	}
 845}
 846
 847func isExactoSupported(modelID string) bool {
 848	supportedModels := []string{
 849		"moonshotai/kimi-k2-0905",
 850		"deepseek/deepseek-v3.1-terminus",
 851		"z-ai/glm-4.6",
 852		"openai/gpt-oss-120b",
 853		"qwen/qwen3-coder",
 854	}
 855	return slices.Contains(supportedModels, modelID)
 856}
 857
 858func (c *coordinator) Cancel(sessionID string) {
 859	c.currentAgent.Cancel(sessionID)
 860}
 861
 862func (c *coordinator) CancelAll() {
 863	c.currentAgent.CancelAll()
 864}
 865
 866func (c *coordinator) ClearQueue(sessionID string) {
 867	c.currentAgent.ClearQueue(sessionID)
 868}
 869
 870func (c *coordinator) IsBusy() bool {
 871	return c.currentAgent.IsBusy()
 872}
 873
 874func (c *coordinator) IsSessionBusy(sessionID string) bool {
 875	return c.currentAgent.IsSessionBusy(sessionID)
 876}
 877
 878func (c *coordinator) Model() Model {
 879	return c.currentAgent.Model()
 880}
 881
 882func (c *coordinator) UpdateModels(ctx context.Context) error {
 883	// build the models again so we make sure we get the latest config
 884	large, small, err := c.buildAgentModels(ctx, false)
 885	if err != nil {
 886		return err
 887	}
 888	c.currentAgent.SetModels(large, small)
 889
 890	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 891	if !ok {
 892		return errCoderAgentNotConfigured
 893	}
 894
 895	tools, err := c.buildTools(ctx, agentCfg)
 896	if err != nil {
 897		return err
 898	}
 899	c.currentAgent.SetTools(tools)
 900	return nil
 901}
 902
 903func (c *coordinator) QueuedPrompts(sessionID string) int {
 904	return c.currentAgent.QueuedPrompts(sessionID)
 905}
 906
 907func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 908	return c.currentAgent.QueuedPromptsList(sessionID)
 909}
 910
 911func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 912	providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 913	if !ok {
 914		return errModelProviderNotConfigured
 915	}
 916	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 917}
 918
 919func (c *coordinator) isUnauthorized(err error) bool {
 920	var providerErr *fantasy.ProviderError
 921	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
 922}
 923
 924func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 925	if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
 926		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 927		return err
 928	}
 929	if err := c.UpdateModels(ctx); err != nil {
 930		return err
 931	}
 932	return nil
 933}
 934
 935func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 936	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 937	if err != nil {
 938		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 939		return err
 940	}
 941
 942	providerCfg.APIKey = newAPIKey
 943	c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
 944
 945	if err := c.UpdateModels(ctx); err != nil {
 946		return err
 947	}
 948	return nil
 949}
 950
 951// subAgentParams holds the parameters for running a sub-agent.
 952type subAgentParams struct {
 953	Agent          SessionAgent
 954	SessionID      string
 955	AgentMessageID string
 956	ToolCallID     string
 957	Prompt         string
 958	SessionTitle   string
 959	// SessionSetup is an optional callback invoked after session creation
 960	// but before agent execution, for custom session configuration.
 961	SessionSetup func(sessionID string)
 962}
 963
 964// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 965// It creates a sub-session, runs the agent with the given prompt, and propagates
 966// the cost to the parent session.
 967func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 968	// Create sub-session
 969	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 970	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 971	if err != nil {
 972		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 973	}
 974
 975	// Call session setup function if provided
 976	if params.SessionSetup != nil {
 977		params.SessionSetup(session.ID)
 978	}
 979
 980	// Get model configuration
 981	model := params.Agent.Model()
 982	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 983	if model.ModelCfg.MaxTokens != 0 {
 984		maxTokens = model.ModelCfg.MaxTokens
 985	}
 986
 987	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 988	if !ok {
 989		return fantasy.ToolResponse{}, errModelProviderNotConfigured
 990	}
 991
 992	// Run the agent
 993	result, err := params.Agent.Run(ctx, SessionAgentCall{
 994		SessionID:        session.ID,
 995		Prompt:           params.Prompt,
 996		MaxOutputTokens:  maxTokens,
 997		ProviderOptions:  getProviderOptions(model, providerCfg),
 998		Temperature:      model.ModelCfg.Temperature,
 999		TopP:             model.ModelCfg.TopP,
1000		TopK:             model.ModelCfg.TopK,
1001		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1002		PresencePenalty:  model.ModelCfg.PresencePenalty,
1003		NonInteractive:   true,
1004	})
1005	if err != nil {
1006		return fantasy.NewTextErrorResponse("error generating response"), nil
1007	}
1008
1009	// Update parent session cost
1010	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1011		return fantasy.ToolResponse{}, err
1012	}
1013
1014	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1015}
1016
1017// updateParentSessionCost accumulates the cost from a child session to its parent session.
1018func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1019	childSession, err := c.sessions.Get(ctx, childSessionID)
1020	if err != nil {
1021		return fmt.Errorf("get child session: %w", err)
1022	}
1023
1024	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1025	if err != nil {
1026		return fmt.Errorf("get parent session: %w", err)
1027	}
1028
1029	parentSession.Cost += childSession.Cost
1030
1031	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1032		return fmt.Errorf("save parent session: %w", err)
1033	}
1034
1035	return nil
1036}