coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"slices"
  16	"strings"
  17
  18	"charm.land/catwalk/pkg/catwalk"
  19	"charm.land/fantasy"
  20	"github.com/charmbracelet/crush/internal/agent/hyper"
  21	"github.com/charmbracelet/crush/internal/agent/notify"
  22	"github.com/charmbracelet/crush/internal/agent/prompt"
  23	"github.com/charmbracelet/crush/internal/agent/tools"
  24	"github.com/charmbracelet/crush/internal/config"
  25	"github.com/charmbracelet/crush/internal/filetracker"
  26	"github.com/charmbracelet/crush/internal/history"
  27	"github.com/charmbracelet/crush/internal/log"
  28	"github.com/charmbracelet/crush/internal/lsp"
  29	"github.com/charmbracelet/crush/internal/message"
  30	"github.com/charmbracelet/crush/internal/oauth/copilot"
  31	"github.com/charmbracelet/crush/internal/permission"
  32	"github.com/charmbracelet/crush/internal/pubsub"
  33	"github.com/charmbracelet/crush/internal/session"
  34	"golang.org/x/sync/errgroup"
  35
  36	"charm.land/fantasy/providers/anthropic"
  37	"charm.land/fantasy/providers/azure"
  38	"charm.land/fantasy/providers/bedrock"
  39	"charm.land/fantasy/providers/google"
  40	"charm.land/fantasy/providers/openai"
  41	"charm.land/fantasy/providers/openaicompat"
  42	"charm.land/fantasy/providers/openrouter"
  43	"charm.land/fantasy/providers/vercel"
  44	openaisdk "github.com/openai/openai-go/v3/option"
  45	"github.com/qjebbs/go-jsons"
  46)
  47
  48// Coordinator errors.
  49var (
  50	errCoderAgentNotConfigured         = errors.New("coder agent not configured")
  51	errModelProviderNotConfigured      = errors.New("model provider not configured")
  52	errLargeModelNotSelected           = errors.New("large model not selected")
  53	errSmallModelNotSelected           = errors.New("small model not selected")
  54	errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  55	errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  56	errLargeModelNotFound              = errors.New("large model not found in provider config")
  57	errSmallModelNotFound              = errors.New("small model not found in provider config")
  58)
  59
  60type Coordinator interface {
  61	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  62	// SetMainAgent(string)
  63	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  64	Cancel(sessionID string)
  65	CancelAll()
  66	IsSessionBusy(sessionID string) bool
  67	IsBusy() bool
  68	QueuedPrompts(sessionID string) int
  69	QueuedPromptsList(sessionID string) []string
  70	ClearQueue(sessionID string)
  71	Summarize(context.Context, string) error
  72	Model() Model
  73	UpdateModels(ctx context.Context) error
  74	RefreshTools(ctx context.Context) error
  75}
  76
  77type coordinator struct {
  78	cfg         *config.ConfigStore
  79	sessions    session.Service
  80	messages    message.Service
  81	permissions permission.Service
  82	history     history.Service
  83	filetracker filetracker.Service
  84	lspManager  *lsp.Manager
  85	notify      pubsub.Publisher[notify.Notification]
  86
  87	currentAgent SessionAgent
  88	agents       map[string]SessionAgent
  89
  90	readyWg errgroup.Group
  91}
  92
  93func NewCoordinator(
  94	ctx context.Context,
  95	cfg *config.ConfigStore,
  96	sessions session.Service,
  97	messages message.Service,
  98	permissions permission.Service,
  99	history history.Service,
 100	filetracker filetracker.Service,
 101	lspManager *lsp.Manager,
 102	notify pubsub.Publisher[notify.Notification],
 103) (Coordinator, error) {
 104	c := &coordinator{
 105		cfg:         cfg,
 106		sessions:    sessions,
 107		messages:    messages,
 108		permissions: permissions,
 109		history:     history,
 110		filetracker: filetracker,
 111		lspManager:  lspManager,
 112		notify:      notify,
 113		agents:      make(map[string]SessionAgent),
 114	}
 115
 116	agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
 117	if !ok {
 118		return nil, errCoderAgentNotConfigured
 119	}
 120
 121	// TODO: make this dynamic when we support multiple agents
 122	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 123	if err != nil {
 124		return nil, err
 125	}
 126
 127	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 128	if err != nil {
 129		return nil, err
 130	}
 131	c.currentAgent = agent
 132	c.agents[config.AgentCoder] = agent
 133	return c, nil
 134}
 135
 136// Run implements Coordinator.
 137func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 138	if err := c.readyWg.Wait(); err != nil {
 139		return nil, err
 140	}
 141
 142	// refresh models before each run
 143	if err := c.UpdateModels(ctx); err != nil {
 144		return nil, fmt.Errorf("failed to update models: %w", err)
 145	}
 146
 147	model := c.currentAgent.Model()
 148	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 149	if model.ModelCfg.MaxTokens != 0 {
 150		maxTokens = model.ModelCfg.MaxTokens
 151	}
 152
 153	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 154		// filter out image attachments
 155		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 156		for _, att := range attachments {
 157			if att.IsText() {
 158				filteredAttachments = append(filteredAttachments, att)
 159			}
 160		}
 161		attachments = filteredAttachments
 162	}
 163
 164	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 165	if !ok {
 166		return nil, errModelProviderNotConfigured
 167	}
 168
 169	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 170
 171	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 172		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 173		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 174			return nil, err
 175		}
 176	}
 177
 178	run := func() (*fantasy.AgentResult, error) {
 179		return c.currentAgent.Run(ctx, SessionAgentCall{
 180			SessionID:        sessionID,
 181			Prompt:           prompt,
 182			Attachments:      attachments,
 183			MaxOutputTokens:  maxTokens,
 184			ProviderOptions:  mergedOptions,
 185			Temperature:      temp,
 186			TopP:             topP,
 187			TopK:             topK,
 188			FrequencyPenalty: freqPenalty,
 189			PresencePenalty:  presPenalty,
 190		})
 191	}
 192	result, originalErr := run()
 193
 194	if c.isUnauthorized(originalErr) {
 195		switch {
 196		case providerCfg.OAuthToken != nil:
 197			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 198			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 199				return nil, originalErr
 200			}
 201			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 202			return run()
 203		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 204			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 205			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 206				return nil, originalErr
 207			}
 208			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 209			return run()
 210		}
 211	}
 212
 213	return result, originalErr
 214}
 215
 216func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 217	options := fantasy.ProviderOptions{}
 218
 219	cfgOpts := []byte("{}")
 220	providerCfgOpts := []byte("{}")
 221	catwalkOpts := []byte("{}")
 222
 223	if model.ModelCfg.ProviderOptions != nil {
 224		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 225		if err == nil {
 226			cfgOpts = data
 227		}
 228	}
 229
 230	if providerCfg.ProviderOptions != nil {
 231		data, err := json.Marshal(providerCfg.ProviderOptions)
 232		if err == nil {
 233			providerCfgOpts = data
 234		}
 235	}
 236
 237	if model.CatwalkCfg.Options.ProviderOptions != nil {
 238		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 239		if err == nil {
 240			catwalkOpts = data
 241		}
 242	}
 243
 244	readers := []io.Reader{
 245		bytes.NewReader(catwalkOpts),
 246		bytes.NewReader(providerCfgOpts),
 247		bytes.NewReader(cfgOpts),
 248	}
 249
 250	got, err := jsons.Merge(readers)
 251	if err != nil {
 252		slog.Error("Could not merge call config", "err", err)
 253		return options
 254	}
 255
 256	mergedOptions := make(map[string]any)
 257
 258	err = json.Unmarshal([]byte(got), &mergedOptions)
 259	if err != nil {
 260		slog.Error("Could not create config for call", "err", err)
 261		return options
 262	}
 263
 264	providerType := providerCfg.Type
 265	if providerType == "hyper" {
 266		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 267			providerType = anthropic.Name
 268		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 269			providerType = openai.Name
 270		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 271			providerType = google.Name
 272		} else {
 273			providerType = openaicompat.Name
 274		}
 275	}
 276
 277	switch providerType {
 278	case openai.Name, azure.Name:
 279		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 280		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 281			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 282		}
 283		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 284			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 285				mergedOptions["reasoning_summary"] = "auto"
 286				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 287			}
 288			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 289			if err == nil {
 290				options[openai.Name] = parsed
 291			}
 292		} else {
 293			parsed, err := openai.ParseOptions(mergedOptions)
 294			if err == nil {
 295				options[openai.Name] = parsed
 296			}
 297		}
 298	case anthropic.Name:
 299		var (
 300			_, hasEffort = mergedOptions["effort"]
 301			_, hasThink  = mergedOptions["thinking"]
 302		)
 303		switch {
 304		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 305			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 306		case !hasThink && model.ModelCfg.Think:
 307			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 308		}
 309		parsed, err := anthropic.ParseOptions(mergedOptions)
 310		if err == nil {
 311			options[anthropic.Name] = parsed
 312		}
 313
 314	case openrouter.Name:
 315		_, hasReasoning := mergedOptions["reasoning"]
 316		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 317			mergedOptions["reasoning"] = map[string]any{
 318				"enabled": true,
 319				"effort":  model.ModelCfg.ReasoningEffort,
 320			}
 321		}
 322		parsed, err := openrouter.ParseOptions(mergedOptions)
 323		if err == nil {
 324			options[openrouter.Name] = parsed
 325		}
 326	case vercel.Name:
 327		_, hasReasoning := mergedOptions["reasoning"]
 328		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 329			mergedOptions["reasoning"] = map[string]any{
 330				"enabled": true,
 331				"effort":  model.ModelCfg.ReasoningEffort,
 332			}
 333		}
 334		parsed, err := vercel.ParseOptions(mergedOptions)
 335		if err == nil {
 336			options[vercel.Name] = parsed
 337		}
 338	case google.Name:
 339		_, hasReasoning := mergedOptions["thinking_config"]
 340		if !hasReasoning {
 341			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 342				mergedOptions["thinking_config"] = map[string]any{
 343					"thinking_budget":  2000,
 344					"include_thoughts": true,
 345				}
 346			} else {
 347				mergedOptions["thinking_config"] = map[string]any{
 348					"thinking_level":   model.ModelCfg.ReasoningEffort,
 349					"include_thoughts": true,
 350				}
 351			}
 352		}
 353		parsed, err := google.ParseOptions(mergedOptions)
 354		if err == nil {
 355			options[google.Name] = parsed
 356		}
 357	case openaicompat.Name:
 358		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 359		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 360			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 361		}
 362		parsed, err := openaicompat.ParseOptions(mergedOptions)
 363		if err == nil {
 364			options[openaicompat.Name] = parsed
 365		}
 366	}
 367
 368	return options
 369}
 370
 371func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 372	modelOptions := getProviderOptions(model, cfg)
 373	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 374	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 375	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 376	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 377	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 378	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 379}
 380
 381func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 382	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 383	if err != nil {
 384		return nil, err
 385	}
 386
 387	largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
 388	result := NewSessionAgent(SessionAgentOptions{
 389		LargeModel:           large,
 390		SmallModel:           small,
 391		SystemPromptPrefix:   largeProviderCfg.SystemPromptPrefix,
 392		SystemPrompt:         "",
 393		IsSubAgent:           isSubAgent,
 394		DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
 395		IsYolo:               c.permissions.SkipRequests(),
 396		Sessions:             c.sessions,
 397		Messages:             c.messages,
 398		Tools:                nil,
 399		Notify:               c.notify,
 400	})
 401
 402	c.readyWg.Go(func() error {
 403		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
 404		if err != nil {
 405			return err
 406		}
 407		result.SetSystemPrompt(systemPrompt)
 408		return nil
 409	})
 410
 411	c.readyWg.Go(func() error {
 412		tools, err := c.buildTools(ctx, agent)
 413		if err != nil {
 414			return err
 415		}
 416		result.SetTools(tools)
 417		return nil
 418	})
 419
 420	return result, nil
 421}
 422
 423func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 424	var allTools []fantasy.AgentTool
 425	if slices.Contains(agent.AllowedTools, AgentToolName) {
 426		agentTool, err := c.agentTool(ctx)
 427		if err != nil {
 428			return nil, err
 429		}
 430		allTools = append(allTools, agentTool)
 431	}
 432
 433	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 434		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 435		if err != nil {
 436			return nil, err
 437		}
 438		allTools = append(allTools, agenticFetchTool)
 439	}
 440
 441	// Get the model name for the agent
 442	modelName := ""
 443	if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
 444		if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 445			modelName = model.Name
 446		}
 447	}
 448
 449	allTools = append(allTools,
 450		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
 451		tools.NewJobOutputTool(),
 452		tools.NewJobKillTool(),
 453		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 454		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 455		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 456		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 457		tools.NewGlobTool(c.cfg.WorkingDir()),
 458		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
 459		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
 460		tools.NewSourcegraphTool(nil),
 461		tools.NewTodosTool(c.sessions),
 462		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
 463		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 464	)
 465
 466	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 467	if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
 468		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 469	}
 470
 471	if len(c.cfg.Config().MCP) > 0 {
 472		allTools = append(
 473			allTools,
 474			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 475			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 476		)
 477	}
 478
 479	var filteredTools []fantasy.AgentTool
 480	for _, tool := range allTools {
 481		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 482			filteredTools = append(filteredTools, tool)
 483		}
 484	}
 485
 486	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 487		if agent.AllowedMCP == nil {
 488			// No MCP restrictions
 489			filteredTools = append(filteredTools, tool)
 490			continue
 491		}
 492		if len(agent.AllowedMCP) == 0 {
 493			// No MCPs allowed
 494			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 495			break
 496		}
 497
 498		for mcp, tools := range agent.AllowedMCP {
 499			if mcp != tool.MCP() {
 500				continue
 501			}
 502			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 503				filteredTools = append(filteredTools, tool)
 504				break
 505			}
 506			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 507		}
 508	}
 509	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 510		return strings.Compare(a.Info().Name, b.Info().Name)
 511	})
 512	return filteredTools, nil
 513}
 514
 515// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 516func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 517	largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
 518	if !ok {
 519		return Model{}, Model{}, errLargeModelNotSelected
 520	}
 521	smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
 522	if !ok {
 523		return Model{}, Model{}, errSmallModelNotSelected
 524	}
 525
 526	largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
 527	if !ok {
 528		return Model{}, Model{}, errLargeModelProviderNotConfigured
 529	}
 530
 531	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 532	if err != nil {
 533		return Model{}, Model{}, err
 534	}
 535
 536	smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
 537	if !ok {
 538		return Model{}, Model{}, errSmallModelProviderNotConfigured
 539	}
 540
 541	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 542	if err != nil {
 543		return Model{}, Model{}, err
 544	}
 545
 546	var largeCatwalkModel *catwalk.Model
 547	var smallCatwalkModel *catwalk.Model
 548
 549	for _, m := range largeProviderCfg.Models {
 550		if m.ID == largeModelCfg.Model {
 551			largeCatwalkModel = &m
 552		}
 553	}
 554	for _, m := range smallProviderCfg.Models {
 555		if m.ID == smallModelCfg.Model {
 556			smallCatwalkModel = &m
 557		}
 558	}
 559
 560	if largeCatwalkModel == nil {
 561		return Model{}, Model{}, errLargeModelNotFound
 562	}
 563
 564	if smallCatwalkModel == nil {
 565		return Model{}, Model{}, errSmallModelNotFound
 566	}
 567
 568	largeModelID := largeModelCfg.Model
 569	smallModelID := smallModelCfg.Model
 570
 571	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 572		largeModelID += ":exacto"
 573	}
 574
 575	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 576		smallModelID += ":exacto"
 577	}
 578
 579	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 580	if err != nil {
 581		return Model{}, Model{}, err
 582	}
 583	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 584	if err != nil {
 585		return Model{}, Model{}, err
 586	}
 587
 588	return Model{
 589			Model:      largeModel,
 590			CatwalkCfg: *largeCatwalkModel,
 591			ModelCfg:   largeModelCfg,
 592		}, Model{
 593			Model:      smallModel,
 594			CatwalkCfg: *smallCatwalkModel,
 595			ModelCfg:   smallModelCfg,
 596		}, nil
 597}
 598
 599func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 600	var opts []anthropic.Option
 601
 602	switch {
 603	case strings.HasPrefix(apiKey, "Bearer "):
 604		// NOTE: Prevent the SDK from picking up the API key from env.
 605		os.Setenv("ANTHROPIC_API_KEY", "")
 606		headers["Authorization"] = apiKey
 607	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 608		// NOTE: Prevent the SDK from picking up the API key from env.
 609		os.Setenv("ANTHROPIC_API_KEY", "")
 610		headers["Authorization"] = "Bearer " + apiKey
 611	case apiKey != "":
 612		// X-Api-Key header
 613		opts = append(opts, anthropic.WithAPIKey(apiKey))
 614	}
 615
 616	if len(headers) > 0 {
 617		opts = append(opts, anthropic.WithHeaders(headers))
 618	}
 619
 620	if baseURL != "" {
 621		opts = append(opts, anthropic.WithBaseURL(baseURL))
 622	}
 623
 624	if c.cfg.Config().Options.Debug {
 625		httpClient := log.NewHTTPClient()
 626		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 627	}
 628	return anthropic.New(opts...)
 629}
 630
 631func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 632	opts := []openai.Option{
 633		openai.WithAPIKey(apiKey),
 634		openai.WithUseResponsesAPI(),
 635	}
 636	if c.cfg.Config().Options.Debug {
 637		httpClient := log.NewHTTPClient()
 638		opts = append(opts, openai.WithHTTPClient(httpClient))
 639	}
 640	if len(headers) > 0 {
 641		opts = append(opts, openai.WithHeaders(headers))
 642	}
 643	if baseURL != "" {
 644		opts = append(opts, openai.WithBaseURL(baseURL))
 645	}
 646	return openai.New(opts...)
 647}
 648
 649func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 650	opts := []openrouter.Option{
 651		openrouter.WithAPIKey(apiKey),
 652	}
 653	if c.cfg.Config().Options.Debug {
 654		httpClient := log.NewHTTPClient()
 655		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 656	}
 657	if len(headers) > 0 {
 658		opts = append(opts, openrouter.WithHeaders(headers))
 659	}
 660	return openrouter.New(opts...)
 661}
 662
 663func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 664	opts := []vercel.Option{
 665		vercel.WithAPIKey(apiKey),
 666	}
 667	if c.cfg.Config().Options.Debug {
 668		httpClient := log.NewHTTPClient()
 669		opts = append(opts, vercel.WithHTTPClient(httpClient))
 670	}
 671	if len(headers) > 0 {
 672		opts = append(opts, vercel.WithHeaders(headers))
 673	}
 674	return vercel.New(opts...)
 675}
 676
 677func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 678	opts := []openaicompat.Option{
 679		openaicompat.WithBaseURL(baseURL),
 680		openaicompat.WithAPIKey(apiKey),
 681	}
 682
 683	// Set HTTP client based on provider and debug mode.
 684	var httpClient *http.Client
 685	if providerID == string(catwalk.InferenceProviderCopilot) {
 686		opts = append(opts, openaicompat.WithUseResponsesAPI())
 687		httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
 688	} else if c.cfg.Config().Options.Debug {
 689		httpClient = log.NewHTTPClient()
 690	}
 691	if httpClient != nil {
 692		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 693	}
 694
 695	if len(headers) > 0 {
 696		opts = append(opts, openaicompat.WithHeaders(headers))
 697	}
 698
 699	for extraKey, extraValue := range extraBody {
 700		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 701	}
 702
 703	return openaicompat.New(opts...)
 704}
 705
 706func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 707	opts := []azure.Option{
 708		azure.WithBaseURL(baseURL),
 709		azure.WithAPIKey(apiKey),
 710		azure.WithUseResponsesAPI(),
 711	}
 712	if c.cfg.Config().Options.Debug {
 713		httpClient := log.NewHTTPClient()
 714		opts = append(opts, azure.WithHTTPClient(httpClient))
 715	}
 716	if options == nil {
 717		options = make(map[string]string)
 718	}
 719	if apiVersion, ok := options["apiVersion"]; ok {
 720		opts = append(opts, azure.WithAPIVersion(apiVersion))
 721	}
 722	if len(headers) > 0 {
 723		opts = append(opts, azure.WithHeaders(headers))
 724	}
 725
 726	return azure.New(opts...)
 727}
 728
 729func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
 730	var opts []bedrock.Option
 731	if c.cfg.Config().Options.Debug {
 732		httpClient := log.NewHTTPClient()
 733		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 734	}
 735	if len(headers) > 0 {
 736		opts = append(opts, bedrock.WithHeaders(headers))
 737	}
 738	switch {
 739	case apiKey != "":
 740		opts = append(opts, bedrock.WithAPIKey(apiKey))
 741	case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
 742		opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
 743	default:
 744		// Skip, let the SDK do authentication.
 745	}
 746	return bedrock.New(opts...)
 747}
 748
 749func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 750	opts := []google.Option{
 751		google.WithBaseURL(baseURL),
 752		google.WithGeminiAPIKey(apiKey),
 753	}
 754	if c.cfg.Config().Options.Debug {
 755		httpClient := log.NewHTTPClient()
 756		opts = append(opts, google.WithHTTPClient(httpClient))
 757	}
 758	if len(headers) > 0 {
 759		opts = append(opts, google.WithHeaders(headers))
 760	}
 761	return google.New(opts...)
 762}
 763
 764func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 765	opts := []google.Option{}
 766	if c.cfg.Config().Options.Debug {
 767		httpClient := log.NewHTTPClient()
 768		opts = append(opts, google.WithHTTPClient(httpClient))
 769	}
 770	if len(headers) > 0 {
 771		opts = append(opts, google.WithHeaders(headers))
 772	}
 773
 774	project := options["project"]
 775	location := options["location"]
 776
 777	opts = append(opts, google.WithVertex(project, location))
 778
 779	return google.New(opts...)
 780}
 781
 782func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
 783	opts := []hyper.Option{
 784		hyper.WithBaseURL(baseURL),
 785		hyper.WithAPIKey(apiKey),
 786	}
 787	if c.cfg.Config().Options.Debug {
 788		httpClient := log.NewHTTPClient()
 789		opts = append(opts, hyper.WithHTTPClient(httpClient))
 790	}
 791	return hyper.New(opts...)
 792}
 793
 794func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 795	if model.Think {
 796		return true
 797	}
 798	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 799	return err == nil && opts.Thinking != nil
 800}
 801
 802func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 803	headers := maps.Clone(providerCfg.ExtraHeaders)
 804	if headers == nil {
 805		headers = make(map[string]string)
 806	}
 807
 808	// handle special headers for anthropic
 809	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 810		if v, ok := headers["anthropic-beta"]; ok {
 811			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 812		} else {
 813			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 814		}
 815	}
 816
 817	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 818	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 819
 820	switch providerCfg.Type {
 821	case openai.Name:
 822		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 823	case anthropic.Name:
 824		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 825	case openrouter.Name:
 826		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 827	case vercel.Name:
 828		return c.buildVercelProvider(baseURL, apiKey, headers)
 829	case azure.Name:
 830		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 831	case bedrock.Name:
 832		return c.buildBedrockProvider(apiKey, headers)
 833	case google.Name:
 834		return c.buildGoogleProvider(baseURL, apiKey, headers)
 835	case "google-vertex":
 836		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 837	case openaicompat.Name:
 838		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 839			if providerCfg.ExtraBody == nil {
 840				providerCfg.ExtraBody = map[string]any{}
 841			}
 842			providerCfg.ExtraBody["tool_stream"] = true
 843		}
 844		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 845	case hyper.Name:
 846		return c.buildHyperProvider(baseURL, apiKey)
 847	default:
 848		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 849	}
 850}
 851
 852func isExactoSupported(modelID string) bool {
 853	supportedModels := []string{
 854		"moonshotai/kimi-k2-0905",
 855		"deepseek/deepseek-v3.1-terminus",
 856		"z-ai/glm-4.6",
 857		"openai/gpt-oss-120b",
 858		"qwen/qwen3-coder",
 859	}
 860	return slices.Contains(supportedModels, modelID)
 861}
 862
 863func (c *coordinator) Cancel(sessionID string) {
 864	c.currentAgent.Cancel(sessionID)
 865}
 866
 867func (c *coordinator) CancelAll() {
 868	c.currentAgent.CancelAll()
 869}
 870
 871func (c *coordinator) ClearQueue(sessionID string) {
 872	c.currentAgent.ClearQueue(sessionID)
 873}
 874
 875func (c *coordinator) IsBusy() bool {
 876	return c.currentAgent.IsBusy()
 877}
 878
 879func (c *coordinator) IsSessionBusy(sessionID string) bool {
 880	return c.currentAgent.IsSessionBusy(sessionID)
 881}
 882
 883func (c *coordinator) Model() Model {
 884	return c.currentAgent.Model()
 885}
 886
 887func (c *coordinator) UpdateModels(ctx context.Context) error {
 888	// build the models again so we make sure we get the latest config
 889	large, small, err := c.buildAgentModels(ctx, false)
 890	if err != nil {
 891		return err
 892	}
 893	c.currentAgent.SetModels(large, small)
 894
 895	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 896	if !ok {
 897		return errCoderAgentNotConfigured
 898	}
 899
 900	tools, err := c.buildTools(ctx, agentCfg)
 901	if err != nil {
 902		return err
 903	}
 904	c.currentAgent.SetTools(tools)
 905	return nil
 906}
 907
 908func (c *coordinator) RefreshTools(ctx context.Context) error {
 909	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 910	if !ok {
 911		return errors.New("coder agent not configured")
 912	}
 913
 914	tools, err := c.buildTools(ctx, agentCfg)
 915	if err != nil {
 916		return err
 917	}
 918	c.currentAgent.SetTools(tools)
 919	slog.Debug("refreshed agent tools", "count", len(tools))
 920	return nil
 921}
 922
 923func (c *coordinator) QueuedPrompts(sessionID string) int {
 924	return c.currentAgent.QueuedPrompts(sessionID)
 925}
 926
 927func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 928	return c.currentAgent.QueuedPromptsList(sessionID)
 929}
 930
 931func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 932	providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 933	if !ok {
 934		return errModelProviderNotConfigured
 935	}
 936	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 937}
 938
 939func (c *coordinator) isUnauthorized(err error) bool {
 940	var providerErr *fantasy.ProviderError
 941	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
 942}
 943
 944func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 945	if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
 946		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 947		return err
 948	}
 949	if err := c.UpdateModels(ctx); err != nil {
 950		return err
 951	}
 952	return nil
 953}
 954
 955func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 956	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 957	if err != nil {
 958		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 959		return err
 960	}
 961
 962	providerCfg.APIKey = newAPIKey
 963	c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
 964
 965	if err := c.UpdateModels(ctx); err != nil {
 966		return err
 967	}
 968	return nil
 969}
 970
 971// subAgentParams holds the parameters for running a sub-agent.
 972type subAgentParams struct {
 973	Agent          SessionAgent
 974	SessionID      string
 975	AgentMessageID string
 976	ToolCallID     string
 977	Prompt         string
 978	SessionTitle   string
 979	// SessionSetup is an optional callback invoked after session creation
 980	// but before agent execution, for custom session configuration.
 981	SessionSetup func(sessionID string)
 982}
 983
 984// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 985// It creates a sub-session, runs the agent with the given prompt, and propagates
 986// the cost to the parent session.
 987func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 988	// Create sub-session
 989	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 990	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 991	if err != nil {
 992		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 993	}
 994
 995	// Call session setup function if provided
 996	if params.SessionSetup != nil {
 997		params.SessionSetup(session.ID)
 998	}
 999
