coordinator.go

   1package agent
   2
   3import (
   4	"bytes"
   5	"cmp"
   6	"context"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"log/slog"
  12	"maps"
  13	"net/http"
  14	"os"
  15	"slices"
  16	"strings"
  17
  18	"charm.land/catwalk/pkg/catwalk"
  19	"charm.land/fantasy"
  20	"github.com/charmbracelet/crush/internal/agent/hyper"
  21	"github.com/charmbracelet/crush/internal/agent/prompt"
  22	"github.com/charmbracelet/crush/internal/agent/tools"
  23	"github.com/charmbracelet/crush/internal/config"
  24	"github.com/charmbracelet/crush/internal/filetracker"
  25	"github.com/charmbracelet/crush/internal/history"
  26	"github.com/charmbracelet/crush/internal/log"
  27	"github.com/charmbracelet/crush/internal/lsp"
  28	"github.com/charmbracelet/crush/internal/message"
  29	"github.com/charmbracelet/crush/internal/oauth/copilot"
  30	"github.com/charmbracelet/crush/internal/permission"
  31	"github.com/charmbracelet/crush/internal/session"
  32	"golang.org/x/sync/errgroup"
  33
  34	"charm.land/fantasy/providers/anthropic"
  35	"charm.land/fantasy/providers/azure"
  36	"charm.land/fantasy/providers/bedrock"
  37	"charm.land/fantasy/providers/google"
  38	"charm.land/fantasy/providers/openai"
  39	"charm.land/fantasy/providers/openaicompat"
  40	"charm.land/fantasy/providers/openrouter"
  41	"charm.land/fantasy/providers/vercel"
  42	openaisdk "github.com/openai/openai-go/v2/option"
  43	"github.com/qjebbs/go-jsons"
  44)
  45
  46type Coordinator interface {
  47	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
  48	// SetMainAgent(string)
  49	Run(ctx context.Context, sessionID, prompt string, verbose bool, attachments ...message.Attachment) (*fantasy.AgentResult, error)
  50	Cancel(sessionID string)
  51	CancelAll()
  52	IsSessionBusy(sessionID string) bool
  53	IsBusy() bool
  54	QueuedPrompts(sessionID string) int
  55	QueuedPromptsList(sessionID string) []string
  56	ClearQueue(sessionID string)
  57	Summarize(context.Context, string) error
  58	Model() Model
  59	UpdateModels(ctx context.Context) error
  60}
  61
  62type coordinator struct {
  63	cfg         *config.Config
  64	sessions    session.Service
  65	messages    message.Service
  66	permissions permission.Service
  67	history     history.Service
  68	filetracker filetracker.Service
  69	lspManager  *lsp.Manager
  70
  71	currentAgent SessionAgent
  72	agents       map[string]SessionAgent
  73
  74	readyWg errgroup.Group
  75}
  76
  77func NewCoordinator(
  78	ctx context.Context,
  79	cfg *config.Config,
  80	sessions session.Service,
  81	messages message.Service,
  82	permissions permission.Service,
  83	history history.Service,
  84	filetracker filetracker.Service,
  85	lspManager *lsp.Manager,
  86) (Coordinator, error) {
  87	c := &coordinator{
  88		cfg:         cfg,
  89		sessions:    sessions,
  90		messages:    messages,
  91		permissions: permissions,
  92		history:     history,
  93		filetracker: filetracker,
  94		lspManager:  lspManager,
  95		agents:      make(map[string]SessionAgent),
  96	}
  97
  98	agentCfg, ok := cfg.Agents[config.AgentCoder]
  99	if !ok {
 100		return nil, errors.New("coder agent not configured")
 101	}
 102
 103	// TODO: make this dynamic when we support multiple agents
 104	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 105	if err != nil {
 106		return nil, err
 107	}
 108
 109	agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
 110	if err != nil {
 111		return nil, err
 112	}
 113	c.currentAgent = agent
 114	c.agents[config.AgentCoder] = agent
 115	return c, nil
 116}
 117
 118// Run implements Coordinator.
 119func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, verbose bool, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 120	if err := c.readyWg.Wait(); err != nil {
 121		return nil, err
 122	}
 123
 124	// refresh models before each run
 125	if err := c.UpdateModels(ctx); err != nil {
 126		return nil, fmt.Errorf("failed to update models: %w", err)
 127	}
 128
 129	model := c.currentAgent.Model()
 130	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 131	if model.ModelCfg.MaxTokens != 0 {
 132		maxTokens = model.ModelCfg.MaxTokens
 133	}
 134
 135	if !model.CatwalkCfg.SupportsImages && attachments != nil {
 136		// filter out image attachments
 137		filteredAttachments := make([]message.Attachment, 0, len(attachments))
 138		for _, att := range attachments {
 139			if att.IsText() {
 140				filteredAttachments = append(filteredAttachments, att)
 141			}
 142		}
 143		attachments = filteredAttachments
 144	}
 145
 146	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
 147	if !ok {
 148		return nil, errors.New("model provider not configured")
 149	}
 150
 151	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 152
 153	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
 154		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
 155		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 156			return nil, err
 157		}
 158	}
 159
 160	run := func() (*fantasy.AgentResult, error) {
 161		return c.currentAgent.Run(ctx, SessionAgentCall{
 162			SessionID:        sessionID,
 163			Prompt:           prompt,
 164			Attachments:      attachments,
 165			MaxOutputTokens:  maxTokens,
 166			ShowToolCalls:    verbose,
 167			ProviderOptions:  mergedOptions,
 168			Temperature:      temp,
 169			TopP:             topP,
 170			TopK:             topK,
 171			FrequencyPenalty: freqPenalty,
 172			PresencePenalty:  presPenalty,
 173		})
 174	}
 175	result, originalErr := run()
 176
 177	if c.isUnauthorized(originalErr) {
 178		switch {
 179		case providerCfg.OAuthToken != nil:
 180			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
 181			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
 182				return nil, originalErr
 183			}
 184			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
 185			return run()
 186		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
 187			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
 188			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
 189				return nil, originalErr
 190			}
 191			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
 192			return run()
 193		}
 194	}
 195
 196	return result, originalErr
 197}
 198
 199func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 200	options := fantasy.ProviderOptions{}
 201
 202	cfgOpts := []byte("{}")
 203	providerCfgOpts := []byte("{}")
 204	catwalkOpts := []byte("{}")
 205
 206	if model.ModelCfg.ProviderOptions != nil {
 207		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
 208		if err == nil {
 209			cfgOpts = data
 210		}
 211	}
 212
 213	if providerCfg.ProviderOptions != nil {
 214		data, err := json.Marshal(providerCfg.ProviderOptions)
 215		if err == nil {
 216			providerCfgOpts = data
 217		}
 218	}
 219
 220	if model.CatwalkCfg.Options.ProviderOptions != nil {
 221		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 222		if err == nil {
 223			catwalkOpts = data
 224		}
 225	}
 226
 227	readers := []io.Reader{
 228		bytes.NewReader(catwalkOpts),
 229		bytes.NewReader(providerCfgOpts),
 230		bytes.NewReader(cfgOpts),
 231	}
 232
 233	got, err := jsons.Merge(readers)
 234	if err != nil {
 235		slog.Error("Could not merge call config", "err", err)
 236		return options
 237	}
 238
 239	mergedOptions := make(map[string]any)
 240
 241	err = json.Unmarshal([]byte(got), &mergedOptions)
 242	if err != nil {
 243		slog.Error("Could not create config for call", "err", err)
 244		return options
 245	}
 246
 247	providerType := providerCfg.Type
 248	if providerType == "hyper" {
 249		if strings.Contains(model.CatwalkCfg.ID, "claude") {
 250			providerType = anthropic.Name
 251		} else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
 252			providerType = openai.Name
 253		} else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
 254			providerType = google.Name
 255		} else {
 256			providerType = openaicompat.Name
 257		}
 258	}
 259
 260	switch providerType {
 261	case openai.Name, azure.Name:
 262		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 263		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 264			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 265		}
 266		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
 267			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
 268				mergedOptions["reasoning_summary"] = "auto"
 269				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
 270			}
 271			parsed, err := openai.ParseResponsesOptions(mergedOptions)
 272			if err == nil {
 273				options[openai.Name] = parsed
 274			}
 275		} else {
 276			parsed, err := openai.ParseOptions(mergedOptions)
 277			if err == nil {
 278				options[openai.Name] = parsed
 279			}
 280		}
 281	case anthropic.Name:
 282		var (
 283			_, hasEffort = mergedOptions["effort"]
 284			_, hasThink  = mergedOptions["thinking"]
 285		)
 286		switch {
 287		case !hasEffort && model.ModelCfg.ReasoningEffort != "":
 288			mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
 289		case !hasThink && model.ModelCfg.Think:
 290			mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
 291		}
 292		parsed, err := anthropic.ParseOptions(mergedOptions)
 293		if err == nil {
 294			options[anthropic.Name] = parsed
 295		}
 296
 297	case openrouter.Name:
 298		_, hasReasoning := mergedOptions["reasoning"]
 299		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 300			mergedOptions["reasoning"] = map[string]any{
 301				"enabled": true,
 302				"effort":  model.ModelCfg.ReasoningEffort,
 303			}
 304		}
 305		parsed, err := openrouter.ParseOptions(mergedOptions)
 306		if err == nil {
 307			options[openrouter.Name] = parsed
 308		}
 309	case vercel.Name:
 310		_, hasReasoning := mergedOptions["reasoning"]
 311		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
 312			mergedOptions["reasoning"] = map[string]any{
 313				"enabled": true,
 314				"effort":  model.ModelCfg.ReasoningEffort,
 315			}
 316		}
 317		parsed, err := vercel.ParseOptions(mergedOptions)
 318		if err == nil {
 319			options[vercel.Name] = parsed
 320		}
 321	case google.Name:
 322		_, hasReasoning := mergedOptions["thinking_config"]
 323		if !hasReasoning {
 324			if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
 325				mergedOptions["thinking_config"] = map[string]any{
 326					"thinking_budget":  2000,
 327					"include_thoughts": true,
 328				}
 329			} else {
 330				mergedOptions["thinking_config"] = map[string]any{
 331					"thinking_level":   model.ModelCfg.ReasoningEffort,
 332					"include_thoughts": true,
 333				}
 334			}
 335		}
 336		parsed, err := google.ParseOptions(mergedOptions)
 337		if err == nil {
 338			options[google.Name] = parsed
 339		}
 340	case openaicompat.Name:
 341		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 342		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
 343			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
 344		}
 345		parsed, err := openaicompat.ParseOptions(mergedOptions)
 346		if err == nil {
 347			options[openaicompat.Name] = parsed
 348		}
 349	}
 350
 351	return options
 352}
 353
 354func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
 355	modelOptions := getProviderOptions(model, cfg)
 356	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 357	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 358	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
 359	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
 360	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
 361	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 362}
 363
 364func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
 365	large, small, err := c.buildAgentModels(ctx, isSubAgent)
 366	if err != nil {
 367		return nil, err
 368	}
 369
 370	largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
 371	result := NewSessionAgent(SessionAgentOptions{
 372		large,
 373		small,
 374		largeProviderCfg.SystemPromptPrefix,
 375		"",
 376		isSubAgent,
 377		c.cfg.Options.DisableAutoSummarize,
 378		c.permissions.SkipRequests(),
 379		c.sessions,
 380		c.messages,
 381		nil,
 382	})
 383
 384	c.readyWg.Go(func() error {
 385		systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
 386		if err != nil {
 387			return err
 388		}
 389		result.SetSystemPrompt(systemPrompt)
 390		return nil
 391	})
 392
 393	c.readyWg.Go(func() error {
 394		tools, err := c.buildTools(ctx, agent)
 395		if err != nil {
 396			return err
 397		}
 398		result.SetTools(tools)
 399		return nil
 400	})
 401
 402	return result, nil
 403}
 404
 405func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 406	var allTools []fantasy.AgentTool
 407	if slices.Contains(agent.AllowedTools, AgentToolName) {
 408		agentTool, err := c.agentTool(ctx)
 409		if err != nil {
 410			return nil, err
 411		}
 412		allTools = append(allTools, agentTool)
 413	}
 414
 415	if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
 416		agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
 417		if err != nil {
 418			return nil, err
 419		}
 420		allTools = append(allTools, agenticFetchTool)
 421	}
 422
 423	// Get the model name for the agent
 424	modelName := ""
 425	if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
 426		if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
 427			modelName = model.Name
 428		}
 429	}
 430
 431	allTools = append(allTools,
 432		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
 433		tools.NewJobOutputTool(),
 434		tools.NewJobKillTool(),
 435		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
 436		tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 437		tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 438		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
 439		tools.NewGlobTool(c.cfg.WorkingDir()),
 440		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
 441		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
 442		tools.NewSourcegraphTool(nil),
 443		tools.NewTodosTool(c.sessions),
 444		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
 445		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 446	)
 447
 448	// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
 449	if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
 450		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
 451	}
 452
 453	if len(c.cfg.MCP) > 0 {
 454		allTools = append(
 455			allTools,
 456			tools.NewListMCPResourcesTool(c.cfg, c.permissions),
 457			tools.NewReadMCPResourceTool(c.cfg, c.permissions),
 458		)
 459	}
 460
 461	var filteredTools []fantasy.AgentTool
 462	for _, tool := range allTools {
 463		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
 464			filteredTools = append(filteredTools, tool)
 465		}
 466	}
 467
 468	for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
 469		if agent.AllowedMCP == nil {
 470			// No MCP restrictions
 471			filteredTools = append(filteredTools, tool)
 472			continue
 473		}
 474		if len(agent.AllowedMCP) == 0 {
 475			// No MCPs allowed
 476			slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
 477			break
 478		}
 479
 480		for mcp, tools := range agent.AllowedMCP {
 481			if mcp != tool.MCP() {
 482				continue
 483			}
 484			if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
 485				filteredTools = append(filteredTools, tool)
 486				break
 487			}
 488			slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
 489		}
 490	}
 491	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
 492		return strings.Compare(a.Info().Name, b.Info().Name)
 493	})
 494	return filteredTools, nil
 495}
 496
 497// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
 498func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
 499	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
 500	if !ok {
 501		return Model{}, Model{}, errors.New("large model not selected")
 502	}
 503	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
 504	if !ok {
 505		return Model{}, Model{}, errors.New("small model not selected")
 506	}
 507
 508	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
 509	if !ok {
 510		return Model{}, Model{}, errors.New("large model provider not configured")
 511	}
 512
 513	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
 514	if err != nil {
 515		return Model{}, Model{}, err
 516	}
 517
 518	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
 519	if !ok {
 520		return Model{}, Model{}, errors.New("small model provider not configured")
 521	}
 522
 523	smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
 524	if err != nil {
 525		return Model{}, Model{}, err
 526	}
 527
 528	var largeCatwalkModel *catwalk.Model
 529	var smallCatwalkModel *catwalk.Model
 530
 531	for _, m := range largeProviderCfg.Models {
 532		if m.ID == largeModelCfg.Model {
 533			largeCatwalkModel = &m
 534		}
 535	}
 536	for _, m := range smallProviderCfg.Models {
 537		if m.ID == smallModelCfg.Model {
 538			smallCatwalkModel = &m
 539		}
 540	}
 541
 542	if largeCatwalkModel == nil {
 543		return Model{}, Model{}, errors.New("large model not found in provider config")
 544	}
 545
 546	if smallCatwalkModel == nil {
 547		return Model{}, Model{}, errors.New("small model not found in provider config")
 548	}
 549
 550	largeModelID := largeModelCfg.Model
 551	smallModelID := smallModelCfg.Model
 552
 553	if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
 554		largeModelID += ":exacto"
 555	}
 556
 557	if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
 558		smallModelID += ":exacto"
 559	}
 560
 561	largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
 562	if err != nil {
 563		return Model{}, Model{}, err
 564	}
 565	smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
 566	if err != nil {
 567		return Model{}, Model{}, err
 568	}
 569
 570	return Model{
 571			Model:      largeModel,
 572			CatwalkCfg: *largeCatwalkModel,
 573			ModelCfg:   largeModelCfg,
 574		}, Model{
 575			Model:      smallModel,
 576			CatwalkCfg: *smallCatwalkModel,
 577			ModelCfg:   smallModelCfg,
 578		}, nil
 579}
 580
 581func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
 582	var opts []anthropic.Option
 583
 584	switch {
 585	case strings.HasPrefix(apiKey, "Bearer "):
 586		// NOTE: Prevent the SDK from picking up the API key from env.
 587		os.Setenv("ANTHROPIC_API_KEY", "")
 588		headers["Authorization"] = apiKey
 589	case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
 590		// NOTE: Prevent the SDK from picking up the API key from env.
 591		os.Setenv("ANTHROPIC_API_KEY", "")
 592		headers["Authorization"] = "Bearer " + apiKey
 593	case apiKey != "":
 594		// X-Api-Key header
 595		opts = append(opts, anthropic.WithAPIKey(apiKey))
 596	}
 597
 598	if len(headers) > 0 {
 599		opts = append(opts, anthropic.WithHeaders(headers))
 600	}
 601
 602	if baseURL != "" {
 603		opts = append(opts, anthropic.WithBaseURL(baseURL))
 604	}
 605
 606	if c.cfg.Options.Debug {
 607		httpClient := log.NewHTTPClient()
 608		opts = append(opts, anthropic.WithHTTPClient(httpClient))
 609	}
 610	return anthropic.New(opts...)
 611}
 612
 613func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 614	opts := []openai.Option{
 615		openai.WithAPIKey(apiKey),
 616		openai.WithUseResponsesAPI(),
 617	}
 618	if c.cfg.Options.Debug {
 619		httpClient := log.NewHTTPClient()
 620		opts = append(opts, openai.WithHTTPClient(httpClient))
 621	}
 622	if len(headers) > 0 {
 623		opts = append(opts, openai.WithHeaders(headers))
 624	}
 625	if baseURL != "" {
 626		opts = append(opts, openai.WithBaseURL(baseURL))
 627	}
 628	return openai.New(opts...)
 629}
 630
 631func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 632	opts := []openrouter.Option{
 633		openrouter.WithAPIKey(apiKey),
 634	}
 635	if c.cfg.Options.Debug {
 636		httpClient := log.NewHTTPClient()
 637		opts = append(opts, openrouter.WithHTTPClient(httpClient))
 638	}
 639	if len(headers) > 0 {
 640		opts = append(opts, openrouter.WithHeaders(headers))
 641	}
 642	return openrouter.New(opts...)
 643}
 644
 645func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 646	opts := []vercel.Option{
 647		vercel.WithAPIKey(apiKey),
 648	}
 649	if c.cfg.Options.Debug {
 650		httpClient := log.NewHTTPClient()
 651		opts = append(opts, vercel.WithHTTPClient(httpClient))
 652	}
 653	if len(headers) > 0 {
 654		opts = append(opts, vercel.WithHeaders(headers))
 655	}
 656	return vercel.New(opts...)
 657}
 658
 659func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
 660	opts := []openaicompat.Option{
 661		openaicompat.WithBaseURL(baseURL),
 662		openaicompat.WithAPIKey(apiKey),
 663	}
 664
 665	// Set HTTP client based on provider and debug mode.
 666	var httpClient *http.Client
 667	if providerID == string(catwalk.InferenceProviderCopilot) {
 668		opts = append(opts, openaicompat.WithUseResponsesAPI())
 669		httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
 670	} else if c.cfg.Options.Debug {
 671		httpClient = log.NewHTTPClient()
 672	}
 673	if httpClient != nil {
 674		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
 675	}
 676
 677	if len(headers) > 0 {
 678		opts = append(opts, openaicompat.WithHeaders(headers))
 679	}
 680
 681	for extraKey, extraValue := range extraBody {
 682		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
 683	}
 684
 685	return openaicompat.New(opts...)
 686}
 687
 688func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 689	opts := []azure.Option{
 690		azure.WithBaseURL(baseURL),
 691		azure.WithAPIKey(apiKey),
 692		azure.WithUseResponsesAPI(),
 693	}
 694	if c.cfg.Options.Debug {
 695		httpClient := log.NewHTTPClient()
 696		opts = append(opts, azure.WithHTTPClient(httpClient))
 697	}
 698	if options == nil {
 699		options = make(map[string]string)
 700	}
 701	if apiVersion, ok := options["apiVersion"]; ok {
 702		opts = append(opts, azure.WithAPIVersion(apiVersion))
 703	}
 704	if len(headers) > 0 {
 705		opts = append(opts, azure.WithHeaders(headers))
 706	}
 707
 708	return azure.New(opts...)
 709}
 710
 711func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
 712	var opts []bedrock.Option
 713	if c.cfg.Options.Debug {
 714		httpClient := log.NewHTTPClient()
 715		opts = append(opts, bedrock.WithHTTPClient(httpClient))
 716	}
 717	if len(headers) > 0 {
 718		opts = append(opts, bedrock.WithHeaders(headers))
 719	}
 720	bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
 721	if bearerToken != "" {
 722		opts = append(opts, bedrock.WithAPIKey(bearerToken))
 723	}
 724	return bedrock.New(opts...)
 725}
 726
 727func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 728	opts := []google.Option{
 729		google.WithBaseURL(baseURL),
 730		google.WithGeminiAPIKey(apiKey),
 731	}
 732	if c.cfg.Options.Debug {
 733		httpClient := log.NewHTTPClient()
 734		opts = append(opts, google.WithHTTPClient(httpClient))
 735	}
 736	if len(headers) > 0 {
 737		opts = append(opts, google.WithHeaders(headers))
 738	}
 739	return google.New(opts...)
 740}
 741
 742func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 743	opts := []google.Option{}
 744	if c.cfg.Options.Debug {
 745		httpClient := log.NewHTTPClient()
 746		opts = append(opts, google.WithHTTPClient(httpClient))
 747	}
 748	if len(headers) > 0 {
 749		opts = append(opts, google.WithHeaders(headers))
 750	}
 751
 752	project := options["project"]
 753	location := options["location"]
 754
 755	opts = append(opts, google.WithVertex(project, location))
 756
 757	return google.New(opts...)
 758}
 759
 760func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
 761	opts := []hyper.Option{
 762		hyper.WithBaseURL(baseURL),
 763		hyper.WithAPIKey(apiKey),
 764	}
 765	if c.cfg.Options.Debug {
 766		httpClient := log.NewHTTPClient()
 767		opts = append(opts, hyper.WithHTTPClient(httpClient))
 768	}
 769	return hyper.New(opts...)
 770}
 771
 772func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
 773	if model.Think {
 774		return true
 775	}
 776
 777	if model.ProviderOptions == nil {
 778		return false
 779	}
 780
 781	opts, err := anthropic.ParseOptions(model.ProviderOptions)
 782	if err != nil {
 783		return false
 784	}
 785	if opts.Thinking != nil {
 786		return true
 787	}
 788	return false
 789}
 790
 791func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
 792	headers := maps.Clone(providerCfg.ExtraHeaders)
 793	if headers == nil {
 794		headers = make(map[string]string)
 795	}
 796
 797	// handle special headers for anthropic
 798	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
 799		if v, ok := headers["anthropic-beta"]; ok {
 800			headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
 801		} else {
 802			headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
 803		}
 804	}
 805
 806	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 807	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
 808
 809	switch providerCfg.Type {
 810	case openai.Name:
 811		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 812	case anthropic.Name:
 813		return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
 814	case openrouter.Name:
 815		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 816	case vercel.Name:
 817		return c.buildVercelProvider(baseURL, apiKey, headers)
 818	case azure.Name:
 819		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 820	case bedrock.Name:
 821		return c.buildBedrockProvider(headers)
 822	case google.Name:
 823		return c.buildGoogleProvider(baseURL, apiKey, headers)
 824	case "google-vertex":
 825		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 826	case openaicompat.Name:
 827		if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
 828			if providerCfg.ExtraBody == nil {
 829				providerCfg.ExtraBody = map[string]any{}
 830			}
 831			providerCfg.ExtraBody["tool_stream"] = true
 832		}
 833		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
 834	case hyper.Name:
 835		return c.buildHyperProvider(baseURL, apiKey)
 836	default:
 837		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 838	}
 839}
 840
 841func isExactoSupported(modelID string) bool {
 842	supportedModels := []string{
 843		"moonshotai/kimi-k2-0905",
 844		"deepseek/deepseek-v3.1-terminus",
 845		"z-ai/glm-4.6",
 846		"openai/gpt-oss-120b",
 847		"qwen/qwen3-coder",
 848	}
 849	return slices.Contains(supportedModels, modelID)
 850}
 851
 852func (c *coordinator) Cancel(sessionID string) {
 853	c.currentAgent.Cancel(sessionID)
 854}
 855
 856func (c *coordinator) CancelAll() {
 857	c.currentAgent.CancelAll()
 858}
 859
 860func (c *coordinator) ClearQueue(sessionID string) {
 861	c.currentAgent.ClearQueue(sessionID)
 862}
 863
 864func (c *coordinator) IsBusy() bool {
 865	return c.currentAgent.IsBusy()
 866}
 867
 868func (c *coordinator) IsSessionBusy(sessionID string) bool {
 869	return c.currentAgent.IsSessionBusy(sessionID)
 870}
 871
 872func (c *coordinator) Model() Model {
 873	return c.currentAgent.Model()
 874}
 875
 876func (c *coordinator) UpdateModels(ctx context.Context) error {
 877	// build the models again so we make sure we get the latest config
 878	large, small, err := c.buildAgentModels(ctx, false)
 879	if err != nil {
 880		return err
 881	}
 882	c.currentAgent.SetModels(large, small)
 883
 884	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
 885	if !ok {
 886		return errors.New("coder agent not configured")
 887	}
 888
 889	tools, err := c.buildTools(ctx, agentCfg)
 890	if err != nil {
 891		return err
 892	}
 893	c.currentAgent.SetTools(tools)
 894	return nil
 895}
 896
 897func (c *coordinator) QueuedPrompts(sessionID string) int {
 898	return c.currentAgent.QueuedPrompts(sessionID)
 899}
 900
 901func (c *coordinator) QueuedPromptsList(sessionID string) []string {
 902	return c.currentAgent.QueuedPromptsList(sessionID)
 903}
 904
 905func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 906	providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
 907	if !ok {
 908		return errors.New("model provider not configured")
 909	}
 910	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 911}
 912
 913func (c *coordinator) isUnauthorized(err error) bool {
 914	var providerErr *fantasy.ProviderError
 915	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
 916}
 917
 918func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
 919	if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
 920		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
 921		return err
 922	}
 923	if err := c.UpdateModels(ctx); err != nil {
 924		return err
 925	}
 926	return nil
 927}
 928
 929func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
 930	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
 931	if err != nil {
 932		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
 933		return err
 934	}
 935
 936	providerCfg.APIKey = newAPIKey
 937	c.cfg.Providers.Set(providerCfg.ID, providerCfg)
 938
 939	if err := c.UpdateModels(ctx); err != nil {
 940		return err
 941	}
 942	return nil
 943}
 944
 945// subAgentParams holds the parameters for running a sub-agent.
 946type subAgentParams struct {
 947	Agent          SessionAgent
 948	SessionID      string
 949	AgentMessageID string
 950	ToolCallID     string
 951	Prompt         string
 952	SessionTitle   string
 953	// SessionSetup is an optional callback invoked after session creation
 954	// but before agent execution, for custom session configuration.
 955	SessionSetup func(sessionID string)
 956}
 957
 958// runSubAgent runs a sub-agent and handles session management and cost accumulation.
 959// It creates a sub-session, runs the agent with the given prompt, and propagates
 960// the cost to the parent session.
 961func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 962	// Create sub-session
 963	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
 964	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 965	if err != nil {
 966		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 967	}
 968
 969	// Call session setup function if provided
 970	if params.SessionSetup != nil {
 971		params.SessionSetup(session.ID)
 972	}
 973
 974	// Get model configuration
 975	model := params.Agent.Model()
 976	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 977	if model.ModelCfg.MaxTokens != 0 {
 978		maxTokens = model.ModelCfg.MaxTokens
 979	}
 980
 981	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
 982	if !ok {
 983		return fantasy.ToolResponse{}, errors.New("model provider not configured")
 984	}
 985
 986	// Run the agent
 987	result, err := params.Agent.Run(ctx, SessionAgentCall{
 988		SessionID:        session.ID,
 989		Prompt:           params.Prompt,
 990		MaxOutputTokens:  maxTokens,
 991		ProviderOptions:  getProviderOptions(model, providerCfg),
 992		Temperature:      model.ModelCfg.Temperature,
 993		TopP:             model.ModelCfg.TopP,
 994		TopK:             model.ModelCfg.TopK,
 995		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
 996		PresencePenalty:  model.ModelCfg.PresencePenalty,
 997	})
 998	if err != nil {
 999		return fantasy.NewTextErrorResponse("error generating response"), nil
1000	}
1001
1002	// Update parent session cost
1003	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1004		return fantasy.ToolResponse{}, err
1005	}
1006
1007	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1008}
1009
1010// updateParentSessionCost accumulates the cost from a child session to its parent session.
1011func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1012	childSession, err := c.sessions.Get(ctx, childSessionID)
1013	if err != nil {
1014		return fmt.Errorf("get child session: %w", err)
1015	}
1016
1017	parentSession, err := c.sessions.Get(ctx, parentSessionID)
1018	if err != nil {
1019		return fmt.Errorf("get parent session: %w", err)
1020	}
1021
1022	parentSession.Cost += childSession.Cost
1023
1024	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1025		return fmt.Errorf("save parent session: %w", err)
1026	}
1027
1028	return nil
1029}