coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"slices"
  16	"strings"
  17
  18	"charm.land/catwalk/pkg/catwalk"
  19	"charm.land/fantasy"
  20	"github.com/charmbracelet/crush/internal/agent/hyper"
  21	"github.com/charmbracelet/crush/internal/agent/notify"
  22	"github.com/charmbracelet/crush/internal/agent/prompt"
  23	"github.com/charmbracelet/crush/internal/agent/tools"
  24	"github.com/charmbracelet/crush/internal/config"
  25	"github.com/charmbracelet/crush/internal/filetracker"
  26	"github.com/charmbracelet/crush/internal/history"
  27	"github.com/charmbracelet/crush/internal/log"
  28	"github.com/charmbracelet/crush/internal/lsp"
  29	"github.com/charmbracelet/crush/internal/message"
  30	"github.com/charmbracelet/crush/internal/oauth/copilot"
  31	"github.com/charmbracelet/crush/internal/permission"
  32	"github.com/charmbracelet/crush/internal/pubsub"
  33	"github.com/charmbracelet/crush/internal/session"
  34	"golang.org/x/sync/errgroup"
  35
  36	"charm.land/fantasy/providers/anthropic"
  37	"charm.land/fantasy/providers/azure"
  38	"charm.land/fantasy/providers/bedrock"
  39	"charm.land/fantasy/providers/google"
  40	"charm.land/fantasy/providers/openai"
  41	"charm.land/fantasy/providers/openaicompat"
  42	"charm.land/fantasy/providers/openrouter"
  43	"charm.land/fantasy/providers/vercel"
  44	openaisdk "github.com/charmbracelet/openai-go/option"
  45	"github.com/qjebbs/go-jsons"
  46)
  47
  48// Coordinator errors.
  49var (
  50	errCoderAgentNotConfigured         = errors.New("coder agent not configured")
  51	errModelProviderNotConfigured      = errors.New("model provider not configured")
  52	errLargeModelNotSelected           = errors.New("large model not selected")
  53	errSmallModelNotSelected           = errors.New("small model not selected")
  54	errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  55	errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  56	errLargeModelNotFound              = errors.New("large model not found in provider config")
  57	errSmallModelNotFound              = errors.New("small model not found in provider config")
  58)
  59
  60type Coordinator interface {
  61	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  62	// SetMainAgent(string)
  63	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  64	Cancel(sessionID string)
  65	CancelAll()
  66	IsSessionBusy(sessionID string) bool
  67	IsBusy() bool
  68	QueuedPrompts(sessionID string) int
  69	QueuedPromptsList(sessionID string) []string
  70	ClearQueue(sessionID string)
  71	Summarize(context.Context, string) error
  72	Model() Model
  73	UpdateModels(ctx context.Context) error
  74}
  75
  76type coordinator struct {
  77	cfg         *config.ConfigStore
  78	sessions    session.Service
  79	messages    message.Service
  80	permissions permission.Service
  81	history     history.Service
  82	filetracker filetracker.Service
  83	lspManager  *lsp.Manager
  84	notify      pubsub.Publisher[notify.Notification]
  85
  86	currentAgent SessionAgent
  87	agents       map[string]SessionAgent
  88
  89	readyWg errgroup.Group
  90}
  91
  92func NewCoordinator(
  93	ctx context.Context,
  94	cfg *config.ConfigStore,
  95	sessions session.Service,
  96	messages message.Service,
  97	permissions permission.Service,
  98	history history.Service,
  99	filetracker filetracker.Service,
 100	lspManager *lsp.Manager,
 101	notify pubsub.Publisher[notify.Notification],
 102) (Coordinator, error) {
 103	c := &coordinator{
 104		cfg:         cfg,
 105		sessions:    sessions,
 106		messages:    messages,
 107		permissions: permissions,
 108		history:     history,
 109		filetracker: filetracker,
 110		lspManager:  lspManager,
 111		notify:      notify,
 112		agents:      make(map[string]SessionAgent),
 113	}
 114
 115	agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
 116	if !ok {
 117		return nil, errCoderAgentNotConfigured
 118	}
 119
 120	// TODO: make this dynamic when we support multiple agents
 121	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 122	if err != nil {
 123		return nil, err
 124	}
 125
 126	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 127	if err != nil {
 128		return nil, err
 129	}
 130	c.currentAgent = agent
 131	c.agents[config.AgentCoder] = agent
 132	return c, nil
 133}
 134
 135// Run implements Coordinator.
 136func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 137	if err := c.readyWg.Wait(); err != nil {
 138		return nil, err
 139	}
 140
 141	// refresh models before each run
 142	if err := c.UpdateModels(ctx); err != nil {
 143		return nil, fmt.Errorf("failed to update models: %w", err)
 144	}
 145
 146	model := c.currentAgent.Model()
 147	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 148	if model.ModelCfg.MaxTokens != 0 {
 149		maxTokens = model.ModelCfg.MaxTokens
 150	}
 151
 152	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 153		// filter out image attachments
 154		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 155		for _, att := range attachments {
 156			if att.IsText() {
 157				filteredAttachments = append(filteredAttachments, att)
 158			}
 159		}
 160		attachments = filteredAttachments
 161	}
 162
 163	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 164	if !ok {
 165		return nil, errModelProviderNotConfigured
 166	}
 167
 168	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 169
 170	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 171		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 172		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 173			return nil, err
 174		}
 175	}
 176
 177	run := func() (*fantasy.AgentResult, error) {
 178		return c.currentAgent.Run(ctx, SessionAgentCall{
 179			SessionID:        sessionID,
 180			Prompt:           prompt,
 181			Attachments:      attachments,
 182			MaxOutputTokens:  maxTokens,
 183			ProviderOptions:  mergedOptions,
 184			Temperature:      temp,
 185			TopP:             topP,
 186			TopK:             topK,
 187			FrequencyPenalty: freqPenalty,
 188			PresencePenalty:  presPenalty,
 189		})
 190	}
 191	result, originalErr := run()
 192
 193	if c.isUnauthorized(originalErr) {
 194		switch {
 195		case providerCfg.OAuthToken != nil:
 196			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 197			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 198				return nil, originalErr
 199			}
 200			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 201			return run()
 202		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 203			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 204			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 205				return nil, originalErr
 206			}
 207			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 208			return run()
 209		}
 210	}
 211
 212	return result, originalErr
 213}
 214
 215func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 216	options := fantasy.ProviderOptions{}
 217
 218	cfgOpts := []byte("{}")
 219	providerCfgOpts := []byte("{}")
 220	catwalkOpts := []byte("{}")
 221
 222	if model.ModelCfg.ProviderOptions != nil {
 223		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 224		if err == nil {
 225			cfgOpts = data
 226		}
 227	}
 228
 229	if providerCfg.ProviderOptions != nil {
 230		data, err := json.Marshal(providerCfg.ProviderOptions)
 231		if err == nil {
 232			providerCfgOpts = data
 233		}
 234	}
 235
 236	if model.CatwalkCfg.Options.ProviderOptions != nil {
 237		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 238		if err == nil {
 239			catwalkOpts = data
 240		}
 241	}
 242
 243	readers := []io.Reader{
 244		bytes.NewReader(catwalkOpts),
 245		bytes.NewReader(providerCfgOpts),
 246		bytes.NewReader(cfgOpts),
 247	}
 248
 249	got, err := jsons.Merge(readers)
 250	if err != nil {
 251		slog.Error("Could not merge call config", "err", err)
 252		return options
 253	}
 254
 255	mergedOptions := make(map[string]any)
 256
 257	err = json.Unmarshal([]byte(got), &mergedOptions)
 258	if err != nil {
 259		slog.Error("Could not create config for call", "err", err)
 260		return options
 261	}
 262
 263	providerType := providerCfg.Type
 264	if providerType == "hyper" {
 265		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 266			providerType = anthropic.Name
 267		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 268			providerType = openai.Name
 269		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 270			providerType = google.Name
 271		} else {
 272			providerType = openaicompat.Name
 273		}
 274	}
 275
 276	switch providerType {
 277	case openai.Name, azure.Name:
 278		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 279		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 280			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 281		}
 282		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 283			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 284				mergedOptions["reasoning_summary"] = "auto"
 285				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 286			}
 287			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 288			if err == nil {
 289				options[openai.Name] = parsed
 290			}
 291		} else {
 292			parsed, err := openai.ParseOptions(mergedOptions)
 293			if err == nil {
 294				options[openai.Name] = parsed
 295			}
 296		}
 297	case anthropic.Name:
 298		var (
 299			_, hasEffort = mergedOptions["effort"]
 300			_, hasThink  = mergedOptions["thinking"]
 301		)
 302		switch {
 303		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 304			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 305		case !hasThink && model.ModelCfg.Think:
 306			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 307		}
 308		parsed, err := anthropic.ParseOptions(mergedOptions)
 309		if err == nil {
 310			options[anthropic.Name] = parsed
 311		}
 312
 313	case openrouter.Name:
 314		_, hasReasoning := mergedOptions["reasoning"]
 315		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 316			mergedOptions["reasoning"] = map[string]any{
 317				"enabled": true,
 318				"effort":  model.ModelCfg.ReasoningEffort,
 319			}
 320		}
 321		parsed, err := openrouter.ParseOptions(mergedOptions)
 322		if err == nil {
 323			options[openrouter.Name] = parsed
 324		}
 325	case vercel.Name:
 326		_, hasReasoning := mergedOptions["reasoning"]
 327		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 328			mergedOptions["reasoning"] = map[string]any{
 329				"enabled": true,
 330				"effort":  model.ModelCfg.ReasoningEffort,
 331			}
 332		}
 333		parsed, err := vercel.ParseOptions(mergedOptions)
 334		if err == nil {
 335			options[vercel.Name] = parsed
 336		}
 337	case google.Name:
 338		_, hasReasoning := mergedOptions["thinking_config"]
 339		if !hasReasoning {
 340			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 341				mergedOptions["thinking_config"] = map[string]any{
 342					"thinking_budget":  2000,
 343					"include_thoughts": true,
 344				}
 345			} else {
 346				mergedOptions["thinking_config"] = map[string]any{
 347					"thinking_level":   model.ModelCfg.ReasoningEffort,
 348					"include_thoughts": true,
 349				}
 350			}
 351		}
 352		parsed, err := google.ParseOptions(mergedOptions)
 353		if err == nil {
 354			options[google.Name] = parsed
 355		}
 356	case openaicompat.Name:
 357		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 358		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 359			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 360		}
 361		parsed, err := openaicompat.ParseOptions(mergedOptions)
 362		if err == nil {
 363			options[openaicompat.Name] = parsed
 364		}
 365	}
 366
 367	return options
 368}
 369
 370func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 371	modelOptions := getProviderOptions(model, cfg)
 372	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 373	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 374	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 375	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 376	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 377	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 378}
 379
 380func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 381	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 382	if err != nil {
 383		return nil, err
 384	}
 385
 386	largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
 387	result := NewSessionAgent(SessionAgentOptions{
 388		LargeModel:           large,
 389		SmallModel:           small,
 390		SystemPromptPrefix:   largeProviderCfg.SystemPromptPrefix,
 391		SystemPrompt:         "",
 392		IsSubAgent:           isSubAgent,
 393		DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
 394		IsYolo:               c.permissions.SkipRequests(),
 395		Sessions:             c.sessions,
 396		Messages:             c.messages,
 397		Tools:                nil,
 398		Notify:               c.notify,
 399	})
 400
 401	c.readyWg.Go(func() error {
 402		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
 403		if err != nil {
 404			return err
 405		}
 406		result.SetSystemPrompt(systemPrompt)
 407		return nil
 408	})
 409
 410	c.readyWg.Go(func() error {
 411		tools, err := c.buildTools(ctx, agent)
 412		if err != nil {
 413			return err
 414		}
 415		result.SetTools(tools)
 416		return nil
 417	})
 418
 419	return result, nil
 420}
 421
 422func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 423	var allTools []fantasy.AgentTool
 424	if slices.Contains(agent.AllowedTools, AgentToolName) {
 425		agentTool, err := c.agentTool(ctx)
 426		if err != nil {
 427			return nil, err
 428		}
 429		allTools = append(allTools, agentTool)
 430	}
 431
 432	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 433		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 434		if err != nil {
 435			return nil, err
 436		}
 437		allTools = append(allTools, agenticFetchTool)
 438	}
 439
 440	// Get the model name for the agent
 441	modelName := ""
 442	if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
 443		if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 444			modelName = model.Name
 445		}
 446	}
 447
 448	allTools = append(allTools,
 449		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
 450		tools.NewJobOutputTool(),
 451		tools.NewJobKillTool(),
 452		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 453		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 454		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 455		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 456		tools.NewGlobTool(c.cfg.WorkingDir()),
 457		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
 458		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
 459		tools.NewSourcegraphTool(nil),
 460		tools.NewTodosTool(c.sessions),
 461		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
 462		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 463	)
 464
 465	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 466	if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
 467		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 468	}
 469
 470	if len(c.cfg.Config().MCP) > 0 {
 471		allTools = append(
 472			allTools,
 473			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 474			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 475		)
 476	}
 477
 478	var filteredTools []fantasy.AgentTool
 479	for _, tool := range allTools {
 480		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 481			filteredTools = append(filteredTools, tool)
 482		}
 483	}
 484
 485	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 486		if agent.AllowedMCP == nil {
 487			// No MCP restrictions
 488			filteredTools = append(filteredTools, tool)
 489			continue
 490		}
 491		if len(agent.AllowedMCP) == 0 {
 492			// No MCPs allowed
 493			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 494			break
 495		}
 496
 497		for mcp, tools := range agent.AllowedMCP {
 498			if mcp != tool.MCP() {
 499				continue
 500			}
 501			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 502				filteredTools = append(filteredTools, tool)
 503				break
 504			}
 505			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 506		}
 507	}
 508	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 509		return strings.Compare(a.Info().Name, b.Info().Name)
 510	})
 511	return filteredTools, nil
 512}
 513
 514// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 515func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 516	largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
 517	if !ok {
 518		return Model{}, Model{}, errLargeModelNotSelected
 519	}
 520	smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
 521	if !ok {
 522		return Model{}, Model{}, errSmallModelNotSelected
 523	}
 524
 525	largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
 526	if !ok {
 527		return Model{}, Model{}, errLargeModelProviderNotConfigured
 528	}
 529
 530	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 531	if err != nil {
 532		return Model{}, Model{}, err
 533	}
 534
 535	smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
 536	if !ok {
 537		return Model{}, Model{}, errSmallModelProviderNotConfigured
 538	}
 539
 540	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 541	if err != nil {
 542		return Model{}, Model{}, err
 543	}
 544
 545	var largeCatwalkModel *catwalk.Model
 546	var smallCatwalkModel *catwalk.Model
 547
 548	for _, m := range largeProviderCfg.Models {
 549		if m.ID == largeModelCfg.Model {
 550			largeCatwalkModel = &m
 551		}
 552	}
 553	for _, m := range smallProviderCfg.Models {
 554		if m.ID == smallModelCfg.Model {
 555			smallCatwalkModel = &m
 556		}
 557	}
 558
 559	if largeCatwalkModel == nil {
 560		return Model{}, Model{}, errLargeModelNotFound
 561	}
 562
 563	if smallCatwalkModel == nil {
 564		return Model{}, Model{}, errSmallModelNotFound
 565	}
 566
 567	largeModelID := largeModelCfg.Model
 568	smallModelID := smallModelCfg.Model
 569
 570	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 571		largeModelID += ":exacto"
 572	}
 573
 574	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 575		smallModelID += ":exacto"
 576	}
 577
 578	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 579	if err != nil {
 580		return Model{}, Model{}, err
 581	}
 582	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 583	if err != nil {
 584		return Model{}, Model{}, err
 585	}
 586
 587	return Model{
 588			Model:      largeModel,
 589			CatwalkCfg: *largeCatwalkModel,
 590			ModelCfg:   largeModelCfg,
 591		}, Model{
 592			Model:      smallModel,
 593			CatwalkCfg: *smallCatwalkModel,
 594			ModelCfg:   smallModelCfg,
 595		}, nil
 596}
 597
 598func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 599	var opts []anthropic.Option
 600
 601	switch {
 602	case strings.HasPrefix(apiKey, "Bearer "):
 603		// NOTE: Prevent the SDK from picking up the API key from env.
 604		os.Setenv("ANTHROPIC_API_KEY", "")
 605		headers["Authorization"] = apiKey
 606	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 607		// NOTE: Prevent the SDK from picking up the API key from env.
 608		os.Setenv("ANTHROPIC_API_KEY", "")
 609		headers["Authorization"] = "Bearer " + apiKey
 610	case apiKey != "":
 611		// X-Api-Key header
 612		opts = append(opts, anthropic.WithAPIKey(apiKey))
 613	}
 614
 615	if len(headers) > 0 {
 616		opts = append(opts, anthropic.WithHeaders(headers))
 617	}
 618
 619	if baseURL != "" {
 620		opts = append(opts, anthropic.WithBaseURL(baseURL))
 621	}
 622
 623	if c.cfg.Config().Options.Debug {
 624		httpClient := log.NewHTTPClient()
 625		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 626	}
 627	return anthropic.New(opts...)
 628}
 629
 630func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 631	opts := []openai.Option{
 632		openai.WithAPIKey(apiKey),
 633		openai.WithUseResponsesAPI(),
 634	}
 635	if c.cfg.Config().Options.Debug {
 636		httpClient := log.NewHTTPClient()
 637		opts = append(opts, openai.WithHTTPClient(httpClient))
 638	}
 639	if len(headers) > 0 {
 640		opts = append(opts, openai.WithHeaders(headers))
 641	}
 642	if baseURL != "" {
 643		opts = append(opts, openai.WithBaseURL(baseURL))
 644	}
 645	return openai.New(opts...)
 646}
 647
 648func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 649	opts := []openrouter.Option{
 650		openrouter.WithAPIKey(apiKey),
 651	}
 652	if c.cfg.Config().Options.Debug {
 653		httpClient := log.NewHTTPClient()
 654		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 655	}
 656	if len(headers) > 0 {
 657		opts = append(opts, openrouter.WithHeaders(headers))
 658	}
 659	return openrouter.New(opts...)
 660}
 661
 662func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 663	opts := []vercel.Option{
 664		vercel.WithAPIKey(apiKey),
 665	}
 666	if c.cfg.Config().Options.Debug {
 667		httpClient := log.NewHTTPClient()
 668		opts = append(opts, vercel.WithHTTPClient(httpClient))
 669	}
 670	if len(headers) > 0 {
 671		opts = append(opts, vercel.WithHeaders(headers))
 672	}
 673	return vercel.New(opts...)
 674}
 675
 676func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 677	opts := []openaicompat.Option{
 678		openaicompat.WithBaseURL(baseURL),
 679		openaicompat.WithAPIKey(apiKey),
 680	}
 681
 682	// Set HTTP client based on provider and debug mode.
 683	var httpClient *http.Client
 684	if providerID == string(catwalk.InferenceProviderCopilot) {
 685		opts = append(opts, openaicompat.WithUseResponsesAPI())
 686		httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
 687	} else if c.cfg.Config().Options.Debug {
 688		httpClient = log.NewHTTPClient()
 689	}
 690	if httpClient != nil {
 691		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 692	}
 693
 694	if len(headers) > 0 {
 695		opts = append(opts, openaicompat.WithHeaders(headers))
 696	}
 697
 698	for extraKey, extraValue := range extraBody {
 699		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 700	}
 701
 702	return openaicompat.New(opts...)
 703}
 704
 705func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 706	opts := []azure.Option{
 707		azure.WithBaseURL(baseURL),
 708		azure.WithAPIKey(apiKey),
 709		azure.WithUseResponsesAPI(),
 710	}
 711	if c.cfg.Config().Options.Debug {
 712		httpClient := log.NewHTTPClient()
 713		opts = append(opts, azure.WithHTTPClient(httpClient))
 714	}
 715	if options == nil {
 716		options = make(map[string]string)
 717	}
 718	if apiVersion, ok := options["apiVersion"]; ok {
 719		opts = append(opts, azure.WithAPIVersion(apiVersion))
 720	}
 721	if len(headers) > 0 {
 722		opts = append(opts, azure.WithHeaders(headers))
 723	}
 724
 725	return azure.New(opts...)
 726}
 727
 728func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
 729	var opts []bedrock.Option
 730	if c.cfg.Config().Options.Debug {
 731		httpClient := log.NewHTTPClient()
 732		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 733	}
 734	if len(headers) > 0 {
 735		opts = append(opts, bedrock.WithHeaders(headers))
 736	}
 737	switch {
 738	case apiKey != "":
 739		opts = append(opts, bedrock.WithAPIKey(apiKey))
 740	case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
 741		opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
 742	default:
 743		// Skip, let the SDK do authentication.
 744	}
 745	return bedrock.New(opts...)
 746}
 747
 748func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 749	opts := []google.Option{
 750		google.WithBaseURL(baseURL),
 751		google.WithGeminiAPIKey(apiKey),
 752	}
 753	if c.cfg.Config().Options.Debug {
 754		httpClient := log.NewHTTPClient()
 755		opts = append(opts, google.WithHTTPClient(httpClient))
 756	}
 757	if len(headers) > 0 {
 758		opts = append(opts, google.WithHeaders(headers))
 759	}
 760	return google.New(opts...)
 761}
 762
 763func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 764	opts := []google.Option{}
 765	if c.cfg.Config().Options.Debug {
 766		httpClient := log.NewHTTPClient()
 767		opts = append(opts, google.WithHTTPClient(httpClient))
 768	}
 769	if len(headers) > 0 {
 770		opts = append(opts, google.WithHeaders(headers))
 771	}
 772
 773	project := options["project"]
 774	location := options["location"]
 775
 776	opts = append(opts, google.WithVertex(project, location))
 777
 778	return google.New(opts...)
 779}
 780
 781func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
 782	opts := []hyper.Option{
 783		hyper.WithAPIKey(apiKey),
 784	}
 785	if c.cfg.Config().Options.Debug {
 786		httpClient := log.NewHTTPClient()
 787		opts = append(opts, hyper.WithHTTPClient(httpClient))
 788	}
 789	return hyper.New(opts...)
 790}
 791
 792func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 793	if model.Think {
 794		return true
 795	}
 796	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 797	return err == nil && opts.Thinking != nil
 798}
 799
 800func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 801	headers := maps.Clone(providerCfg.ExtraHeaders)
 802	if headers == nil {
 803		headers = make(map[string]string)
 804	}
 805
 806	// handle special headers for anthropic
 807	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 808		if v, ok := headers["anthropic-beta"]; ok {
 809			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 810		} else {
 811			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 812		}
 813	}
 814
 815	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 816	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 817
 818	switch providerCfg.Type {
 819	case openai.Name:
 820		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 821	case anthropic.Name:
 822		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 823	case openrouter.Name:
 824		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 825	case vercel.Name:
 826		return c.buildVercelProvider(baseURL, apiKey, headers)
 827	case azure.Name:
 828		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 829	case bedrock.Name:
 830		return c.buildBedrockProvider(apiKey, headers)
 831	case google.Name:
 832		return c.buildGoogleProvider(baseURL, apiKey, headers)
 833	case "google-vertex":
 834		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 835	case openaicompat.Name:
 836		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 837			if providerCfg.ExtraBody == nil {
 838				providerCfg.ExtraBody = map[string]any{}
 839			}
 840			providerCfg.ExtraBody["tool_stream"] = true
 841		}
 842		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 843	case hyper.Name:
 844		return c.buildHyperProvider(apiKey)
 845	default:
 846		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 847	}
 848}
 849
 850func isExactoSupported(modelID string) bool {
 851	supportedModels := []string{
 852		"moonshotai/kimi-k2-0905",
 853		"deepseek/deepseek-v3.1-terminus",
 854		"z-ai/glm-4.6",
 855		"openai/gpt-oss-120b",
 856		"qwen/qwen3-coder",
 857	}
 858	return slices.Contains(supportedModels, modelID)
 859}
 860
 861func (c *coordinator) Cancel(sessionID string) {
 862	c.currentAgent.Cancel(sessionID)
 863}
 864
 865func (c *coordinator) CancelAll() {
 866	c.currentAgent.CancelAll()
 867}
 868
 869func (c *coordinator) ClearQueue(sessionID string) {
 870	c.currentAgent.ClearQueue(sessionID)
 871}
 872
 873func (c *coordinator) IsBusy() bool {
 874	return c.currentAgent.IsBusy()
 875}
 876
 877func (c *coordinator) IsSessionBusy(sessionID string) bool {
 878	return c.currentAgent.IsSessionBusy(sessionID)
 879}
 880
 881func (c *coordinator) Model() Model {
 882	return c.currentAgent.Model()
 883}
 884
 885func (c *coordinator) UpdateModels(ctx context.Context) error {
 886	// build the models again so we make sure we get the latest config
 887	large, small, err := c.buildAgentModels(ctx, false)
 888	if err != nil {
 889		return err
 890	}
 891	c.currentAgent.SetModels(large, small)
 892
 893	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 894	if !ok {
 895		return errCoderAgentNotConfigured
 896	}
 897
 898	tools, err := c.buildTools(ctx, agentCfg)
 899	if err != nil {
 900		return err
 901	}
 902	c.currentAgent.SetTools(tools)
 903	return nil
 904}
 905
 906func (c *coordinator) QueuedPrompts(sessionID string) int {
 907	return c.currentAgent.QueuedPrompts(sessionID)
 908}
 909
 910func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 911	return c.currentAgent.QueuedPromptsList(sessionID)
 912}
 913
 914func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 915	providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 916	if !ok {
 917		return errModelProviderNotConfigured
 918	}
 919	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 920}
 921
 922func (c *coordinator) isUnauthorized(err error) bool {
 923	var providerErr *fantasy.ProviderError
 924	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
 925}
 926
 927func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 928	if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
 929		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 930		return err
 931	}
 932	if err := c.UpdateModels(ctx); err != nil {
 933		return err
 934	}
 935	return nil
 936}
 937
 938func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 939	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 940	if err != nil {
 941		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 942		return err
 943	}
 944
 945	providerCfg.APIKey = newAPIKey
 946	c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
 947
 948	if err := c.UpdateModels(ctx); err != nil {
 949		return err
 950	}
 951	return nil
 952}
 953
 954// subAgentParams holds the parameters for running a sub-agent.
 955type subAgentParams struct {
 956	Agent          SessionAgent
 957	SessionID      string
 958	AgentMessageID string
 959	ToolCallID     string
 960	Prompt         string
 961	SessionTitle   string
 962	// SessionSetup is an optional callback invoked after session creation
 963	// but before agent execution, for custom session configuration.
 964	SessionSetup func(sessionID string)
 965}
 966
 967// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 968// It creates a sub-session, runs the agent with the given prompt, and propagates
 969// the cost to the parent session.
 970func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 971	// Create sub-session
 972	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 973	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 974	if err != nil {
 975		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 976	}
 977
 978	// Call session setup function if provided
 979	if params.SessionSetup != nil {
 980		params.SessionSetup(session.ID)
 981	}
 982
 983	// Get model configuration
 984	model := params.Agent.Model()
 985	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 986	if model.ModelCfg.MaxTokens != 0 {
 987		maxTokens = model.ModelCfg.MaxTokens
 988	}
 989
 990	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 991	if !ok {
 992		return fantasy.ToolResponse{}, errModelProviderNotConfigured
 993	}
 994
 995	// Run the agent
 996	result, err := params.Agent.Run(ctx, SessionAgentCall{
 997		SessionID:        session.ID,
 998		Prompt:           params.Prompt,
 999		MaxOutputTokens:  maxTokens,
1000		ProviderOptions:  getProviderOptions(model, providerCfg),
1001		Temperature:      model.ModelCfg.Temperature,
1002		TopP:             model.ModelCfg.TopP,
1003		TopK:             model.ModelCfg.TopK,
1004		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1005		PresencePenalty:  model.ModelCfg.PresencePenalty,
1006		NonInteractive:   true,
1007	})
1008	if err != nil {
1009		return fantasy.NewTextErrorResponse("error generating response"), nil
1010	}
1011
1012	// Update parent session cost
1013	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1014		return fantasy.ToolResponse{}, err
1015	}
1016
1017	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1018}
1019
1020// updateParentSessionCost accumulates the cost from a child session to its parent session.
1021func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1022	childSession, err := c.sessions.Get(ctx, childSessionID)
1023	if err != nil {
1024		return fmt.Errorf("get child session: %w", err)
1025	}
1026
1027	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1028	if err != nil {
1029		return fmt.Errorf("get parent session: %w", err)
1030	}
1031
1032	parentSession.Cost += childSession.Cost
1033
1034	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1035		return fmt.Errorf("save parent session: %w", err)
1036	}
1037
1038	return nil
1039}