1000	// Get model configuration
1001	model := params.Agent.Model()
1002	maxTokens := model.CatwalkCfg.DefaultMaxTokens
1003	if model.ModelCfg.MaxTokens != 0 {
1004		maxTokens = model.ModelCfg.MaxTokens
1005	}
1006
1007	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
1008	if !ok {
1009		return fantasy.ToolResponse{}, errModelProviderNotConfigured
1010	}
1011
1012	// Run the agent
1013	result, err := params.Agent.Run(ctx, SessionAgentCall{
1014		SessionID:        session.ID,
1015		Prompt:           params.Prompt,
1016		MaxOutputTokens:  maxTokens,
1017		ProviderOptions:  getProviderOptions(model, providerCfg),
1018		Temperature:      model.ModelCfg.Temperature,
1019		TopP:             model.ModelCfg.TopP,
1020		TopK:             model.ModelCfg.TopK,
1021		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1022		PresencePenalty:  model.ModelCfg.PresencePenalty,
1023		NonInteractive:   true,
1024	})
1025	if err != nil {
1026		return fantasy.NewTextErrorResponse("error generating response"), nil
1027	}
1028
1029	// Update parent session cost
1030	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1031		return fantasy.ToolResponse{}, err
1032	}
1033
1034	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1035}
1036
1037// updateParentSessionCost accumulates the cost from a child session to its parent session.
1038func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1039	childSession, err := c.sessions.Get(ctx, childSessionID)
1040	if err != nil {
1041		return fmt.Errorf("get child session: %w", err)
1042	}
1043
1044	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1045	if err != nil {
1046		return fmt.Errorf("get parent session: %w", err)
1047	}
1048
1049	parentSession.Cost += childSession.Cost
1050
1051	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1052		return fmt.Errorf("save parent session: %w", err)
1053	}
1054
1055	return nil
1056}