coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"path/filepath"
  16	"slices"
  17	"strings"
  18
  19	"charm.land/catwalk/pkg/catwalk"
  20	"charm.land/fantasy"
  21	"github.com/charmbracelet/crush/internal/agent/hyper"
  22	"github.com/charmbracelet/crush/internal/agent/notify"
  23	"github.com/charmbracelet/crush/internal/agent/prompt"
  24	"github.com/charmbracelet/crush/internal/agent/tools"
  25	"github.com/charmbracelet/crush/internal/config"
  26	"github.com/charmbracelet/crush/internal/filetracker"
  27	"github.com/charmbracelet/crush/internal/history"
  28	"github.com/charmbracelet/crush/internal/log"
  29	"github.com/charmbracelet/crush/internal/lsp"
  30	"github.com/charmbracelet/crush/internal/message"
  31	"github.com/charmbracelet/crush/internal/oauth/copilot"
  32	"github.com/charmbracelet/crush/internal/permission"
  33	"github.com/charmbracelet/crush/internal/pubsub"
  34	"github.com/charmbracelet/crush/internal/session"
  35	"golang.org/x/sync/errgroup"
  36
  37	"charm.land/fantasy/providers/anthropic"
  38	"charm.land/fantasy/providers/azure"
  39	"charm.land/fantasy/providers/bedrock"
  40	"charm.land/fantasy/providers/google"
  41	"charm.land/fantasy/providers/openai"
  42	"charm.land/fantasy/providers/openaicompat"
  43	"charm.land/fantasy/providers/openrouter"
  44	"charm.land/fantasy/providers/vercel"
  45	openaisdk "github.com/charmbracelet/openai-go/option"
  46	"github.com/qjebbs/go-jsons"
  47)
  48
  49// Coordinator errors.
  50var (
  51	errCoderAgentNotConfigured         = errors.New("coder agent not configured")
  52	errModelProviderNotConfigured      = errors.New("model provider not configured")
  53	errLargeModelNotSelected           = errors.New("large model not selected")
  54	errSmallModelNotSelected           = errors.New("small model not selected")
  55	errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
  56	errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
  57	errLargeModelNotFound              = errors.New("large model not found in provider config")
  58	errSmallModelNotFound              = errors.New("small model not found in provider config")
  59)
  60
  61type Coordinator interface {
  62	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  63	// SetMainAgent(string)
  64	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  65	Cancel(sessionID string)
  66	CancelAll()
  67	IsSessionBusy(sessionID string) bool
  68	IsBusy() bool
  69	QueuedPrompts(sessionID string) int
  70	QueuedPromptsList(sessionID string) []string
  71	ClearQueue(sessionID string)
  72	Summarize(context.Context, string) error
  73	Model() Model
  74	UpdateModels(ctx context.Context) error
  75}
  76
  77type coordinator struct {
  78	cfg         *config.ConfigStore
  79	sessions    session.Service
  80	messages    message.Service
  81	permissions permission.Service
  82	history     history.Service
  83	filetracker filetracker.Service
  84	lspManager  *lsp.Manager
  85	notify      pubsub.Publisher[notify.Notification]
  86
  87	currentAgent SessionAgent
  88	agents       map[string]SessionAgent
  89
  90	readyWg errgroup.Group
  91}
  92
  93func NewCoordinator(
  94	ctx context.Context,
  95	cfg *config.ConfigStore,
  96	sessions session.Service,
  97	messages message.Service,
  98	permissions permission.Service,
  99	history history.Service,
 100	filetracker filetracker.Service,
 101	lspManager *lsp.Manager,
 102	notify pubsub.Publisher[notify.Notification],
 103) (Coordinator, error) {
 104	c := &coordinator{
 105		cfg:         cfg,
 106		sessions:    sessions,
 107		messages:    messages,
 108		permissions: permissions,
 109		history:     history,
 110		filetracker: filetracker,
 111		lspManager:  lspManager,
 112		notify:      notify,
 113		agents:      make(map[string]SessionAgent),
 114	}
 115
 116	agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
 117	if !ok {
 118		return nil, errCoderAgentNotConfigured
 119	}
 120
 121	// TODO: make this dynamic when we support multiple agents
 122	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 123	if err != nil {
 124		return nil, err
 125	}
 126
 127	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 128	if err != nil {
 129		return nil, err
 130	}
 131	c.currentAgent = agent
 132	c.agents[config.AgentCoder] = agent
 133	return c, nil
 134}
 135
 136// Run implements Coordinator.
 137func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 138	if err := c.readyWg.Wait(); err != nil {
 139		return nil, err
 140	}
 141
 142	// refresh models before each run
 143	if err := c.UpdateModels(ctx); err != nil {
 144		return nil, fmt.Errorf("failed to update models: %w", err)
 145	}
 146
 147	model := c.currentAgent.Model()
 148	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 149	if model.ModelCfg.MaxTokens != 0 {
 150		maxTokens = model.ModelCfg.MaxTokens
 151	}
 152
 153	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 154		// filter out image attachments
 155		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 156		for _, att := range attachments {
 157			if att.IsText() {
 158				filteredAttachments = append(filteredAttachments, att)
 159			}
 160		}
 161		attachments = filteredAttachments
 162	}
 163
 164	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 165	if !ok {
 166		return nil, errModelProviderNotConfigured
 167	}
 168
 169	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 170
 171	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 172		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 173		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 174			return nil, err
 175		}
 176	}
 177
 178	run := func() (*fantasy.AgentResult, error) {
 179		return c.currentAgent.Run(ctx, SessionAgentCall{
 180			SessionID:        sessionID,
 181			Prompt:           prompt,
 182			Attachments:      attachments,
 183			MaxOutputTokens:  maxTokens,
 184			ProviderOptions:  mergedOptions,
 185			Temperature:      temp,
 186			TopP:             topP,
 187			TopK:             topK,
 188			FrequencyPenalty: freqPenalty,
 189			PresencePenalty:  presPenalty,
 190		})
 191	}
 192	result, originalErr := run()
 193
 194	if c.isUnauthorized(originalErr) {
 195		switch {
 196		case providerCfg.OAuthToken != nil:
 197			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 198			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 199				return nil, originalErr
 200			}
 201			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 202			return run()
 203		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 204			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 205			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 206				return nil, originalErr
 207			}
 208			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 209			return run()
 210		}
 211	}
 212
 213	return result, originalErr
 214}
 215
 216func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 217	options := fantasy.ProviderOptions{}
 218
 219	cfgOpts := []byte("{}")
 220	providerCfgOpts := []byte("{}")
 221	catwalkOpts := []byte("{}")
 222
 223	if model.ModelCfg.ProviderOptions != nil {
 224		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 225		if err == nil {
 226			cfgOpts = data
 227		}
 228	}
 229
 230	if providerCfg.ProviderOptions != nil {
 231		data, err := json.Marshal(providerCfg.ProviderOptions)
 232		if err == nil {
 233			providerCfgOpts = data
 234		}
 235	}
 236
 237	if model.CatwalkCfg.Options.ProviderOptions != nil {
 238		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 239		if err == nil {
 240			catwalkOpts = data
 241		}
 242	}
 243
 244	readers := []io.Reader{
 245		bytes.NewReader(catwalkOpts),
 246		bytes.NewReader(providerCfgOpts),
 247		bytes.NewReader(cfgOpts),
 248	}
 249
 250	got, err := jsons.Merge(readers)
 251	if err != nil {
 252		slog.Error("Could not merge call config", "err", err)
 253		return options
 254	}
 255
 256	mergedOptions := make(map[string]any)
 257
 258	err = json.Unmarshal([]byte(got), &mergedOptions)
 259	if err != nil {
 260		slog.Error("Could not create config for call", "err", err)
 261		return options
 262	}
 263
 264	providerType := providerCfg.Type
 265	if providerType == "hyper" {
 266		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 267			providerType = anthropic.Name
 268		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 269			providerType = openai.Name
 270		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 271			providerType = google.Name
 272		} else {
 273			providerType = openaicompat.Name
 274		}
 275	}
 276
 277	switch providerType {
 278	case openai.Name, azure.Name:
 279		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 280		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 281			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 282		}
 283		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 284			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 285				mergedOptions["reasoning_summary"] = "auto"
 286				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 287			}
 288			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 289			if err == nil {
 290				options[openai.Name] = parsed
 291			}
 292		} else {
 293			parsed, err := openai.ParseOptions(mergedOptions)
 294			if err == nil {
 295				options[openai.Name] = parsed
 296			}
 297		}
 298	case anthropic.Name:
 299		var (
 300			_, hasEffort = mergedOptions["effort"]
 301			_, hasThink  = mergedOptions["thinking"]
 302		)
 303		switch {
 304		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 305			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 306		case !hasThink && model.ModelCfg.Think:
 307			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 308		}
 309		parsed, err := anthropic.ParseOptions(mergedOptions)
 310		if err == nil {
 311			options[anthropic.Name] = parsed
 312		}
 313
 314	case openrouter.Name:
 315		_, hasReasoning := mergedOptions["reasoning"]
 316		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 317			mergedOptions["reasoning"] = map[string]any{
 318				"enabled": true,
 319				"effort":  model.ModelCfg.ReasoningEffort,
 320			}
 321		}
 322		parsed, err := openrouter.ParseOptions(mergedOptions)
 323		if err == nil {
 324			options[openrouter.Name] = parsed
 325		}
 326	case vercel.Name:
 327		_, hasReasoning := mergedOptions["reasoning"]
 328		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 329			mergedOptions["reasoning"] = map[string]any{
 330				"enabled": true,
 331				"effort":  model.ModelCfg.ReasoningEffort,
 332			}
 333		}
 334		parsed, err := vercel.ParseOptions(mergedOptions)
 335		if err == nil {
 336			options[vercel.Name] = parsed
 337		}
 338	case google.Name:
 339		_, hasReasoning := mergedOptions["thinking_config"]
 340		if !hasReasoning {
 341			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 342				mergedOptions["thinking_config"] = map[string]any{
 343					"thinking_budget":  2000,
 344					"include_thoughts": true,
 345				}
 346			} else {
 347				mergedOptions["thinking_config"] = map[string]any{
 348					"thinking_level":   model.ModelCfg.ReasoningEffort,
 349					"include_thoughts": true,
 350				}
 351			}
 352		}
 353		parsed, err := google.ParseOptions(mergedOptions)
 354		if err == nil {
 355			options[google.Name] = parsed
 356		}
 357	case openaicompat.Name:
 358		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 359		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 360			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 361		}
 362		parsed, err := openaicompat.ParseOptions(mergedOptions)
 363		if err == nil {
 364			options[openaicompat.Name] = parsed
 365		}
 366	}
 367
 368	return options
 369}
 370
 371func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 372	modelOptions := getProviderOptions(model, cfg)
 373	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 374	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 375	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 376	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 377	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 378	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 379}
 380
 381func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 382	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 383	if err != nil {
 384		return nil, err
 385	}
 386
 387	largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
 388	result := NewSessionAgent(SessionAgentOptions{
 389		LargeModel:           large,
 390		SmallModel:           small,
 391		SystemPromptPrefix:   largeProviderCfg.SystemPromptPrefix,
 392		SystemPrompt:         "",
 393		IsSubAgent:           isSubAgent,
 394		DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
 395		IsYolo:               c.permissions.SkipRequests(),
 396		Sessions:             c.sessions,
 397		Messages:             c.messages,
 398		Tools:                nil,
 399		Notify:               c.notify,
 400	})
 401
 402	c.readyWg.Go(func() error {
 403		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
 404		if err != nil {
 405			return err
 406		}
 407		result.SetSystemPrompt(systemPrompt)
 408		return nil
 409	})
 410
 411	c.readyWg.Go(func() error {
 412		tools, err := c.buildTools(ctx, agent)
 413		if err != nil {
 414			return err
 415		}
 416		result.SetTools(tools)
 417		return nil
 418	})
 419
 420	return result, nil
 421}
 422
 423func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 424	var allTools []fantasy.AgentTool
 425	if slices.Contains(agent.AllowedTools, AgentToolName) {
 426		agentTool, err := c.agentTool(ctx)
 427		if err != nil {
 428			return nil, err
 429		}
 430		allTools = append(allTools, agentTool)
 431	}
 432
 433	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 434		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 435		if err != nil {
 436			return nil, err
 437		}
 438		allTools = append(allTools, agenticFetchTool)
 439	}
 440
 441	// Get the model name for the agent
 442	modelName := ""
 443	if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
 444		if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 445			modelName = model.Name
 446		}
 447	}
 448
 449	logFile := filepath.Join(c.cfg.Config().Options.DataDirectory, "logs", "crush.log")
 450
 451	allTools = append(allTools,
 452		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
 453		tools.NewCrushInfoTool(c.cfg, c.lspManager),
 454		tools.NewCrushLogsTool(logFile),
 455		tools.NewJobOutputTool(),
 456		tools.NewJobKillTool(),
 457		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 458		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 459		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 460		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 461		tools.NewGlobTool(c.cfg.WorkingDir()),
 462		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
 463		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
 464		tools.NewSourcegraphTool(nil),
 465		tools.NewTodosTool(c.sessions),
 466		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
 467		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 468	)
 469
 470	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 471	if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
 472		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 473	}
 474
 475	if len(c.cfg.Config().MCP) > 0 {
 476		allTools = append(
 477			allTools,
 478			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 479			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 480		)
 481	}
 482
 483	var filteredTools []fantasy.AgentTool
 484	for _, tool := range allTools {
 485		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 486			filteredTools = append(filteredTools, tool)
 487		}
 488	}
 489
 490	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 491		if agent.AllowedMCP == nil {
 492			// No MCP restrictions
 493			filteredTools = append(filteredTools, tool)
 494			continue
 495		}
 496		if len(agent.AllowedMCP) == 0 {
 497			// No MCPs allowed
 498			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 499			break
 500		}
 501
 502		for mcp, tools := range agent.AllowedMCP {
 503			if mcp != tool.MCP() {
 504				continue
 505			}
 506			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 507				filteredTools = append(filteredTools, tool)
 508				break
 509			}
 510			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 511		}
 512	}
 513	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 514		return strings.Compare(a.Info().Name, b.Info().Name)
 515	})
 516	return filteredTools, nil
 517}
 518
 519// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 520func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 521	largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
 522	if !ok {
 523		return Model{}, Model{}, errLargeModelNotSelected
 524	}
 525	smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
 526	if !ok {
 527		return Model{}, Model{}, errSmallModelNotSelected
 528	}
 529
 530	largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
 531	if !ok {
 532		return Model{}, Model{}, errLargeModelProviderNotConfigured
 533	}
 534
 535	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 536	if err != nil {
 537		return Model{}, Model{}, err
 538	}
 539
 540	smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
 541	if !ok {
 542		return Model{}, Model{}, errSmallModelProviderNotConfigured
 543	}
 544
 545	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 546	if err != nil {
 547		return Model{}, Model{}, err
 548	}
 549
 550	var largeCatwalkModel *catwalk.Model
 551	var smallCatwalkModel *catwalk.Model
 552
 553	for _, m := range largeProviderCfg.Models {
 554		if m.ID == largeModelCfg.Model {
 555			largeCatwalkModel = &m
 556		}
 557	}
 558	for _, m := range smallProviderCfg.Models {
 559		if m.ID == smallModelCfg.Model {
 560			smallCatwalkModel = &m
 561		}
 562	}
 563
 564	if largeCatwalkModel == nil {
 565		return Model{}, Model{}, errLargeModelNotFound
 566	}
 567
 568	if smallCatwalkModel == nil {
 569		return Model{}, Model{}, errSmallModelNotFound
 570	}
 571
 572	largeModelID := largeModelCfg.Model
 573	smallModelID := smallModelCfg.Model
 574
 575	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 576		largeModelID += ":exacto"
 577	}
 578
 579	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 580		smallModelID += ":exacto"
 581	}
 582
 583	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 584	if err != nil {
 585		return Model{}, Model{}, err
 586	}
 587	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 588	if err != nil {
 589		return Model{}, Model{}, err
 590	}
 591
 592	return Model{
 593			Model:      largeModel,
 594			CatwalkCfg: *largeCatwalkModel,
 595			ModelCfg:   largeModelCfg,
 596		}, Model{
 597			Model:      smallModel,
 598			CatwalkCfg: *smallCatwalkModel,
 599			ModelCfg:   smallModelCfg,
 600		}, nil
 601}
 602
 603func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 604	var opts []anthropic.Option
 605
 606	switch {
 607	case strings.HasPrefix(apiKey, "Bearer "):
 608		// NOTE: Prevent the SDK from picking up the API key from env.
 609		os.Setenv("ANTHROPIC_API_KEY", "")
 610		headers["Authorization"] = apiKey
 611	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 612		// NOTE: Prevent the SDK from picking up the API key from env.
 613		os.Setenv("ANTHROPIC_API_KEY", "")
 614		headers["Authorization"] = "Bearer " + apiKey
 615	case apiKey != "":
 616		// X-Api-Key header
 617		opts = append(opts, anthropic.WithAPIKey(apiKey))
 618	}
 619
 620	if len(headers) > 0 {
 621		opts = append(opts, anthropic.WithHeaders(headers))
 622	}
 623
 624	if baseURL != "" {
 625		opts = append(opts, anthropic.WithBaseURL(baseURL))
 626	}
 627
 628	if c.cfg.Config().Options.Debug {
 629		httpClient := log.NewHTTPClient()
 630		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 631	}
 632	return anthropic.New(opts...)
 633}
 634
 635func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 636	opts := []openai.Option{
 637		openai.WithAPIKey(apiKey),
 638		openai.WithUseResponsesAPI(),
 639	}
 640	if c.cfg.Config().Options.Debug {
 641		httpClient := log.NewHTTPClient()
 642		opts = append(opts, openai.WithHTTPClient(httpClient))
 643	}
 644	if len(headers) > 0 {
 645		opts = append(opts, openai.WithHeaders(headers))
 646	}
 647	if baseURL != "" {
 648		opts = append(opts, openai.WithBaseURL(baseURL))
 649	}
 650	return openai.New(opts...)
 651}
 652
 653func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 654	opts := []openrouter.Option{
 655		openrouter.WithAPIKey(apiKey),
 656	}
 657	if c.cfg.Config().Options.Debug {
 658		httpClient := log.NewHTTPClient()
 659		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 660	}
 661	if len(headers) > 0 {
 662		opts = append(opts, openrouter.WithHeaders(headers))
 663	}
 664	return openrouter.New(opts...)
 665}
 666
 667func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 668	opts := []vercel.Option{
 669		vercel.WithAPIKey(apiKey),
 670	}
 671	if c.cfg.Config().Options.Debug {
 672		httpClient := log.NewHTTPClient()
 673		opts = append(opts, vercel.WithHTTPClient(httpClient))
 674	}
 675	if len(headers) > 0 {
 676		opts = append(opts, vercel.WithHeaders(headers))
 677	}
 678	return vercel.New(opts...)
 679}
 680
 681func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 682	opts := []openaicompat.Option{
 683		openaicompat.WithBaseURL(baseURL),
 684		openaicompat.WithAPIKey(apiKey),
 685	}
 686
 687	// Set HTTP client based on provider and debug mode.
 688	var httpClient *http.Client
 689	if providerID == string(catwalk.InferenceProviderCopilot) {
 690		opts = append(opts, openaicompat.WithUseResponsesAPI())
 691		httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
 692	} else if c.cfg.Config().Options.Debug {
 693		httpClient = log.NewHTTPClient()
 694	}
 695	if httpClient != nil {
 696		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 697	}
 698
 699	if len(headers) > 0 {
 700		opts = append(opts, openaicompat.WithHeaders(headers))
 701	}
 702
 703	for extraKey, extraValue := range extraBody {
 704		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 705	}
 706
 707	return openaicompat.New(opts...)
 708}
 709
 710func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 711	opts := []azure.Option{
 712		azure.WithBaseURL(baseURL),
 713		azure.WithAPIKey(apiKey),
 714		azure.WithUseResponsesAPI(),
 715	}
 716	if c.cfg.Config().Options.Debug {
 717		httpClient := log.NewHTTPClient()
 718		opts = append(opts, azure.WithHTTPClient(httpClient))
 719	}
 720	if options == nil {
 721		options = make(map[string]string)
 722	}
 723	if apiVersion, ok := options["apiVersion"]; ok {
 724		opts = append(opts, azure.WithAPIVersion(apiVersion))
 725	}
 726	if len(headers) > 0 {
 727		opts = append(opts, azure.WithHeaders(headers))
 728	}
 729
 730	return azure.New(opts...)
 731}
 732
 733func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
 734	var opts []bedrock.Option
 735	if c.cfg.Config().Options.Debug {
 736		httpClient := log.NewHTTPClient()
 737		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 738	}
 739	if len(headers) > 0 {
 740		opts = append(opts, bedrock.WithHeaders(headers))
 741	}
 742	switch {
 743	case apiKey != "":
 744		opts = append(opts, bedrock.WithAPIKey(apiKey))
 745	case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
 746		opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
 747	default:
 748		// Skip, let the SDK do authentication.
 749	}
 750	return bedrock.New(opts...)
 751}
 752
 753func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 754	opts := []google.Option{
 755		google.WithBaseURL(baseURL),
 756		google.WithGeminiAPIKey(apiKey),
 757	}
 758	if c.cfg.Config().Options.Debug {
 759		httpClient := log.NewHTTPClient()
 760		opts = append(opts, google.WithHTTPClient(httpClient))
 761	}
 762	if len(headers) > 0 {
 763		opts = append(opts, google.WithHeaders(headers))
 764	}
 765	return google.New(opts...)
 766}
 767
 768func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 769	opts := []google.Option{}
 770	if c.cfg.Config().Options.Debug {
 771		httpClient := log.NewHTTPClient()
 772		opts = append(opts, google.WithHTTPClient(httpClient))
 773	}
 774	if len(headers) > 0 {
 775		opts = append(opts, google.WithHeaders(headers))
 776	}
 777
 778	project := options["project"]
 779	location := options["location"]
 780
 781	opts = append(opts, google.WithVertex(project, location))
 782
 783	return google.New(opts...)
 784}
 785
 786func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
 787	opts := []hyper.Option{
 788		hyper.WithAPIKey(apiKey),
 789	}
 790	if c.cfg.Config().Options.Debug {
 791		httpClient := log.NewHTTPClient()
 792		opts = append(opts, hyper.WithHTTPClient(httpClient))
 793	}
 794	return hyper.New(opts...)
 795}
 796
 797func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 798	if model.Think {
 799		return true
 800	}
 801	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 802	return err == nil && opts.Thinking != nil
 803}
 804
 805func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 806	headers := maps.Clone(providerCfg.ExtraHeaders)
 807	if headers == nil {
 808		headers = make(map[string]string)
 809	}
 810
 811	// handle special headers for anthropic
 812	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 813		if v, ok := headers["anthropic-beta"]; ok {
 814			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 815		} else {
 816			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 817		}
 818	}
 819
 820	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 821	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 822
 823	switch providerCfg.Type {
 824	case openai.Name:
 825		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 826	case anthropic.Name:
 827		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 828	case openrouter.Name:
 829		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 830	case vercel.Name:
 831		return c.buildVercelProvider(baseURL, apiKey, headers)
 832	case azure.Name:
 833		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 834	case bedrock.Name:
 835		return c.buildBedrockProvider(apiKey, headers)
 836	case google.Name:
 837		return c.buildGoogleProvider(baseURL, apiKey, headers)
 838	case "google-vertex":
 839		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 840	case openaicompat.Name:
 841		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 842			if providerCfg.ExtraBody == nil {
 843				providerCfg.ExtraBody = map[string]any{}
 844			}
 845			providerCfg.ExtraBody["tool_stream"] = true
 846		}
 847		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 848	case hyper.Name:
 849		return c.buildHyperProvider(apiKey)
 850	default:
 851		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 852	}
 853}
 854
 855func isExactoSupported(modelID string) bool {
 856	supportedModels := []string{
 857		"moonshotai/kimi-k2-0905",
 858		"deepseek/deepseek-v3.1-terminus",
 859		"z-ai/glm-4.6",
 860		"openai/gpt-oss-120b",
 861		"qwen/qwen3-coder",
 862	}
 863	return slices.Contains(supportedModels, modelID)
 864}
 865
 866func (c *coordinator) Cancel(sessionID string) {
 867	c.currentAgent.Cancel(sessionID)
 868}
 869
 870func (c *coordinator) CancelAll() {
 871	c.currentAgent.CancelAll()
 872}
 873
 874func (c *coordinator) ClearQueue(sessionID string) {
 875	c.currentAgent.ClearQueue(sessionID)
 876}
 877
 878func (c *coordinator) IsBusy() bool {
 879	return c.currentAgent.IsBusy()
 880}
 881
 882func (c *coordinator) IsSessionBusy(sessionID string) bool {
 883	return c.currentAgent.IsSessionBusy(sessionID)
 884}
 885
 886func (c *coordinator) Model() Model {
 887	return c.currentAgent.Model()
 888}
 889
 890func (c *coordinator) UpdateModels(ctx context.Context) error {
 891	// build the models again so we make sure we get the latest config
 892	large, small, err := c.buildAgentModels(ctx, false)
 893	if err != nil {
 894		return err
 895	}
 896	c.currentAgent.SetModels(large, small)
 897
 898	agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
 899	if !ok {
 900		return errCoderAgentNotConfigured
 901	}
 902
 903	tools, err := c.buildTools(ctx, agentCfg)
 904	if err != nil {
 905		return err
 906	}
 907	c.currentAgent.SetTools(tools)
 908	return nil
 909}
 910
 911func (c *coordinator) QueuedPrompts(sessionID string) int {
 912	return c.currentAgent.QueuedPrompts(sessionID)
 913}
 914
 915func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 916	return c.currentAgent.QueuedPromptsList(sessionID)
 917}
 918
 919func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 920	providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 921	if !ok {
 922		return errModelProviderNotConfigured
 923	}
 924	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 925}
 926
 927func (c *coordinator) isUnauthorized(err error) bool {
 928	var providerErr *fantasy.ProviderError
 929	return (errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized) ||
 930		errors.Is(err, hyper.ErrUnauthorized)
 931}
 932
 933func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 934	if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
 935		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 936		return err
 937	}
 938	if err := c.UpdateModels(ctx); err != nil {
 939		return err
 940	}
 941	return nil
 942}
 943
 944func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 945	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 946	if err != nil {
 947		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 948		return err
 949	}
 950
 951	providerCfg.APIKey = newAPIKey
 952	c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
 953
 954	if err := c.UpdateModels(ctx); err != nil {
 955		return err
 956	}
 957	return nil
 958}
 959
 960// subAgentParams holds the parameters for running a sub-agent.
 961type subAgentParams struct {
 962	Agent          SessionAgent
 963	SessionID      string
 964	AgentMessageID string
 965	ToolCallID     string
 966	Prompt         string
 967	SessionTitle   string
 968	// SessionSetup is an optional callback invoked after session creation
 969	// but before agent execution, for custom session configuration.
 970	SessionSetup func(sessionID string)
 971}
 972
 973// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 974// It creates a sub-session, runs the agent with the given prompt, and propagates
 975// the cost to the parent session.
 976func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 977	// Create sub-session
 978	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 979	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 980	if err != nil {
 981		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 982	}
 983
 984	// Call session setup function if provided
 985	if params.SessionSetup != nil {
 986		params.SessionSetup(session.ID)
 987	}
 988
 989	// Get model configuration
 990	model := params.Agent.Model()
 991	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 992	if model.ModelCfg.MaxTokens != 0 {
 993		maxTokens = model.ModelCfg.MaxTokens
 994	}
 995
 996	providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
 997	if !ok {
 998		return fantasy.ToolResponse{}, errModelProviderNotConfigured
 999	}
