coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"slices"
  16	"strings"
  17
  18	"charm.land/catwalk/pkg/catwalk"
  19	"charm.land/fantasy"
  20	"github.com/charmbracelet/crush/internal/agent/hyper"
  21	"github.com/charmbracelet/crush/internal/agent/notify"
  22	"github.com/charmbracelet/crush/internal/agent/prompt"
  23	"github.com/charmbracelet/crush/internal/agent/tools"
  24	"github.com/charmbracelet/crush/internal/config"
  25	"github.com/charmbracelet/crush/internal/filetracker"
  26	"github.com/charmbracelet/crush/internal/history"
  27	"github.com/charmbracelet/crush/internal/log"
  28	"github.com/charmbracelet/crush/internal/lsp"
  29	"github.com/charmbracelet/crush/internal/message"
  30	"github.com/charmbracelet/crush/internal/oauth/copilot"
  31	"github.com/charmbracelet/crush/internal/permission"
  32	"github.com/charmbracelet/crush/internal/pubsub"
  33	"github.com/charmbracelet/crush/internal/session"
  34	"golang.org/x/sync/errgroup"
  35
  36	"charm.land/fantasy/providers/anthropic"
  37	"charm.land/fantasy/providers/azure"
  38	"charm.land/fantasy/providers/bedrock"
  39	"charm.land/fantasy/providers/google"
  40	"charm.land/fantasy/providers/openai"
  41	"charm.land/fantasy/providers/openaicompat"
  42	"charm.land/fantasy/providers/openrouter"
  43	"charm.land/fantasy/providers/vercel"
  44	openaisdk "github.com/openai/openai-go/v2/option"
  45	"github.com/qjebbs/go-jsons"
  46)
  47
  48// Coordinator errors.
  49var (
  50	errCoderAgentNotConfigured         = errors.New("coder agent not configured")
  51	errModelProviderNotConfigured      = errors.New("model provider not configured")
  52	errLargeModelNotSelected           = errors.New("large model not selected")
  53	errSmallModelNotSelected           = errors.New("small model not selected")
  54	errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  55	errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  56	errLargeModelNotFound              = errors.New("large model not found in provider config")
  57	errSmallModelNotFound              = errors.New("small model not found in provider config")
  58)
  59
  60type Coordinator interface {
  61	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  62	// SetMainAgent(string)
  63	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  64	Cancel(sessionID string)
  65	CancelAll()
  66	IsSessionBusy(sessionID string) bool
  67	IsBusy() bool
  68	QueuedPrompts(sessionID string) int
  69	QueuedPromptsList(sessionID string) []string
  70	ClearQueue(sessionID string)
  71	Summarize(context.Context, string) error
  72	Model() Model
  73	UpdateModels(ctx context.Context) error
  74	RefreshTools(ctx context.Context) error
  75}
  76
  77type coordinator struct {
  78	cfg         *config.ConfigStore
  79	sessions    session.Service
  80	messages    message.Service
  81	permissions permission.Service
  82	history     history.Service
  83	filetracker filetracker.Service
  84	lspManager  *lsp.Manager
  85	notify      pubsub.Publisher[notify.Notification]
  86
  87	currentAgent SessionAgent
  88	agents       map[string]SessionAgent
  89
  90	readyWg errgroup.Group
  91}
  92
  93func NewCoordinator(
  94	ctx context.Context,
  95	cfg *config.ConfigStore,
  96	sessions session.Service,
  97	messages message.Service,
  98	permissions permission.Service,
  99	history history.Service,
 100	filetracker filetracker.Service,
 101	lspManager *lsp.Manager,
 102	notify pubsub.Publisher[notify.Notification],
 103) (Coordinator, error) {
 104	c := &coordinator{
 105		cfg:         cfg,
 106		sessions:    sessions,
 107		messages:    messages,
 108		permissions: permissions,
 109		history:     history,
 110		filetracker: filetracker,
 111		lspManager:  lspManager,
 112		notify:      notify,
 113		agents:      make(map[string]SessionAgent),
 114	}
 115
 116	agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
 117	if !ok {
 118		return nil, errCoderAgentNotConfigured
 119	}
 120
 121	// TODO: make this dynamic when we support multiple agents
 122	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 123	if err != nil {
 124		return nil, err
 125	}
 126
 127	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 128	if err != nil {
 129		return nil, err
 130	}
 131	c.currentAgent = agent
 132	c.agents[config.AgentCoder] = agent
 133	return c, nil
 134}
 135
 136// Run implements Coordinator.
 137func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 138	if err := c.readyWg.Wait(); err != nil {
 139		return nil, err
 140	}
 141
 142	// refresh models before each run
 143	if err := c.UpdateModels(ctx); err != nil {
 144		return nil, fmt.Errorf("failed to update models: %w", err)
 145	}
 146
 147	model := c.currentAgent.Model()
 148	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 149	if model.ModelCfg.MaxTokens != 0 {
 150		maxTokens = model.ModelCfg.MaxTokens
 151	}
 152
 153	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 154		// filter out image attachments
 155		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 156		for _, att := range attachments {
 157			if att.IsText() {
 158				filteredAttachments = append(filteredAttachments, att)
 159			}
 160		}
 161		attachments = filteredAttachments
 162	}
 163
 164	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 165	if !ok {
 166		return nil, errModelProviderNotConfigured
 167	}
 168
 169	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 170
 171	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 172		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 173		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 174			return nil, err
 175		}
 176	}
 177
 178	run := func() (*fantasy.AgentResult, error) {
 179		return c.currentAgent.Run(ctx, SessionAgentCall{
 180			SessionID:        sessionID,
 181			Prompt:           prompt,
 182			Attachments:      attachments,
 183			MaxOutputTokens:  maxTokens,
 184			ProviderOptions:  mergedOptions,
 185			Temperature:      temp,
 186			TopP:             topP,
 187			TopK:             topK,
 188			FrequencyPenalty: freqPenalty,
 189			PresencePenalty:  presPenalty,
 190		})
 191	}
 192	result, originalErr := run()
 193
 194	if c.isUnauthorized(originalErr) {
 195		switch {
 196		case providerCfg.OAuthToken != nil:
 197			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 198			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 199				return nil, originalErr
 200			}
 201			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 202			return run()
 203		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 204			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 205			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 206				return nil, originalErr
 207			}
 208			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 209			return run()
 210		}
 211	}
 212
 213	return result, originalErr
 214}
 215
 216func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 217	options := fantasy.ProviderOptions{}
 218
 219	cfgOpts := []byte("{}")
 220	providerCfgOpts := []byte("{}")
 221	catwalkOpts := []byte("{}")
 222
 223	if model.ModelCfg.ProviderOptions != nil {
 224		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 225		if err == nil {
 226			cfgOpts = data
 227		}
 228	}
 229
 230	if providerCfg.ProviderOptions != nil {
 231		data, err := json.Marshal(providerCfg.ProviderOptions)
 232		if err == nil {
 233			providerCfgOpts = data
 234		}
 235	}
 236
 237	if model.CatwalkCfg.Options.ProviderOptions != nil {
 238		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 239		if err == nil {
 240			catwalkOpts = data
 241		}
 242	}
 243
 244	readers := []io.Reader{
 245		bytes.NewReader(catwalkOpts),
 246		bytes.NewReader(providerCfgOpts),
 247		bytes.NewReader(cfgOpts),
 248	}
 249
 250	got, err := jsons.Merge(readers)
 251	if err != nil {
 252		slog.Error("Could not merge call config", "err", err)
 253		return options
 254	}
 255
 256	mergedOptions := make(map[string]any)
 257
 258	err = json.Unmarshal([]byte(got), &mergedOptions)
 259	if err != nil {
 260		slog.Error("Could not create config for call", "err", err)
 261		return options
 262	}
 263
 264	providerType := providerCfg.Type
 265	if providerType == "hyper" {
 266		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 267			providerType = anthropic.Name
 268		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 269			providerType = openai.Name
 270		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 271			providerType = google.Name
 272		} else {
 273			providerType = openaicompat.Name
 274		}
 275	}
 276
 277	switch providerType {
 278	case openai.Name, azure.Name:
 279		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 280		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 281			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 282		}
 283		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 284			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 285				mergedOptions["reasoning_summary"] = "auto"
 286				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 287			}
 288			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 289			if err == nil {
 290				options[openai.Name] = parsed
 291			}
 292		} else {
 293			parsed, err := openai.ParseOptions(mergedOptions)
 294			if err == nil {
 295				options[openai.Name] = parsed
 296			}
 297		}
 298	case anthropic.Name:
 299		var (
 300			_, hasEffort = mergedOptions["effort"]
 301			_, hasThink  = mergedOptions["thinking"]
 302		)
 303		switch {
 304		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 305			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 306		case !hasThink && model.ModelCfg.Think:
 307			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 308		}
 309		parsed, err := anthropic.ParseOptions(mergedOptions)
 310		if err == nil {
 311			options[anthropic.Name] = parsed
 312		}
 313
 314	case openrouter.Name:
 315		_, hasReasoning := mergedOptions["reasoning"]
 316		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 317			mergedOptions["reasoning"] = map[string]any{
 318				"enabled": true,
 319				"effort":  model.ModelCfg.ReasoningEffort,
 320			}
 321		}
 322		parsed, err := openrouter.ParseOptions(mergedOptions)
 323		if err == nil {
 324			options[openrouter.Name] = parsed
 325		}
 326	case vercel.Name:
 327		_, hasReasoning := mergedOptions["reasoning"]
 328		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 329			mergedOptions["reasoning"] = map[string]any{
 330				"enabled": true,
 331				"effort":  model.ModelCfg.ReasoningEffort,
 332			}
 333		}
 334		parsed, err := vercel.ParseOptions(mergedOptions)
 335		if err == nil {
 336			options[vercel.Name] = parsed
 337		}
 338	case google.Name:
 339		_, hasReasoning := mergedOptions["thinking_config"]
 340		if !hasReasoning {
 341			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 342				mergedOptions["thinking_config"] = map[string]any{
 343					"thinking_budget":  2000,
 344					"include_thoughts": true,
 345				}
 346			} else {
 347				mergedOptions["thinking_config"] = map[string]any{
 348					"thinking_level":   model.ModelCfg.ReasoningEffort,
 349					"include_thoughts": true,
 350				}
 351			}
 352		}
 353		parsed, err := google.ParseOptions(mergedOptions)
 354		if err == nil {
 355			options[google.Name] = parsed
 356		}
 357	case openaicompat.Name:
 358		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 359		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 360			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 361		}
 362		parsed, err := openaicompat.ParseOptions(mergedOptions)
 363		if err == nil {
 364			options[openaicompat.Name] = parsed
 365		}
 366	}
 367
 368	return options
 369}
 370
 371func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 372	modelOptions := getProviderOptions(model, cfg)
 373	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 374	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 375	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 376	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 377	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 378	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 379}
 380
 381func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 382	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 383	if err != nil {
 384		return nil, err
 385	}
 386
 387	largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
 388	result := NewSessionAgent(SessionAgentOptions{
 389		LargeModel:           large,
 390		SmallModel:           small,
 391		SystemPromptPrefix:   largeProviderCfg.SystemPromptPrefix,
 392		SystemPrompt:         "",
 393		IsSubAgent:           isSubAgent,
 394		DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
 395		IsYolo:               c.permissions.SkipRequests(),
 396		Sessions:             c.sessions,
 397		Messages:             c.messages,
 398		Tools:                nil,
 399		Notify:               c.notify,
 400	})
 401
 402	c.readyWg.Go(func() error {
 403		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
 404		if err != nil {
 405			return err
 406		}
 407		result.SetSystemPrompt(systemPrompt)
 408		return nil
 409	})
 410
 411	c.readyWg.Go(func() error {
 412		tools, err := c.buildTools(ctx, agent)
 413		if err != nil {
 414			return err
 415		}
 416		result.SetTools(tools)
 417		return nil
 418	})
 419
 420	return result, nil
 421}
 422
 423func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 424	var allTools []fantasy.AgentTool
 425	if slices.Contains(agent.AllowedTools, AgentToolName) {
 426		agentTool, err := c.agentTool(ctx)
 427		if err != nil {
 428			return nil, err
 429		}
 430		allTools = append(allTools, agentTool)
 431	}
 432
 433	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 434		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 435		if err != nil {
 436			return nil, err
 437		}
 438		allTools = append(allTools, agenticFetchTool)
 439	}
 440
 441	// Get the model name for the agent
 442	modelName := ""
 443	if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
 444		if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 445			modelName = model.Name
 446		}
 447	}
 448
 449	allTools = append(allTools,
 450		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
 451		tools.NewJobOutputTool(),
 452		tools.NewJobKillTool(),
 453		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 454		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 455		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 456		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 457		tools.NewGlobTool(c.cfg.WorkingDir()),
 458		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
 459		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
 460		tools.NewSourcegraphTool(nil),
 461		tools.NewTodosTool(c.sessions),
 462		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
 463		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 464	)
 465
 466	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 467	if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
 468		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 469	}
 470
 471	if len(c.cfg.Config().MCP) > 0 {
 472		allTools = append(
 473			allTools,
 474			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 475			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 476		)
 477	}
 478
 479	var filteredTools []fantasy.AgentTool
 480	for _, tool := range allTools {
 481		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 482			filteredTools = append(filteredTools, tool)
 483		}
 484	}
 485
 486	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 487		if agent.AllowedMCP == nil {
 488			// No MCP restrictions
 489			filteredTools = append(filteredTools, tool)
 490			continue
 491		}
 492		if len(agent.AllowedMCP) == 0 {
 493			// No MCPs allowed
 494			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 495			break
 496		}
 497
 498		for mcp, tools := range agent.AllowedMCP {
 499			if mcp != tool.MCP() {
 500				continue
 501			}
 502			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 503				filteredTools = append(filteredTools, tool)
 504				break
 505			}
 506			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 507		}
 508	}
 509	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 510		return strings.Compare(a.Info().Name, b.Info().Name)
 511	})
 512	return filteredTools, nil
 513}
 514
 515// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 516func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 517	largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
 518	if !ok {
 519		return Model{}, Model{}, errLargeModelNotSelected
 520	}
 521	smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
 522	if !ok {
 523		return Model{}, Model{}, errSmallModelNotSelected
 524	}
 525
 526	largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
 527	if !ok {
 528		return Model{}, Model{}, errLargeModelProviderNotConfigured
 529	}
 530
 531	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 532	if err != nil {
 533		return Model{}, Model{}, err
 534	}
 535
 536	smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
 537	if !ok {
 538		return Model{}, Model{}, errSmallModelProviderNotConfigured
 539	}
 540
 541	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 542	if err != nil {
 543		return Model{}, Model{}, err
 544	}
 545
 546	var largeCatwalkModel *catwalk.Model
 547	var smallCatwalkModel *catwalk.Model
 548
 549	for _, m := range largeProviderCfg.Models {
 550		if m.ID == largeModelCfg.Model {
 551			largeCatwalkModel = &m
 552		}
 553	}
 554	for _, m := range smallProviderCfg.Models {
 555		if m.ID == smallModelCfg.Model {
 556			smallCatwalkModel = &m
 557		}
 558	}
 559
 560	if largeCatwalkModel == nil {
 561		return Model{}, Model{}, errLargeModelNotFound
 562	}
 563
 564	if smallCatwalkModel == nil {
 565		return Model{}, Model{}, errSmallModelNotFound
 566	}
 567
 568	largeModelID := largeModelCfg.Model
 569	smallModelID := smallModelCfg.Model
 570
 571	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 572		largeModelID += ":exacto"
 573	}
 574
 575	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 576		smallModelID += ":exacto"
 577	}
 578
 579	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 580	if err != nil {
 581		return Model{}, Model{}, err
 582	}
 583	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 584	if err != nil {
 585		return Model{}, Model{}, err
 586	}
 587
 588	return Model{
 589			Model:      largeModel,
 590			CatwalkCfg: *largeCatwalkModel,
 591			ModelCfg:   largeModelCfg,
 592		}, Model{
 593			Model:      smallModel,
 594			CatwalkCfg: *smallCatwalkModel,
 595			ModelCfg:   smallModelCfg,
 596		}, nil
 597}
 598
 599func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 600	var opts []anthropic.Option
 601
 602	switch {
 603	case strings.HasPrefix(apiKey, "Bearer "):
 604		// NOTE: Prevent the SDK from picking up the API key from env.
 605		os.Setenv("ANTHROPIC_API_KEY", "")
 606		headers["Authorization"] = apiKey
 607	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 608		// NOTE: Prevent the SDK from picking up the API key from env.
 609		os.Setenv("ANTHROPIC_API_KEY", "")
 610		headers["Authorization"] = "Bearer " + apiKey
 611	case apiKey != "":
 612		// X-Api-Key header
 613		opts = append(opts, anthropic.WithAPIKey(apiKey))
 614	}
 615
 616	if len(headers) > 0 {
 617		opts = append(opts, anthropic.WithHeaders(headers))
 618	}
 619
 620	if baseURL != "" {
 621		opts = append(opts, anthropic.WithBaseURL(baseURL))
 622	}
 623
 624	if c.cfg.Config().Options.Debug {
 625		httpClient := log.NewHTTPClient()
 626		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 627	}
 628	return anthropic.New(opts...)
 629}
 630
 631func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 632	opts := []openai.Option{
 633		openai.WithAPIKey(apiKey),
 634		openai.WithUseResponsesAPI(),
 635	}
 636	if c.cfg.Config().Options.Debug {
 637		httpClient := log.NewHTTPClient()
 638		opts = append(opts, openai.WithHTTPClient(httpClient))
 639	}
 640	if len(headers) > 0 {
 641		opts = append(opts, openai.WithHeaders(headers))
 642	}
 643	if baseURL != "" {
 644		opts = append(opts, openai.WithBaseURL(baseURL))
 645	}
 646	return openai.New(opts...)
 647}
 648
 649func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 650	opts := []openrouter.Option{
 651		openrouter.WithAPIKey(apiKey),
 652	}
 653	if c.cfg.Config().Options.Debug {
 654		httpClient := log.NewHTTPClient()
 655		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 656	}
 657	if len(headers) > 0 {
 658		opts = append(opts, openrouter.WithHeaders(headers))
 659	}
 660	return openrouter.New(opts...)
 661}
 662
 663func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 664	opts := []vercel.Option{
 665		vercel.WithAPIKey(apiKey),
 666	}
 667	if c.cfg.Config().Options.Debug {
 668		httpClient := log.NewHTTPClient()
 669		opts = append(opts, vercel.WithHTTPClient(httpClient))
 670	}
 671	if len(headers) > 0 {
 672		opts = append(opts, vercel.WithHeaders(headers))
 673	}
 674	return vercel.New(opts...)
 675}
 676
 677func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 678	opts := []openaicompat.Option{
 679		openaicompat.WithBaseURL(baseURL),
 680		openaicompat.WithAPIKey(apiKey),
 681	}
 682
 683	// Set HTTP client based on provider and debug mode.
 684	var httpClient *http.Client
 685	if providerID == string(catwalk.InferenceProviderCopilot) {
 686		opts = append(opts, openaicompat.WithUseResponsesAPI())
 687		httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
 688	} else if c.cfg.Config().Options.Debug {
 689		httpClient = log.NewHTTPClient()
 690	}
 691	if httpClient != nil {
 692		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 693	}
 694
 695	if len(headers) > 0 {
 696		opts = append(opts, openaicompat.WithHeaders(headers))
 697	}
 698
 699	for extraKey, extraValue := range extraBody {
 700		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 701	}
 702
 703	return openaicompat.New(opts...)
 704}
 705
 706func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 707	opts := []azure.Option{
 708		azure.WithBaseURL(baseURL),
 709		azure.WithAPIKey(apiKey),
 710		azure.WithUseResponsesAPI(),
 711	}
 712	if c.cfg.Config().Options.Debug {
 713		httpClient := log.NewHTTPClient()
 714		opts = append(opts, azure.WithHTTPClient(httpClient))
 715	}
 716	if options == nil {
 717		options = make(map[string]string)
 718	}
 719	if apiVersion, ok := options["apiVersion"]; ok {
 720		opts = append(opts, azure.WithAPIVersion(apiVersion))
 721	}
 722	if len(headers) > 0 {
 723		opts = append(opts, azure.WithHeaders(headers))
 724	}
 725
 726	return azure.New(opts...)
 727}
 728
 729func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
 730	var opts []bedrock.Option
 731	if c.cfg.Config().Options.Debug {
 732		httpClient := log.NewHTTPClient()
 733		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 734	}
 735	if len(headers) > 0 {
 736		opts = append(opts, bedrock.WithHeaders(headers))
 737	}
 738	bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
 739	if bearerToken != "" {
 740		opts = append(opts, bedrock.WithAPIKey(bearerToken))
 741	}
 742	return bedrock.New(opts...)
 743}
 744
 745func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 746	opts := []google.Option{
 747		google.WithBaseURL(baseURL),
 748		google.WithGeminiAPIKey(apiKey),
 749	}
 750	if c.cfg.Config().Options.Debug {
 751		httpClient := log.NewHTTPClient()
 752		opts = append(opts, google.WithHTTPClient(httpClient))
 753	}
 754	if len(headers) > 0 {
 755		opts = append(opts, google.WithHeaders(headers))
 756	}
 757	return google.New(opts...)
 758}
 759
 760func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 761	opts := []google.Option{}
 762	if c.cfg.Config().Options.Debug {
 763		httpClient := log.NewHTTPClient()
 764		opts = append(opts, google.WithHTTPClient(httpClient))
 765	}
 766	if len(headers) > 0 {
 767		opts = append(opts, google.WithHeaders(headers))
 768	}
 769
 770	project := options["project"]
 771	location := options["location"]
 772
 773	opts = append(opts, google.WithVertex(project, location))
 774
 775	return google.New(opts...)
 776}
 777
 778func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
 779	opts := []hyper.Option{
 780		hyper.WithBaseURL(baseURL),
 781		hyper.WithAPIKey(apiKey),
 782	}
 783	if c.cfg.Config().Options.Debug {
 784		httpClient := log.NewHTTPClient()
 785		opts = append(opts, hyper.WithHTTPClient(httpClient))
 786	}
 787	return hyper.New(opts...)
 788}
 789
 790func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 791	if model.Think {
 792		return true
 793	}
 794	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 795	return err == nil && opts.Thinking != nil
 796}
 797
 798func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 799	headers := maps.Clone(providerCfg.ExtraHeaders)
 800	if headers == nil {
 801		headers = make(map[string]string)
 802	}
 803
 804	// handle special headers for anthropic
 805	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 806		if v, ok := headers["anthropic-beta"]; ok {
 807			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 808		} else {
 809			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 810		}
 811	}
 812
 813	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 814	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 815
 816	switch providerCfg.Type {
 817	case openai.Name:
 818		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 819	case anthropic.Name:
 820		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 821	case openrouter.Name:
 822		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 823	case vercel.Name:
 824		return c.buildVercelProvider(baseURL, apiKey, headers)
 825	case azure.Name:
 826		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 827	case bedrock.Name:
 828		return c.buildBedrockProvider(headers)
 829	case google.Name:
 830		return c.buildGoogleProvider(baseURL, apiKey, headers)
 831	case "google-vertex":
 832		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 833	case openaicompat.Name:
 834		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 835			if providerCfg.ExtraBody == nil {
 836				providerCfg.ExtraBody = map[string]any{}
 837			}
 838			providerCfg.ExtraBody["tool_stream"] = true
 839		}
 840		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 841	case hyper.Name:
 842		return c.buildHyperProvider(baseURL, apiKey)
 843	default:
 844		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 845	}
 846}
 847
 848func isExactoSupported(modelID string) bool {
 849	supportedModels := []string{
 850		"moonshotai/kimi-k2-0905",
 851		"deepseek/deepseek-v3.1-terminus",
 852		"z-ai/glm-4.6",
 853		"openai/gpt-oss-120b",
 854		"qwen/qwen3-coder",
 855	}
 856	return slices.Contains(supportedModels, modelID)
 857}
 858
 859func (c *coordinator) Cancel(sessionID string) {
 860	c.currentAgent.Cancel(sessionID)
 861}
 862
 863func (c *coordinator) CancelAll() {
 864	c.currentAgent.CancelAll()
 865}
 866
 867func (c *coordinator) ClearQueue(sessionID string) {
 868	c.currentAgent.ClearQueue(sessionID)
 869}
 870
 871func (c *coordinator) IsBusy() bool {
 872	return c.currentAgent.IsBusy()
 873}
 874
 875func (c *coordinator) IsSessionBusy(sessionID string) bool {
 876	return c.currentAgent.IsSessionBusy(sessionID)
 877}
 878
 879func (c *coordinator) Model() Model {
 880	return c.currentAgent.Model()
 881}
 882
 883func (c *coordinator) UpdateModels(ctx context.Context) error {
 884	// build the models again so we make sure we get the latest config
 885	large, small, err := c.buildAgentModels(ctx, false)
 886	if err != nil {
 887		return err
 888	}
 889	c.currentAgent.SetModels(large, small)
 890
 891	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 892	if !ok {
 893		return errCoderAgentNotConfigured
 894	}
 895
 896	tools, err := c.buildTools(ctx, agentCfg)
 897	if err != nil {
 898		return err
 899	}
 900	c.currentAgent.SetTools(tools)
 901	return nil
 902}
 903
 904func (c *coordinator) RefreshTools(ctx context.Context) error {
 905	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 906	if !ok {
 907		return errors.New("coder agent not configured")
 908	}
 909
 910	tools, err := c.buildTools(ctx, agentCfg)
 911	if err != nil {
 912		return err
 913	}
 914	c.currentAgent.SetTools(tools)
 915	slog.Debug("refreshed agent tools", "count", len(tools))
 916	return nil
 917}
 918
 919func (c *coordinator) QueuedPrompts(sessionID string) int {
 920	return c.currentAgent.QueuedPrompts(sessionID)
 921}
 922
 923func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 924	return c.currentAgent.QueuedPromptsList(sessionID)
 925}
 926
 927func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 928	providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 929	if !ok {
 930		return errModelProviderNotConfigured
 931	}
 932	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 933}
 934
 935func (c *coordinator) isUnauthorized(err error) bool {
 936	var providerErr *fantasy.ProviderError
 937	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
 938}
 939
 940func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 941	if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
 942		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 943		return err
 944	}
 945	if err := c.UpdateModels(ctx); err != nil {
 946		return err
 947	}
 948	return nil
 949}
 950
 951func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 952	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 953	if err != nil {
 954		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 955		return err
 956	}
 957
 958	providerCfg.APIKey = newAPIKey
 959	c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
 960
 961	if err := c.UpdateModels(ctx); err != nil {
 962		return err
 963	}
 964	return nil
 965}
 966
 967// subAgentParams holds the parameters for running a sub-agent.
 968type subAgentParams struct {
 969	Agent          SessionAgent
 970	SessionID      string
 971	AgentMessageID string
 972	ToolCallID     string
 973	Prompt         string
 974	SessionTitle   string
 975	// SessionSetup is an optional callback invoked after session creation
 976	// but before agent execution, for custom session configuration.
 977	SessionSetup func(sessionID string)
 978}
 979
 980// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 981// It creates a sub-session, runs the agent with the given prompt, and propagates
 982// the cost to the parent session.
 983func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 984	// Create sub-session
 985	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 986	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 987	if err != nil {
 988		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 989	}
 990
 991	// Call session setup function if provided
 992	if params.SessionSetup != nil {
 993		params.SessionSetup(session.ID)
 994	}
 995
 996	// Get model configuration
 997	model := params.Agent.Model()
 998	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 999	if model.ModelCfg.MaxTokens != 0 {
1000		maxTokens = model.ModelCfg.MaxTokens
1001	}
1002
1003	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
1004	if !ok {
1005		return fantasy.ToolResponse{}, errModelProviderNotConfigured
1006	}
1007
1008	// Run the agent
1009	result, err := params.Agent.Run(ctx, SessionAgentCall{
1010		SessionID:        session.ID,
1011		Prompt:           params.Prompt,
1012		MaxOutputTokens:  maxTokens,
1013		ProviderOptions:  getProviderOptions(model, providerCfg),
1014		Temperature:      model.ModelCfg.Temperature,
1015		TopP:             model.ModelCfg.TopP,
1016		TopK:             model.ModelCfg.TopK,
1017		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1018		PresencePenalty:  model.ModelCfg.PresencePenalty,
1019		NonInteractive:   true,
1020	})
1021	if err != nil {
1022		return fantasy.NewTextErrorResponse("error generating response"), nil
1023	}
1024
1025	// Update parent session cost
1026	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1027		return fantasy.ToolResponse{}, err
1028	}
1029
1030	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1031}
1032
1033// updateParentSessionCost accumulates the cost from a child session to its parent session.
1034func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1035	childSession, err := c.sessions.Get(ctx, childSessionID)
1036	if err != nil {
1037		return fmt.Errorf("get child session: %w", err)
1038	}
1039
1040	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1041	if err != nil {
1042		return fmt.Errorf("get parent session: %w", err)
1043	}
1044
1045	parentSession.Cost += childSession.Cost
1046
1047	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1048		return fmt.Errorf("save parent session: %w", err)
1049	}
1050
1051	return nil
1052}