coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"slices"
  16	"strings"
  17
  18	"charm.land/catwalk/pkg/catwalk"
  19	"charm.land/fantasy"
  20	"github.com/charmbracelet/crush/internal/agent/hyper"
  21	"github.com/charmbracelet/crush/internal/agent/prompt"
  22	"github.com/charmbracelet/crush/internal/agent/tools"
  23	"github.com/charmbracelet/crush/internal/config"
  24	"github.com/charmbracelet/crush/internal/filetracker"
  25	"github.com/charmbracelet/crush/internal/history"
  26	"github.com/charmbracelet/crush/internal/log"
  27	"github.com/charmbracelet/crush/internal/lsp"
  28	"github.com/charmbracelet/crush/internal/message"
  29	"github.com/charmbracelet/crush/internal/oauth/copilot"
  30	"github.com/charmbracelet/crush/internal/permission"
  31	"github.com/charmbracelet/crush/internal/session"
  32	"golang.org/x/sync/errgroup"
  33
  34	"charm.land/fantasy/providers/anthropic"
  35	"charm.land/fantasy/providers/azure"
  36	"charm.land/fantasy/providers/bedrock"
  37	"charm.land/fantasy/providers/google"
  38	"charm.land/fantasy/providers/openai"
  39	"charm.land/fantasy/providers/openaicompat"
  40	"charm.land/fantasy/providers/openrouter"
  41	"charm.land/fantasy/providers/vercel"
  42	openaisdk "github.com/openai/openai-go/v2/option"
  43	"github.com/qjebbs/go-jsons"
  44)
  45
  46// Coordinator errors.
  47var (
  48	errCoderAgentNotConfigured         = errors.New("coder agent not configured")
  49	errModelProviderNotConfigured      = errors.New("model provider not configured")
  50	errLargeModelNotSelected           = errors.New("large model not selected")
  51	errSmallModelNotSelected           = errors.New("small model not selected")
  52	errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  53	errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  54	errLargeModelNotFound              = errors.New("large model not found in provider config")
  55	errSmallModelNotFound              = errors.New("small model not found in provider config")
  56)
  57
  58type Coordinator interface {
  59	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  60	// SetMainAgent(string)
  61	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  62	Cancel(sessionID string)
  63	CancelAll()
  64	IsSessionBusy(sessionID string) bool
  65	IsBusy() bool
  66	QueuedPrompts(sessionID string) int
  67	QueuedPromptsList(sessionID string) []string
  68	ClearQueue(sessionID string)
  69	Summarize(context.Context, string) error
  70	Model() Model
  71	UpdateModels(ctx context.Context) error
  72}
  73
  74type coordinator struct {
  75	cfg         *config.Config
  76	sessions    session.Service
  77	messages    message.Service
  78	permissions permission.Service
  79	history     history.Service
  80	filetracker filetracker.Service
  81	lspManager  *lsp.Manager
  82
  83	currentAgent SessionAgent
  84	agents       map[string]SessionAgent
  85
  86	readyWg errgroup.Group
  87}
  88
  89func NewCoordinator(
  90	ctx context.Context,
  91	cfg *config.Config,
  92	sessions session.Service,
  93	messages message.Service,
  94	permissions permission.Service,
  95	history history.Service,
  96	filetracker filetracker.Service,
  97	lspManager *lsp.Manager,
  98) (Coordinator, error) {
  99	c := &coordinator{
 100		cfg:         cfg,
 101		sessions:    sessions,
 102		messages:    messages,
 103		permissions: permissions,
 104		history:     history,
 105		filetracker: filetracker,
 106		lspManager:  lspManager,
 107		agents:      make(map[string]SessionAgent),
 108	}
 109
 110	agentCfg, ok := cfg.Agents[config.AgentCoder]
 111	if !ok {
 112		return nil, errCoderAgentNotConfigured
 113	}
 114
 115	// TODO: make this dynamic when we support multiple agents
 116	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 117	if err != nil {
 118		return nil, err
 119	}
 120
 121	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 122	if err != nil {
 123		return nil, err
 124	}
 125	c.currentAgent = agent
 126	c.agents[config.AgentCoder] = agent
 127	return c, nil
 128}
 129
 130// Run implements Coordinator.
 131func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 132	if err := c.readyWg.Wait(); err != nil {
 133		return nil, err
 134	}
 135
 136	// refresh models before each run
 137	if err := c.UpdateModels(ctx); err != nil {
 138		return nil, fmt.Errorf("failed to update models: %w", err)
 139	}
 140
 141	model := c.currentAgent.Model()
 142	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 143	if model.ModelCfg.MaxTokens != 0 {
 144		maxTokens = model.ModelCfg.MaxTokens
 145	}
 146
 147	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 148		// filter out image attachments
 149		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 150		for _, att := range attachments {
 151			if att.IsText() {
 152				filteredAttachments = append(filteredAttachments, att)
 153			}
 154		}
 155		attachments = filteredAttachments
 156	}
 157
 158	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
 159	if !ok {
 160		return nil, errModelProviderNotConfigured
 161	}
 162
 163	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 164
 165	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 166		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 167		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 168			return nil, err
 169		}
 170	}
 171
 172	run := func() (*fantasy.AgentResult, error) {
 173		return c.currentAgent.Run(ctx, SessionAgentCall{
 174			SessionID:        sessionID,
 175			Prompt:           prompt,
 176			Attachments:      attachments,
 177			MaxOutputTokens:  maxTokens,
 178			ProviderOptions:  mergedOptions,
 179			Temperature:      temp,
 180			TopP:             topP,
 181			TopK:             topK,
 182			FrequencyPenalty: freqPenalty,
 183			PresencePenalty:  presPenalty,
 184		})
 185	}
 186	result, originalErr := run()
 187
 188	if c.isUnauthorized(originalErr) {
 189		switch {
 190		case providerCfg.OAuthToken != nil:
 191			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 192			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 193				return nil, originalErr
 194			}
 195			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 196			return run()
 197		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 198			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 199			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 200				return nil, originalErr
 201			}
 202			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 203			return run()
 204		}
 205	}
 206
 207	return result, originalErr
 208}
 209
 210func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 211	options := fantasy.ProviderOptions{}
 212
 213	cfgOpts := []byte("{}")
 214	providerCfgOpts := []byte("{}")
 215	catwalkOpts := []byte("{}")
 216
 217	if model.ModelCfg.ProviderOptions != nil {
 218		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 219		if err == nil {
 220			cfgOpts = data
 221		}
 222	}
 223
 224	if providerCfg.ProviderOptions != nil {
 225		data, err := json.Marshal(providerCfg.ProviderOptions)
 226		if err == nil {
 227			providerCfgOpts = data
 228		}
 229	}
 230
 231	if model.CatwalkCfg.Options.ProviderOptions != nil {
 232		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 233		if err == nil {
 234			catwalkOpts = data
 235		}
 236	}
 237
 238	readers := []io.Reader{
 239		bytes.NewReader(catwalkOpts),
 240		bytes.NewReader(providerCfgOpts),
 241		bytes.NewReader(cfgOpts),
 242	}
 243
 244	got, err := jsons.Merge(readers)
 245	if err != nil {
 246		slog.Error("Could not merge call config", "err", err)
 247		return options
 248	}
 249
 250	mergedOptions := make(map[string]any)
 251
 252	err = json.Unmarshal([]byte(got), &mergedOptions)
 253	if err != nil {
 254		slog.Error("Could not create config for call", "err", err)
 255		return options
 256	}
 257
 258	providerType := providerCfg.Type
 259	if providerType == "hyper" {
 260		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 261			providerType = anthropic.Name
 262		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 263			providerType = openai.Name
 264		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 265			providerType = google.Name
 266		} else {
 267			providerType = openaicompat.Name
 268		}
 269	}
 270
 271	switch providerType {
 272	case openai.Name, azure.Name:
 273		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 274		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 275			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 276		}
 277		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 278			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 279				mergedOptions["reasoning_summary"] = "auto"
 280				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 281			}
 282			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 283			if err == nil {
 284				options[openai.Name] = parsed
 285			}
 286		} else {
 287			parsed, err := openai.ParseOptions(mergedOptions)
 288			if err == nil {
 289				options[openai.Name] = parsed
 290			}
 291		}
 292	case anthropic.Name:
 293		var (
 294			_, hasEffort = mergedOptions["effort"]
 295			_, hasThink  = mergedOptions["thinking"]
 296		)
 297		switch {
 298		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 299			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 300		case !hasThink && model.ModelCfg.Think:
 301			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 302		}
 303		parsed, err := anthropic.ParseOptions(mergedOptions)
 304		if err == nil {
 305			options[anthropic.Name] = parsed
 306		}
 307
 308	case openrouter.Name:
 309		_, hasReasoning := mergedOptions["reasoning"]
 310		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 311			mergedOptions["reasoning"] = map[string]any{
 312				"enabled": true,
 313				"effort":  model.ModelCfg.ReasoningEffort,
 314			}
 315		}
 316		parsed, err := openrouter.ParseOptions(mergedOptions)
 317		if err == nil {
 318			options[openrouter.Name] = parsed
 319		}
 320	case vercel.Name:
 321		_, hasReasoning := mergedOptions["reasoning"]
 322		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 323			mergedOptions["reasoning"] = map[string]any{
 324				"enabled": true,
 325				"effort":  model.ModelCfg.ReasoningEffort,
 326			}
 327		}
 328		parsed, err := vercel.ParseOptions(mergedOptions)
 329		if err == nil {
 330			options[vercel.Name] = parsed
 331		}
 332	case google.Name:
 333		_, hasReasoning := mergedOptions["thinking_config"]
 334		if !hasReasoning {
 335			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 336				mergedOptions["thinking_config"] = map[string]any{
 337					"thinking_budget":  2000,
 338					"include_thoughts": true,
 339				}
 340			} else {
 341				mergedOptions["thinking_config"] = map[string]any{
 342					"thinking_level":   model.ModelCfg.ReasoningEffort,
 343					"include_thoughts": true,
 344				}
 345			}
 346		}
 347		parsed, err := google.ParseOptions(mergedOptions)
 348		if err == nil {
 349			options[google.Name] = parsed
 350		}
 351	case openaicompat.Name:
 352		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 353		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 354			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 355		}
 356		parsed, err := openaicompat.ParseOptions(mergedOptions)
 357		if err == nil {
 358			options[openaicompat.Name] = parsed
 359		}
 360	}
 361
 362	return options
 363}
 364
 365func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 366	modelOptions := getProviderOptions(model, cfg)
 367	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 368	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 369	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 370	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 371	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 372	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 373}
 374
 375func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 376	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 377	if err != nil {
 378		return nil, err
 379	}
 380
 381	largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
 382	result := NewSessionAgent(SessionAgentOptions{
 383		large,
 384		small,
 385		largeProviderCfg.SystemPromptPrefix,
 386		"",
 387		isSubAgent,
 388		c.cfg.Options.DisableAutoSummarize,
 389		c.permissions.SkipRequests(),
 390		c.sessions,
 391		c.messages,
 392		nil,
 393	})
 394
 395	c.readyWg.Go(func() error {
 396		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
 397		if err != nil {
 398			return err
 399		}
 400		result.SetSystemPrompt(systemPrompt)
 401		return nil
 402	})
 403
 404	c.readyWg.Go(func() error {
 405		tools, err := c.buildTools(ctx, agent)
 406		if err != nil {
 407			return err
 408		}
 409		result.SetTools(tools)
 410		return nil
 411	})
 412
 413	return result, nil
 414}
 415
 416func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 417	var allTools []fantasy.AgentTool
 418	if slices.Contains(agent.AllowedTools, AgentToolName) {
 419		agentTool, err := c.agentTool(ctx)
 420		if err != nil {
 421			return nil, err
 422		}
 423		allTools = append(allTools, agentTool)
 424	}
 425
 426	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 427		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 428		if err != nil {
 429			return nil, err
 430		}
 431		allTools = append(allTools, agenticFetchTool)
 432	}
 433
 434	// Get the model name for the agent
 435	modelName := ""
 436	if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
 437		if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 438			modelName = model.Name
 439		}
 440	}
 441
 442	allTools = append(allTools,
 443		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
 444		tools.NewJobOutputTool(),
 445		tools.NewJobKillTool(),
 446		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 447		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 448		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 449		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 450		tools.NewGlobTool(c.cfg.WorkingDir()),
 451		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
 452		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
 453		tools.NewSourcegraphTool(nil),
 454		tools.NewTodosTool(c.sessions),
 455		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
 456		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 457	)
 458
 459	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 460	if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
 461		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 462	}
 463
 464	if len(c.cfg.MCP) > 0 {
 465		allTools = append(
 466			allTools,
 467			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 468			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 469		)
 470	}
 471
 472	var filteredTools []fantasy.AgentTool
 473	for _, tool := range allTools {
 474		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 475			filteredTools = append(filteredTools, tool)
 476		}
 477	}
 478
 479	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 480		if agent.AllowedMCP == nil {
 481			// No MCP restrictions
 482			filteredTools = append(filteredTools, tool)
 483			continue
 484		}
 485		if len(agent.AllowedMCP) == 0 {
 486			// No MCPs allowed
 487			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 488			break
 489		}
 490
 491		for mcp, tools := range agent.AllowedMCP {
 492			if mcp != tool.MCP() {
 493				continue
 494			}
 495			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 496				filteredTools = append(filteredTools, tool)
 497				break
 498			}
 499			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 500		}
 501	}
 502	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 503		return strings.Compare(a.Info().Name, b.Info().Name)
 504	})
 505	return filteredTools, nil
 506}
 507
 508// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 509func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 510	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
 511	if !ok {
 512		return Model{}, Model{}, errLargeModelNotSelected
 513	}
 514	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
 515	if !ok {
 516		return Model{}, Model{}, errSmallModelNotSelected
 517	}
 518
 519	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
 520	if !ok {
 521		return Model{}, Model{}, errLargeModelProviderNotConfigured
 522	}
 523
 524	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 525	if err != nil {
 526		return Model{}, Model{}, err
 527	}
 528
 529	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
 530	if !ok {
 531		return Model{}, Model{}, errSmallModelProviderNotConfigured
 532	}
 533
 534	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 535	if err != nil {
 536		return Model{}, Model{}, err
 537	}
 538
 539	var largeCatwalkModel *catwalk.Model
 540	var smallCatwalkModel *catwalk.Model
 541
 542	for _, m := range largeProviderCfg.Models {
 543		if m.ID == largeModelCfg.Model {
 544			largeCatwalkModel = &m
 545		}
 546	}
 547	for _, m := range smallProviderCfg.Models {
 548		if m.ID == smallModelCfg.Model {
 549			smallCatwalkModel = &m
 550		}
 551	}
 552
 553	if largeCatwalkModel == nil {
 554		return Model{}, Model{}, errLargeModelNotFound
 555	}
 556
 557	if smallCatwalkModel == nil {
 558		return Model{}, Model{}, errSmallModelNotFound
 559	}
 560
 561	largeModelID := largeModelCfg.Model
 562	smallModelID := smallModelCfg.Model
 563
 564	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 565		largeModelID += ":exacto"
 566	}
 567
 568	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 569		smallModelID += ":exacto"
 570	}
 571
 572	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 573	if err != nil {
 574		return Model{}, Model{}, err
 575	}
 576	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 577	if err != nil {
 578		return Model{}, Model{}, err
 579	}
 580
 581	return Model{
 582			Model:      largeModel,
 583			CatwalkCfg: *largeCatwalkModel,
 584			ModelCfg:   largeModelCfg,
 585		}, Model{
 586			Model:      smallModel,
 587			CatwalkCfg: *smallCatwalkModel,
 588			ModelCfg:   smallModelCfg,
 589		}, nil
 590}
 591
 592func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 593	var opts []anthropic.Option
 594
 595	switch {
 596	case strings.HasPrefix(apiKey, "Bearer "):
 597		// NOTE: Prevent the SDK from picking up the API key from env.
 598		os.Setenv("ANTHROPIC_API_KEY", "")
 599		headers["Authorization"] = apiKey
 600	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 601		// NOTE: Prevent the SDK from picking up the API key from env.
 602		os.Setenv("ANTHROPIC_API_KEY", "")
 603		headers["Authorization"] = "Bearer " + apiKey
 604	case apiKey != "":
 605		// X-Api-Key header
 606		opts = append(opts, anthropic.WithAPIKey(apiKey))
 607	}
 608
 609	if len(headers) > 0 {
 610		opts = append(opts, anthropic.WithHeaders(headers))
 611	}
 612
 613	if baseURL != "" {
 614		opts = append(opts, anthropic.WithBaseURL(baseURL))
 615	}
 616
 617	if c.cfg.Options.Debug {
 618		httpClient := log.NewHTTPClient()
 619		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 620	}
 621	return anthropic.New(opts...)
 622}
 623
 624func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 625	opts := []openai.Option{
 626		openai.WithAPIKey(apiKey),
 627		openai.WithUseResponsesAPI(),
 628	}
 629	if c.cfg.Options.Debug {
 630		httpClient := log.NewHTTPClient()
 631		opts = append(opts, openai.WithHTTPClient(httpClient))
 632	}
 633	if len(headers) > 0 {
 634		opts = append(opts, openai.WithHeaders(headers))
 635	}
 636	if baseURL != "" {
 637		opts = append(opts, openai.WithBaseURL(baseURL))
 638	}
 639	return openai.New(opts...)
 640}
 641
 642func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 643	opts := []openrouter.Option{
 644		openrouter.WithAPIKey(apiKey),
 645	}
 646	if c.cfg.Options.Debug {
 647		httpClient := log.NewHTTPClient()
 648		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 649	}
 650	if len(headers) > 0 {
 651		opts = append(opts, openrouter.WithHeaders(headers))
 652	}
 653	return openrouter.New(opts...)
 654}
 655
 656func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 657	opts := []vercel.Option{
 658		vercel.WithAPIKey(apiKey),
 659	}
 660	if c.cfg.Options.Debug {
 661		httpClient := log.NewHTTPClient()
 662		opts = append(opts, vercel.WithHTTPClient(httpClient))
 663	}
 664	if len(headers) > 0 {
 665		opts = append(opts, vercel.WithHeaders(headers))
 666	}
 667	return vercel.New(opts...)
 668}
 669
 670func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 671	opts := []openaicompat.Option{
 672		openaicompat.WithBaseURL(baseURL),
 673		openaicompat.WithAPIKey(apiKey),
 674	}
 675
 676	// Set HTTP client based on provider and debug mode.
 677	var httpClient *http.Client
 678	if providerID == string(catwalk.InferenceProviderCopilot) {
 679		opts = append(opts, openaicompat.WithUseResponsesAPI())
 680		httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
 681	} else if c.cfg.Options.Debug {
 682		httpClient = log.NewHTTPClient()
 683	}
 684	if httpClient != nil {
 685		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 686	}
 687
 688	if len(headers) > 0 {
 689		opts = append(opts, openaicompat.WithHeaders(headers))
 690	}
 691
 692	for extraKey, extraValue := range extraBody {
 693		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 694	}
 695
 696	return openaicompat.New(opts...)
 697}
 698
 699func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 700	opts := []azure.Option{
 701		azure.WithBaseURL(baseURL),
 702		azure.WithAPIKey(apiKey),
 703		azure.WithUseResponsesAPI(),
 704	}
 705	if c.cfg.Options.Debug {
 706		httpClient := log.NewHTTPClient()
 707		opts = append(opts, azure.WithHTTPClient(httpClient))
 708	}
 709	if options == nil {
 710		options = make(map[string]string)
 711	}
 712	if apiVersion, ok := options["apiVersion"]; ok {
 713		opts = append(opts, azure.WithAPIVersion(apiVersion))
 714	}
 715	if len(headers) > 0 {
 716		opts = append(opts, azure.WithHeaders(headers))
 717	}
 718
 719	return azure.New(opts...)
 720}
 721
 722func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
 723	var opts []bedrock.Option
 724	if c.cfg.Options.Debug {
 725		httpClient := log.NewHTTPClient()
 726		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 727	}
 728	if len(headers) > 0 {
 729		opts = append(opts, bedrock.WithHeaders(headers))
 730	}
 731	bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
 732	if bearerToken != "" {
 733		opts = append(opts, bedrock.WithAPIKey(bearerToken))
 734	}
 735	return bedrock.New(opts...)
 736}
 737
 738func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 739	opts := []google.Option{
 740		google.WithBaseURL(baseURL),
 741		google.WithGeminiAPIKey(apiKey),
 742	}
 743	if c.cfg.Options.Debug {
 744		httpClient := log.NewHTTPClient()
 745		opts = append(opts, google.WithHTTPClient(httpClient))
 746	}
 747	if len(headers) > 0 {
 748		opts = append(opts, google.WithHeaders(headers))
 749	}
 750	return google.New(opts...)
 751}
 752
 753func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 754	opts := []google.Option{}
 755	if c.cfg.Options.Debug {
 756		httpClient := log.NewHTTPClient()
 757		opts = append(opts, google.WithHTTPClient(httpClient))
 758	}
 759	if len(headers) > 0 {
 760		opts = append(opts, google.WithHeaders(headers))
 761	}
 762
 763	project := options["project"]
 764	location := options["location"]
 765
 766	opts = append(opts, google.WithVertex(project, location))
 767
 768	return google.New(opts...)
 769}
 770
 771func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
 772	opts := []hyper.Option{
 773		hyper.WithBaseURL(baseURL),
 774		hyper.WithAPIKey(apiKey),
 775	}
 776	if c.cfg.Options.Debug {
 777		httpClient := log.NewHTTPClient()
 778		opts = append(opts, hyper.WithHTTPClient(httpClient))
 779	}
 780	return hyper.New(opts...)
 781}
 782
 783func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 784	if model.Think {
 785		return true
 786	}
 787	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 788	return err == nil && opts.Thinking != nil
 789}
 790
 791func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 792	headers := maps.Clone(providerCfg.ExtraHeaders)
 793	if headers == nil {
 794		headers = make(map[string]string)
 795	}
 796
 797	// handle special headers for anthropic
 798	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 799		if v, ok := headers["anthropic-beta"]; ok {
 800			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 801		} else {
 802			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 803		}
 804	}
 805
 806	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 807	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 808
 809	switch providerCfg.Type {
 810	case openai.Name:
 811		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 812	case anthropic.Name:
 813		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 814	case openrouter.Name:
 815		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 816	case vercel.Name:
 817		return c.buildVercelProvider(baseURL, apiKey, headers)
 818	case azure.Name:
 819		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 820	case bedrock.Name:
 821		return c.buildBedrockProvider(headers)
 822	case google.Name:
 823		return c.buildGoogleProvider(baseURL, apiKey, headers)
 824	case "google-vertex":
 825		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 826	case openaicompat.Name:
 827		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 828			if providerCfg.ExtraBody == nil {
 829				providerCfg.ExtraBody = map[string]any{}
 830			}
 831			providerCfg.ExtraBody["tool_stream"] = true
 832		}
 833		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 834	case hyper.Name:
 835		return c.buildHyperProvider(baseURL, apiKey)
 836	default:
 837		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 838	}
 839}
 840
 841func isExactoSupported(modelID string) bool {
 842	supportedModels := []string{
 843		"moonshotai/kimi-k2-0905",
 844		"deepseek/deepseek-v3.1-terminus",
 845		"z-ai/glm-4.6",
 846		"openai/gpt-oss-120b",
 847		"qwen/qwen3-coder",
 848	}
 849	return slices.Contains(supportedModels, modelID)
 850}
 851
 852func (c *coordinator) Cancel(sessionID string) {
 853	c.currentAgent.Cancel(sessionID)
 854}
 855
 856func (c *coordinator) CancelAll() {
 857	c.currentAgent.CancelAll()
 858}
 859
 860func (c *coordinator) ClearQueue(sessionID string) {
 861	c.currentAgent.ClearQueue(sessionID)
 862}
 863
 864func (c *coordinator) IsBusy() bool {
 865	return c.currentAgent.IsBusy()
 866}
 867
 868func (c *coordinator) IsSessionBusy(sessionID string) bool {
 869	return c.currentAgent.IsSessionBusy(sessionID)
 870}
 871
 872func (c *coordinator) Model() Model {
 873	return c.currentAgent.Model()
 874}
 875
 876func (c *coordinator) UpdateModels(ctx context.Context) error {
 877	// build the models again so we make sure we get the latest config
 878	large, small, err := c.buildAgentModels(ctx, false)
 879	if err != nil {
 880		return err
 881	}
 882	c.currentAgent.SetModels(large, small)
 883
 884	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
 885	if !ok {
 886		return errCoderAgentNotConfigured
 887	}
 888
 889	tools, err := c.buildTools(ctx, agentCfg)
 890	if err != nil {
 891		return err
 892	}
 893	c.currentAgent.SetTools(tools)
 894	return nil
 895}
 896
 897func (c *coordinator) QueuedPrompts(sessionID string) int {
 898	return c.currentAgent.QueuedPrompts(sessionID)
 899}
 900
 901func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 902	return c.currentAgent.QueuedPromptsList(sessionID)
 903}
 904
 905func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 906	providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 907	if !ok {
 908		return errModelProviderNotConfigured
 909	}
 910	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 911}
 912
 913func (c *coordinator) isUnauthorized(err error) bool {
 914	var providerErr *fantasy.ProviderError
 915	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
 916}
 917
 918func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 919	if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
 920		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 921		return err
 922	}
 923	if err := c.UpdateModels(ctx); err != nil {
 924		return err
 925	}
 926	return nil
 927}
 928
 929func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 930	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 931	if err != nil {
 932		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 933		return err
 934	}
 935
 936	providerCfg.APIKey = newAPIKey
 937	c.cfg.Providers.Set(providerCfg.ID, providerCfg)
 938
 939	if err := c.UpdateModels(ctx); err != nil {
 940		return err
 941	}
 942	return nil
 943}
 944
 945// subAgentParams holds the parameters for running a sub-agent.
 946type subAgentParams struct {
 947	Agent          SessionAgent
 948	SessionID      string
 949	AgentMessageID string
 950	ToolCallID     string
 951	Prompt         string
 952	SessionTitle   string
 953	// SessionSetup is an optional callback invoked after session creation
 954	// but before agent execution, for custom session configuration.
 955	SessionSetup func(sessionID string)
 956}
 957
 958// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 959// It creates a sub-session, runs the agent with the given prompt, and propagates
 960// the cost to the parent session.
 961func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 962	// Create sub-session
 963	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 964	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 965	if err != nil {
 966		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 967	}
 968
 969	// Call session setup function if provided
 970	if params.SessionSetup != nil {
 971		params.SessionSetup(session.ID)
 972	}
 973
 974	// Get model configuration
 975	model := params.Agent.Model()
 976	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 977	if model.ModelCfg.MaxTokens != 0 {
 978		maxTokens = model.ModelCfg.MaxTokens
 979	}
 980
 981	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
 982	if !ok {
 983		return fantasy.ToolResponse{}, errModelProviderNotConfigured
 984	}
 985
 986	// Run the agent
 987	result, err := params.Agent.Run(ctx, SessionAgentCall{
 988		SessionID:        session.ID,
 989		Prompt:           params.Prompt,
 990		MaxOutputTokens:  maxTokens,
 991		ProviderOptions:  getProviderOptions(model, providerCfg),
 992		Temperature:      model.ModelCfg.Temperature,
 993		TopP:             model.ModelCfg.TopP,
 994		TopK:             model.ModelCfg.TopK,
 995		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
 996		PresencePenalty:  model.ModelCfg.PresencePenalty,
 997	})
 998	if err != nil {
 999		return fantasy.NewTextErrorResponse("error generating response"), nil
1000	}
1001
1002	// Update parent session cost
1003	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1004		return fantasy.ToolResponse{}, err
1005	}
1006
1007	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1008}
1009
1010// updateParentSessionCost accumulates the cost from a child session to its parent session.
1011func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1012	childSession, err := c.sessions.Get(ctx, childSessionID)
1013	if err != nil {
1014		return fmt.Errorf("get child session: %w", err)
1015	}
1016
1017	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1018	if err != nil {
1019		return fmt.Errorf("get parent session: %w", err)
1020	}
1021
1022	parentSession.Cost += childSession.Cost
1023
1024	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1025		return fmt.Errorf("save parent session: %w", err)
1026	}
1027
1028	return nil
1029}