1000
1001	// Run the agent
1002	result, err := params.Agent.Run(ctx, SessionAgentCall{
1003		SessionID:        session.ID,
1004		Prompt:           params.Prompt,
1005		MaxOutputTokens:  maxTokens,
1006		ProviderOptions:  getProviderOptions(model, providerCfg),
1007		Temperature:      model.ModelCfg.Temperature,
1008		TopP:             model.ModelCfg.TopP,
1009		TopK:             model.ModelCfg.TopK,
1010		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1011		PresencePenalty:  model.ModelCfg.PresencePenalty,
1012		NonInteractive:   true,
1013	})
1014	if err != nil {
1015		return fantasy.NewTextErrorResponse("error generating response"), nil
1016	}
1017
1018	// Update parent session cost
1019	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1020		return fantasy.ToolResponse{}, err
1021	}
1022
1023	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1024}
1025
1026// updateParentSessionCost accumulates the cost from a child session to its parent session.
1027func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1028	childSession, err := c.sessions.Get(ctx, childSessionID)
1029	if err != nil {
1030		return fmt.Errorf("get child session: %w", err)
1031	}
1032
1033	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1034	if err != nil {
1035		return fmt.Errorf("get parent session: %w", err)
1036	}
1037
1038	parentSession.Cost += childSession.Cost
1039
1040	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1041		return fmt.Errorf("save parent session: %w", err)
1042	}
1043
1044	return nil
1045}