coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"slices"
  16	"strings"
  17
  18	"charm.land/catwalk/pkg/catwalk"
  19	"charm.land/fantasy"
  20	"github.com/charmbracelet/crush/internal/agent/hyper"
  21	"github.com/charmbracelet/crush/internal/agent/notify"
  22	"github.com/charmbracelet/crush/internal/agent/prompt"
  23	"github.com/charmbracelet/crush/internal/agent/tools"
  24	"github.com/charmbracelet/crush/internal/config"
  25	"github.com/charmbracelet/crush/internal/filetracker"
  26	"github.com/charmbracelet/crush/internal/history"
  27	"github.com/charmbracelet/crush/internal/log"
  28	"github.com/charmbracelet/crush/internal/lsp"
  29	"github.com/charmbracelet/crush/internal/message"
  30	"github.com/charmbracelet/crush/internal/oauth/copilot"
  31	"github.com/charmbracelet/crush/internal/permission"
  32	"github.com/charmbracelet/crush/internal/pubsub"
  33	"github.com/charmbracelet/crush/internal/session"
  34	"golang.org/x/sync/errgroup"
  35
  36	"charm.land/fantasy/providers/anthropic"
  37	"charm.land/fantasy/providers/azure"
  38	"charm.land/fantasy/providers/bedrock"
  39	"charm.land/fantasy/providers/google"
  40	"charm.land/fantasy/providers/openai"
  41	"charm.land/fantasy/providers/openaicompat"
  42	"charm.land/fantasy/providers/openrouter"
  43	"charm.land/fantasy/providers/vercel"
  44	openaisdk "github.com/charmbracelet/openai-go/option"
  45	"github.com/qjebbs/go-jsons"
  46)
  47
  48// Coordinator errors.
  49var (
  50	errCoderAgentNotConfigured         = errors.New("coder agent not configured")
  51	errModelProviderNotConfigured      = errors.New("model provider not configured")
  52	errLargeModelNotSelected           = errors.New("large model not selected")
  53	errSmallModelNotSelected           = errors.New("small model not selected")
  54	errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  55	errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  56	errLargeModelNotFound              = errors.New("large model not found in provider config")
  57	errSmallModelNotFound              = errors.New("small model not found in provider config")
  58)
  59
  60type Coordinator interface {
  61	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  62	// SetMainAgent(string)
  63	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  64	Cancel(sessionID string)
  65	CancelAll()
  66	IsSessionBusy(sessionID string) bool
  67	IsBusy() bool
  68	QueuedPrompts(sessionID string) int
  69	QueuedPromptsList(sessionID string) []string
  70	ClearQueue(sessionID string)
  71	Summarize(context.Context, string) error
  72	Model() Model
  73	UpdateModels(ctx context.Context) error
  74}
  75
  76type coordinator struct {
  77	cfg         *config.ConfigStore
  78	sessions    session.Service
  79	messages    message.Service
  80	permissions permission.Service
  81	history     history.Service
  82	filetracker filetracker.Service
  83	lspManager  *lsp.Manager
  84	notify      pubsub.Publisher[notify.Notification]
  85
  86	currentAgent SessionAgent
  87	agents       map[string]SessionAgent
  88
  89	readyWg errgroup.Group
  90}
  91
  92func NewCoordinator(
  93	ctx context.Context,
  94	cfg *config.ConfigStore,
  95	sessions session.Service,
  96	messages message.Service,
  97	permissions permission.Service,
  98	history history.Service,
  99	filetracker filetracker.Service,
 100	lspManager *lsp.Manager,
 101	notify pubsub.Publisher[notify.Notification],
 102) (Coordinator, error) {
 103	c := &coordinator{
 104		cfg:         cfg,
 105		sessions:    sessions,
 106		messages:    messages,
 107		permissions: permissions,
 108		history:     history,
 109		filetracker: filetracker,
 110		lspManager:  lspManager,
 111		notify:      notify,
 112		agents:      make(map[string]SessionAgent),
 113	}
 114
 115	agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
 116	if !ok {
 117		return nil, errCoderAgentNotConfigured
 118	}
 119
 120	// TODO: make this dynamic when we support multiple agents
 121	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 122	if err != nil {
 123		return nil, err
 124	}
 125
 126	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 127	if err != nil {
 128		return nil, err
 129	}
 130	c.currentAgent = agent
 131	c.agents[config.AgentCoder] = agent
 132	return c, nil
 133}
 134
 135// Run implements Coordinator.
 136func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 137	if err := c.readyWg.Wait(); err != nil {
 138		return nil, err
 139	}
 140
 141	// refresh models before each run
 142	if err := c.UpdateModels(ctx); err != nil {
 143		return nil, fmt.Errorf("failed to update models: %w", err)
 144	}
 145
 146	model := c.currentAgent.Model()
 147	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 148	if model.ModelCfg.MaxTokens != 0 {
 149		maxTokens = model.ModelCfg.MaxTokens
 150	}
 151
 152	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 153		// filter out image attachments
 154		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 155		for _, att := range attachments {
 156			if att.IsText() {
 157				filteredAttachments = append(filteredAttachments, att)
 158			}
 159		}
 160		attachments = filteredAttachments
 161	}
 162
 163	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 164	if !ok {
 165		return nil, errModelProviderNotConfigured
 166	}
 167
 168	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 169
 170	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 171		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 172		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 173			return nil, err
 174		}
 175	}
 176
 177	run := func() (*fantasy.AgentResult, error) {
 178		return c.currentAgent.Run(ctx, SessionAgentCall{
 179			SessionID:        sessionID,
 180			Prompt:           prompt,
 181			Attachments:      attachments,
 182			MaxOutputTokens:  maxTokens,
 183			ProviderOptions:  mergedOptions,
 184			Temperature:      temp,
 185			TopP:             topP,
 186			TopK:             topK,
 187			FrequencyPenalty: freqPenalty,
 188			PresencePenalty:  presPenalty,
 189		})
 190	}
 191	result, originalErr := run()
 192
 193	if c.isUnauthorized(originalErr) {
 194		switch {
 195		case providerCfg.OAuthToken != nil:
 196			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 197			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 198				return nil, originalErr
 199			}
 200			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 201			return run()
 202		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 203			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 204			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 205				return nil, originalErr
 206			}
 207			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 208			return run()
 209		}
 210	}
 211
 212	return result, originalErr
 213}
 214
 215func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 216	options := fantasy.ProviderOptions{}
 217
 218	cfgOpts := []byte("{}")
 219	providerCfgOpts := []byte("{}")
 220	catwalkOpts := []byte("{}")
 221
 222	if model.ModelCfg.ProviderOptions != nil {
 223		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 224		if err == nil {
 225			cfgOpts = data
 226		}
 227	}
 228
 229	if providerCfg.ProviderOptions != nil {
 230		data, err := json.Marshal(providerCfg.ProviderOptions)
 231		if err == nil {
 232			providerCfgOpts = data
 233		}
 234	}
 235
 236	if model.CatwalkCfg.Options.ProviderOptions != nil {
 237		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 238		if err == nil {
 239			catwalkOpts = data
 240		}
 241	}
 242
 243	readers := []io.Reader{
 244		bytes.NewReader(catwalkOpts),
 245		bytes.NewReader(providerCfgOpts),
 246		bytes.NewReader(cfgOpts),
 247	}
 248
 249	got, err := jsons.Merge(readers)
 250	if err != nil {
 251		slog.Error("Could not merge call config", "err", err)
 252		return options
 253	}
 254
 255	mergedOptions := make(map[string]any)
 256
 257	err = json.Unmarshal([]byte(got), &mergedOptions)
 258	if err != nil {
 259		slog.Error("Could not create config for call", "err", err)
 260		return options
 261	}
 262
 263	providerType := providerCfg.Type
 264	if providerType == "hyper" {
 265		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 266			providerType = anthropic.Name
 267		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 268			providerType = openai.Name
 269		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 270			providerType = google.Name
 271		} else {
 272			providerType = openaicompat.Name
 273		}
 274	}
 275
 276	switch providerType {
 277	case openai.Name, azure.Name:
 278		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 279		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 280			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 281		}
 282		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 283			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 284				mergedOptions["reasoning_summary"] = "auto"
 285				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 286			}
 287			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 288			if err == nil {
 289				options[openai.Name] = parsed
 290			}
 291		} else {
 292			parsed, err := openai.ParseOptions(mergedOptions)
 293			if err == nil {
 294				options[openai.Name] = parsed
 295			}
 296		}
 297	case anthropic.Name:
 298		var (
 299			_, hasEffort = mergedOptions["effort"]
 300			_, hasThink  = mergedOptions["thinking"]
 301		)
 302		switch {
 303		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 304			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 305		case !hasThink && model.ModelCfg.Think:
 306			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 307		}
 308		parsed, err := anthropic.ParseOptions(mergedOptions)
 309		if err == nil {
 310			options[anthropic.Name] = parsed
 311		}
 312
 313	case openrouter.Name:
 314		_, hasReasoning := mergedOptions["reasoning"]
 315		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 316			mergedOptions["reasoning"] = map[string]any{
 317				"enabled": true,
 318				"effort":  model.ModelCfg.ReasoningEffort,
 319			}
 320		}
 321		parsed, err := openrouter.ParseOptions(mergedOptions)
 322		if err == nil {
 323			options[openrouter.Name] = parsed
 324		}
 325	case vercel.Name:
 326		_, hasReasoning := mergedOptions["reasoning"]
 327		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 328			mergedOptions["reasoning"] = map[string]any{
 329				"enabled": true,
 330				"effort":  model.ModelCfg.ReasoningEffort,
 331			}
 332		}
 333		parsed, err := vercel.ParseOptions(mergedOptions)
 334		if err == nil {
 335			options[vercel.Name] = parsed
 336		}
 337	case google.Name:
 338		_, hasReasoning := mergedOptions["thinking_config"]
 339		if !hasReasoning {
 340			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 341				mergedOptions["thinking_config"] = map[string]any{
 342					"thinking_budget":  2000,
 343					"include_thoughts": true,
 344				}
 345			} else {
 346				mergedOptions["thinking_config"] = map[string]any{
 347					"thinking_level":   model.ModelCfg.ReasoningEffort,
 348					"include_thoughts": true,
 349				}
 350			}
 351		}
 352		parsed, err := google.ParseOptions(mergedOptions)
 353		if err == nil {
 354			options[google.Name] = parsed
 355		}
 356	case openaicompat.Name:
 357		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 358		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 359			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 360		}
 361		parsed, err := openaicompat.ParseOptions(mergedOptions)
 362		if err == nil {
 363			options[openaicompat.Name] = parsed
 364		}
 365	}
 366
 367	return options
 368}
 369
 370func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 371	modelOptions := getProviderOptions(model, cfg)
 372	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 373	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 374	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 375	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 376	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 377	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 378}
 379
 380func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 381	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 382	if err != nil {
 383		return nil, err
 384	}
 385
 386	largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
 387	result := NewSessionAgent(SessionAgentOptions{
 388		LargeModel:           large,
 389		SmallModel:           small,
 390		SystemPromptPrefix:   largeProviderCfg.SystemPromptPrefix,
 391		SystemPrompt:         "",
 392		IsSubAgent:           isSubAgent,
 393		DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
 394		IsYolo:               c.permissions.SkipRequests(),
 395		Sessions:             c.sessions,
 396		Messages:             c.messages,
 397		Tools:                nil,
 398		Notify:               c.notify,
 399	})
 400
 401	c.readyWg.Go(func() error {
 402		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
 403		if err != nil {
 404			return err
 405		}
 406		result.SetSystemPrompt(systemPrompt)
 407		return nil
 408	})
 409
 410	c.readyWg.Go(func() error {
 411		tools, err := c.buildTools(ctx, agent)
 412		if err != nil {
 413			return err
 414		}
 415		result.SetTools(tools)
 416		return nil
 417	})
 418
 419	return result, nil
 420}
 421
 422func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 423	var allTools []fantasy.AgentTool
 424	if slices.Contains(agent.AllowedTools, AgentToolName) {
 425		agentTool, err := c.agentTool(ctx)
 426		if err != nil {
 427			return nil, err
 428		}
 429		allTools = append(allTools, agentTool)
 430	}
 431
 432	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 433		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 434		if err != nil {
 435			return nil, err
 436		}
 437		allTools = append(allTools, agenticFetchTool)
 438	}
 439
 440	// Get the model name for the agent
 441	modelName := ""
 442	if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
 443		if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 444			modelName = model.Name
 445		}
 446	}
 447
 448	allTools = append(allTools,
 449		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
 450		tools.NewJobOutputTool(),
 451		tools.NewJobKillTool(),
 452		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 453		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 454		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 455		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 456		tools.NewGlobTool(c.cfg.WorkingDir()),
 457		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
 458		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
 459		tools.NewSourcegraphTool(nil),
 460		tools.NewTodosTool(c.sessions),
 461		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
 462		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 463	)
 464
 465	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 466	if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
 467		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 468	}
 469
 470	if len(c.cfg.Config().MCP) > 0 {
 471		allTools = append(
 472			allTools,
 473			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 474			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 475		)
 476	}
 477
 478	var filteredTools []fantasy.AgentTool
 479	for _, tool := range allTools {
 480		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 481			filteredTools = append(filteredTools, tool)
 482		}
 483	}
 484
 485	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 486		if agent.AllowedMCP == nil {
 487			// No MCP restrictions
 488			filteredTools = append(filteredTools, tool)
 489			continue
 490		}
 491		if len(agent.AllowedMCP) == 0 {
 492			// No MCPs allowed
 493			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 494			break
 495		}
 496
 497		for mcp, tools := range agent.AllowedMCP {
 498			if mcp != tool.MCP() {
 499				continue
 500			}
 501			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 502				filteredTools = append(filteredTools, tool)
 503				break
 504			}
 505			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 506		}
 507	}
 508	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 509		return strings.Compare(a.Info().Name, b.Info().Name)
 510	})
 511	return filteredTools, nil
 512}
 513
 514// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 515func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 516	largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
 517	if !ok {
 518		return Model{}, Model{}, errLargeModelNotSelected
 519	}
 520	smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
 521	if !ok {
 522		return Model{}, Model{}, errSmallModelNotSelected
 523	}
 524
 525	largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
 526	if !ok {
 527		return Model{}, Model{}, errLargeModelProviderNotConfigured
 528	}
 529
 530	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 531	if err != nil {
 532		return Model{}, Model{}, err
 533	}
 534
 535	smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
 536	if !ok {
 537		return Model{}, Model{}, errSmallModelProviderNotConfigured
 538	}
 539
 540	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 541	if err != nil {
 542		return Model{}, Model{}, err
 543	}
 544
 545	var largeCatwalkModel *catwalk.Model
 546	var smallCatwalkModel *catwalk.Model
 547
 548	for _, m := range largeProviderCfg.Models {
 549		if m.ID == largeModelCfg.Model {
 550			largeCatwalkModel = &m
 551		}
 552	}
 553	for _, m := range smallProviderCfg.Models {
 554		if m.ID == smallModelCfg.Model {
 555			smallCatwalkModel = &m
 556		}
 557	}
 558
 559	if largeCatwalkModel == nil {
 560		return Model{}, Model{}, errLargeModelNotFound
 561	}
 562
 563	if smallCatwalkModel == nil {
 564		return Model{}, Model{}, errSmallModelNotFound
 565	}
 566
 567	largeModelID := largeModelCfg.Model
 568	smallModelID := smallModelCfg.Model
 569
 570	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 571		largeModelID += ":exacto"
 572	}
 573
 574	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 575		smallModelID += ":exacto"
 576	}
 577
 578	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 579	if err != nil {
 580		return Model{}, Model{}, err
 581	}
 582	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 583	if err != nil {
 584		return Model{}, Model{}, err
 585	}
 586
 587	return Model{
 588			Model:      largeModel,
 589			CatwalkCfg: *largeCatwalkModel,
 590			ModelCfg:   largeModelCfg,
 591		}, Model{
 592			Model:      smallModel,
 593			CatwalkCfg: *smallCatwalkModel,
 594			ModelCfg:   smallModelCfg,
 595		}, nil
 596}
 597
 598func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 599	var opts []anthropic.Option
 600
 601	switch {
 602	case strings.HasPrefix(apiKey, "Bearer "):
 603		// NOTE: Prevent the SDK from picking up the API key from env.
 604		os.Setenv("ANTHROPIC_API_KEY", "")
 605		headers["Authorization"] = apiKey
 606	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 607		// NOTE: Prevent the SDK from picking up the API key from env.
 608		os.Setenv("ANTHROPIC_API_KEY", "")
 609		headers["Authorization"] = "Bearer " + apiKey
 610	case apiKey != "":
 611		// X-Api-Key header
 612		opts = append(opts, anthropic.WithAPIKey(apiKey))
 613	}
 614
 615	if len(headers) > 0 {
 616		opts = append(opts, anthropic.WithHeaders(headers))
 617	}
 618
 619	if baseURL != "" {
 620		opts = append(opts, anthropic.WithBaseURL(baseURL))
 621	}
 622
 623	if c.cfg.Config().Options.Debug {
 624		httpClient := log.NewHTTPClient()
 625		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 626	}
 627	return anthropic.New(opts...)
 628}
 629
 630func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 631	opts := []openai.Option{
 632		openai.WithAPIKey(apiKey),
 633		openai.WithUseResponsesAPI(),
 634	}
 635	if c.cfg.Config().Options.Debug {
 636		httpClient := log.NewHTTPClient()
 637		opts = append(opts, openai.WithHTTPClient(httpClient))
 638	}
 639	if len(headers) > 0 {
 640		opts = append(opts, openai.WithHeaders(headers))
 641	}
 642	if baseURL != "" {
 643		opts = append(opts, openai.WithBaseURL(baseURL))
 644	}
 645	return openai.New(opts...)
 646}
 647
 648func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 649	opts := []openrouter.Option{
 650		openrouter.WithAPIKey(apiKey),
 651	}
 652	if c.cfg.Config().Options.Debug {
 653		httpClient := log.NewHTTPClient()
 654		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 655	}
 656	if len(headers) > 0 {
 657		opts = append(opts, openrouter.WithHeaders(headers))
 658	}
 659	return openrouter.New(opts...)
 660}
 661
 662func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 663	opts := []vercel.Option{
 664		vercel.WithAPIKey(apiKey),
 665	}
 666	if c.cfg.Config().Options.Debug {
 667		httpClient := log.NewHTTPClient()
 668		opts = append(opts, vercel.WithHTTPClient(httpClient))
 669	}
 670	if len(headers) > 0 {
 671		opts = append(opts, vercel.WithHeaders(headers))
 672	}
 673	return vercel.New(opts...)
 674}
 675
 676func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 677	opts := []openaicompat.Option{
 678		openaicompat.WithBaseURL(baseURL),
 679		openaicompat.WithAPIKey(apiKey),
 680	}
 681
 682	// Set HTTP client based on provider and debug mode.
 683	var httpClient *http.Client
 684	if providerID == string(catwalk.InferenceProviderCopilot) {
 685		opts = append(opts, openaicompat.WithUseResponsesAPI())
 686		httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
 687	} else if c.cfg.Config().Options.Debug {
 688		httpClient = log.NewHTTPClient()
 689	}
 690	if httpClient != nil {
 691		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 692	}
 693
 694	if len(headers) > 0 {
 695		opts = append(opts, openaicompat.WithHeaders(headers))
 696	}
 697
 698	for extraKey, extraValue := range extraBody {
 699		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 700	}
 701
 702	return openaicompat.New(opts...)
 703}
 704
 705func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 706	opts := []azure.Option{
 707		azure.WithBaseURL(baseURL),
 708		azure.WithAPIKey(apiKey),
 709		azure.WithUseResponsesAPI(),
 710	}
 711	if c.cfg.Config().Options.Debug {
 712		httpClient := log.NewHTTPClient()
 713		opts = append(opts, azure.WithHTTPClient(httpClient))
 714	}
 715	if options == nil {
 716		options = make(map[string]string)
 717	}
 718	if apiVersion, ok := options["apiVersion"]; ok {
 719		opts = append(opts, azure.WithAPIVersion(apiVersion))
 720	}
 721	if len(headers) > 0 {
 722		opts = append(opts, azure.WithHeaders(headers))
 723	}
 724
 725	return azure.New(opts...)
 726}
 727
 728func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
 729	var opts []bedrock.Option
 730	if c.cfg.Config().Options.Debug {
 731		httpClient := log.NewHTTPClient()
 732		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 733	}
 734	if len(headers) > 0 {
 735		opts = append(opts, bedrock.WithHeaders(headers))
 736	}
 737	switch {
 738	case apiKey != "":
 739		opts = append(opts, bedrock.WithAPIKey(apiKey))
 740	case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
 741		opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
 742	default:
 743		// Skip, let the SDK do authentication.
 744	}
 745	return bedrock.New(opts...)
 746}
 747
 748func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 749	opts := []google.Option{
 750		google.WithBaseURL(baseURL),
 751		google.WithGeminiAPIKey(apiKey),
 752	}
 753	if c.cfg.Config().Options.Debug {
 754		httpClient := log.NewHTTPClient()
 755		opts = append(opts, google.WithHTTPClient(httpClient))
 756	}
 757	if len(headers) > 0 {
 758		opts = append(opts, google.WithHeaders(headers))
 759	}
 760	return google.New(opts...)
 761}
 762
 763func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 764	opts := []google.Option{}
 765	if c.cfg.Config().Options.Debug {
 766		httpClient := log.NewHTTPClient()
 767		opts = append(opts, google.WithHTTPClient(httpClient))
 768	}
 769	if len(headers) > 0 {
 770		opts = append(opts, google.WithHeaders(headers))
 771	}
 772
 773	project := options["project"]
 774	location := options["location"]
 775
 776	opts = append(opts, google.WithVertex(project, location))
 777
 778	return google.New(opts...)
 779}
 780
 781func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
 782	opts := []hyper.Option{
 783		hyper.WithBaseURL(baseURL),
 784		hyper.WithAPIKey(apiKey),
 785	}
 786	if c.cfg.Config().Options.Debug {
 787		httpClient := log.NewHTTPClient()
 788		opts = append(opts, hyper.WithHTTPClient(httpClient))
 789	}
 790	return hyper.New(opts...)
 791}
 792
 793func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 794	if model.Think {
 795		return true
 796	}
 797	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 798	return err == nil && opts.Thinking != nil
 799}
 800
 801func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 802	headers := maps.Clone(providerCfg.ExtraHeaders)
 803	if headers == nil {
 804		headers = make(map[string]string)
 805	}
 806
 807	// handle special headers for anthropic
 808	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 809		if v, ok := headers["anthropic-beta"]; ok {
 810			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 811		} else {
 812			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 813		}
 814	}
 815
 816	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 817	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 818
 819	switch providerCfg.Type {
 820	case openai.Name:
 821		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 822	case anthropic.Name:
 823		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 824	case openrouter.Name:
 825		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 826	case vercel.Name:
 827		return c.buildVercelProvider(baseURL, apiKey, headers)
 828	case azure.Name:
 829		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 830	case bedrock.Name:
 831		return c.buildBedrockProvider(apiKey, headers)
 832	case google.Name:
 833		return c.buildGoogleProvider(baseURL, apiKey, headers)
 834	case "google-vertex":
 835		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 836	case openaicompat.Name:
 837		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 838			if providerCfg.ExtraBody == nil {
 839				providerCfg.ExtraBody = map[string]any{}
 840			}
 841			providerCfg.ExtraBody["tool_stream"] = true
 842		}
 843		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 844	case hyper.Name:
 845		return c.buildHyperProvider(baseURL, apiKey)
 846	default:
 847		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 848	}
 849}
 850
 851func isExactoSupported(modelID string) bool {
 852	supportedModels := []string{
 853		"moonshotai/kimi-k2-0905",
 854		"deepseek/deepseek-v3.1-terminus",
 855		"z-ai/glm-4.6",
 856		"openai/gpt-oss-120b",
 857		"qwen/qwen3-coder",
 858	}
 859	return slices.Contains(supportedModels, modelID)
 860}
 861
 862func (c *coordinator) Cancel(sessionID string) {
 863	c.currentAgent.Cancel(sessionID)
 864}
 865
 866func (c *coordinator) CancelAll() {
 867	c.currentAgent.CancelAll()
 868}
 869
 870func (c *coordinator) ClearQueue(sessionID string) {
 871	c.currentAgent.ClearQueue(sessionID)
 872}
 873
 874func (c *coordinator) IsBusy() bool {
 875	return c.currentAgent.IsBusy()
 876}
 877
 878func (c *coordinator) IsSessionBusy(sessionID string) bool {
 879	return c.currentAgent.IsSessionBusy(sessionID)
 880}
 881
 882func (c *coordinator) Model() Model {
 883	return c.currentAgent.Model()
 884}
 885
 886func (c *coordinator) UpdateModels(ctx context.Context) error {
 887	// build the models again so we make sure we get the latest config
 888	large, small, err := c.buildAgentModels(ctx, false)
 889	if err != nil {
 890		return err
 891	}
 892	c.currentAgent.SetModels(large, small)
 893
 894	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 895	if !ok {
 896		return errCoderAgentNotConfigured
 897	}
 898
 899	tools, err := c.buildTools(ctx, agentCfg)
 900	if err != nil {
 901		return err
 902	}
 903	c.currentAgent.SetTools(tools)
 904	return nil
 905}
 906
 907func (c *coordinator) QueuedPrompts(sessionID string) int {
 908	return c.currentAgent.QueuedPrompts(sessionID)
 909}
 910
 911func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 912	return c.currentAgent.QueuedPromptsList(sessionID)
 913}
 914
 915func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 916	providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 917	if !ok {
 918		return errModelProviderNotConfigured
 919	}
 920	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 921}
 922
 923func (c *coordinator) isUnauthorized(err error) bool {
 924	var providerErr *fantasy.ProviderError
 925	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
 926}
 927
 928func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 929	if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
 930		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 931		return err
 932	}
 933	if err := c.UpdateModels(ctx); err != nil {
 934		return err
 935	}
 936	return nil
 937}
 938
 939func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 940	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 941	if err != nil {
 942		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 943		return err
 944	}
 945
 946	providerCfg.APIKey = newAPIKey
 947	c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
 948
 949	if err := c.UpdateModels(ctx); err != nil {
 950		return err
 951	}
 952	return nil
 953}
 954
 955// subAgentParams holds the parameters for running a sub-agent.
 956type subAgentParams struct {
 957	Agent          SessionAgent
 958	SessionID      string
 959	AgentMessageID string
 960	ToolCallID     string
 961	Prompt         string
 962	SessionTitle   string
 963	// SessionSetup is an optional callback invoked after session creation
 964	// but before agent execution, for custom session configuration.
 965	SessionSetup func(sessionID string)
 966}
 967
 968// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 969// It creates a sub-session, runs the agent with the given prompt, and propagates
 970// the cost to the parent session.
 971func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 972	// Create sub-session
 973	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 974	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 975	if err != nil {
 976		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 977	}
 978
 979	// Call session setup function if provided
 980	if params.SessionSetup != nil {
 981		params.SessionSetup(session.ID)
 982	}
 983
 984	// Get model configuration
 985	model := params.Agent.Model()
 986	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 987	if model.ModelCfg.MaxTokens != 0 {
 988		maxTokens = model.ModelCfg.MaxTokens
 989	}
 990
 991	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 992	if !ok {
 993		return fantasy.ToolResponse{}, errModelProviderNotConfigured
 994	}
 995
 996	// Run the agent
 997	result, err := params.Agent.Run(ctx, SessionAgentCall{
 998		SessionID:        session.ID,
 999		Prompt:           params.Prompt,
1000		MaxOutputTokens:  maxTokens,
1001		ProviderOptions:  getProviderOptions(model, providerCfg),
1002		Temperature:      model.ModelCfg.Temperature,
1003		TopP:             model.ModelCfg.TopP,
1004		TopK:             model.ModelCfg.TopK,
1005		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1006		PresencePenalty:  model.ModelCfg.PresencePenalty,
1007		NonInteractive:   true,
1008	})
1009	if err != nil {
1010		return fantasy.NewTextErrorResponse("error generating response"), nil
1011	}
1012
1013	// Update parent session cost
1014	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1015		return fantasy.ToolResponse{}, err
1016	}
1017
1018	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1019}
1020
1021// updateParentSessionCost accumulates the cost from a child session to its parent session.
1022func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1023	childSession, err := c.sessions.Get(ctx, childSessionID)
1024	if err != nil {
1025		return fmt.Errorf("get child session: %w", err)
1026	}
1027
1028	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1029	if err != nil {
1030		return fmt.Errorf("get parent session: %w", err)
1031	}
1032
1033	parentSession.Cost += childSession.Cost
1034
1035	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1036		return fmt.Errorf("save parent session: %w", err)
1037	}
1038
1039	return nil
1040}