1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/notify"
22 "github.com/charmbracelet/crush/internal/agent/prompt"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/filetracker"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/pubsub"
33 "github.com/charmbracelet/crush/internal/session"
34 "golang.org/x/sync/errgroup"
35
36 "charm.land/fantasy/providers/anthropic"
37 "charm.land/fantasy/providers/azure"
38 "charm.land/fantasy/providers/bedrock"
39 "charm.land/fantasy/providers/google"
40 "charm.land/fantasy/providers/openai"
41 "charm.land/fantasy/providers/openaicompat"
42 "charm.land/fantasy/providers/openrouter"
43 "charm.land/fantasy/providers/vercel"
44 openaisdk "github.com/openai/openai-go/v3/option"
45 "github.com/qjebbs/go-jsons"
46)
47
48// Coordinator errors.
49var (
50 errCoderAgentNotConfigured = errors.New("coder agent not configured")
51 errModelProviderNotConfigured = errors.New("model provider not configured")
52 errLargeModelNotSelected = errors.New("large model not selected")
53 errSmallModelNotSelected = errors.New("small model not selected")
54 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
55 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
56 errLargeModelNotFound = errors.New("large model not found in provider config")
57 errSmallModelNotFound = errors.New("small model not found in provider config")
58)
59
60type Coordinator interface {
61 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
62 // SetMainAgent(string)
63 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 QueuedPromptsList(sessionID string) []string
70 ClearQueue(sessionID string)
71 Summarize(context.Context, string) error
72 Model() Model
73 UpdateModels(ctx context.Context) error
74 RefreshTools(ctx context.Context) error
75}
76
77type coordinator struct {
78 cfg *config.ConfigStore
79 sessions session.Service
80 messages message.Service
81 permissions permission.Service
82 history history.Service
83 filetracker filetracker.Service
84 lspManager *lsp.Manager
85 notify pubsub.Publisher[notify.Notification]
86
87 currentAgent SessionAgent
88 agents map[string]SessionAgent
89
90 readyWg errgroup.Group
91}
92
93func NewCoordinator(
94 ctx context.Context,
95 cfg *config.ConfigStore,
96 sessions session.Service,
97 messages message.Service,
98 permissions permission.Service,
99 history history.Service,
100 filetracker filetracker.Service,
101 lspManager *lsp.Manager,
102 notify pubsub.Publisher[notify.Notification],
103) (Coordinator, error) {
104 c := &coordinator{
105 cfg: cfg,
106 sessions: sessions,
107 messages: messages,
108 permissions: permissions,
109 history: history,
110 filetracker: filetracker,
111 lspManager: lspManager,
112 notify: notify,
113 agents: make(map[string]SessionAgent),
114 }
115
116 agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
117 if !ok {
118 return nil, errCoderAgentNotConfigured
119 }
120
121 // TODO: make this dynamic when we support multiple agents
122 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
123 if err != nil {
124 return nil, err
125 }
126
127 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
128 if err != nil {
129 return nil, err
130 }
131 c.currentAgent = agent
132 c.agents[config.AgentCoder] = agent
133 return c, nil
134}
135
136// Run implements Coordinator.
137func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
138 if err := c.readyWg.Wait(); err != nil {
139 return nil, err
140 }
141
142 // refresh models before each run
143 if err := c.UpdateModels(ctx); err != nil {
144 return nil, fmt.Errorf("failed to update models: %w", err)
145 }
146
147 model := c.currentAgent.Model()
148 maxTokens := model.CatwalkCfg.DefaultMaxTokens
149 if model.ModelCfg.MaxTokens != 0 {
150 maxTokens = model.ModelCfg.MaxTokens
151 }
152
153 if !model.CatwalkCfg.SupportsImages && attachments != nil {
154 // filter out image attachments
155 filteredAttachments := make([]message.Attachment, 0, len(attachments))
156 for _, att := range attachments {
157 if att.IsText() {
158 filteredAttachments = append(filteredAttachments, att)
159 }
160 }
161 attachments = filteredAttachments
162 }
163
164 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
165 if !ok {
166 return nil, errModelProviderNotConfigured
167 }
168
169 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
170
171 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
172 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
173 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
174 return nil, err
175 }
176 }
177
178 run := func() (*fantasy.AgentResult, error) {
179 return c.currentAgent.Run(ctx, SessionAgentCall{
180 SessionID: sessionID,
181 Prompt: prompt,
182 Attachments: attachments,
183 MaxOutputTokens: maxTokens,
184 ProviderOptions: mergedOptions,
185 Temperature: temp,
186 TopP: topP,
187 TopK: topK,
188 FrequencyPenalty: freqPenalty,
189 PresencePenalty: presPenalty,
190 })
191 }
192 result, originalErr := run()
193
194 if c.isUnauthorized(originalErr) {
195 switch {
196 case providerCfg.OAuthToken != nil:
197 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
198 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
199 return nil, originalErr
200 }
201 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
202 return run()
203 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
204 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
205 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
206 return nil, originalErr
207 }
208 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
209 return run()
210 }
211 }
212
213 return result, originalErr
214}
215
216func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
217 options := fantasy.ProviderOptions{}
218
219 cfgOpts := []byte("{}")
220 providerCfgOpts := []byte("{}")
221 catwalkOpts := []byte("{}")
222
223 if model.ModelCfg.ProviderOptions != nil {
224 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
225 if err == nil {
226 cfgOpts = data
227 }
228 }
229
230 if providerCfg.ProviderOptions != nil {
231 data, err := json.Marshal(providerCfg.ProviderOptions)
232 if err == nil {
233 providerCfgOpts = data
234 }
235 }
236
237 if model.CatwalkCfg.Options.ProviderOptions != nil {
238 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
239 if err == nil {
240 catwalkOpts = data
241 }
242 }
243
244 readers := []io.Reader{
245 bytes.NewReader(catwalkOpts),
246 bytes.NewReader(providerCfgOpts),
247 bytes.NewReader(cfgOpts),
248 }
249
250 got, err := jsons.Merge(readers)
251 if err != nil {
252 slog.Error("Could not merge call config", "err", err)
253 return options
254 }
255
256 mergedOptions := make(map[string]any)
257
258 err = json.Unmarshal([]byte(got), &mergedOptions)
259 if err != nil {
260 slog.Error("Could not create config for call", "err", err)
261 return options
262 }
263
264 providerType := providerCfg.Type
265 if providerType == "hyper" {
266 if strings.Contains(model.CatwalkCfg.ID, "claude") {
267 providerType = anthropic.Name
268 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
269 providerType = openai.Name
270 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
271 providerType = google.Name
272 } else {
273 providerType = openaicompat.Name
274 }
275 }
276
277 switch providerType {
278 case openai.Name, azure.Name:
279 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
280 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
281 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
282 }
283 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
284 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
285 mergedOptions["reasoning_summary"] = "auto"
286 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
287 }
288 parsed, err := openai.ParseResponsesOptions(mergedOptions)
289 if err == nil {
290 options[openai.Name] = parsed
291 }
292 } else {
293 parsed, err := openai.ParseOptions(mergedOptions)
294 if err == nil {
295 options[openai.Name] = parsed
296 }
297 }
298 case anthropic.Name:
299 var (
300 _, hasEffort = mergedOptions["effort"]
301 _, hasThink = mergedOptions["thinking"]
302 )
303 switch {
304 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
305 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
306 case !hasThink && model.ModelCfg.Think:
307 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
308 }
309 parsed, err := anthropic.ParseOptions(mergedOptions)
310 if err == nil {
311 options[anthropic.Name] = parsed
312 }
313
314 case openrouter.Name:
315 _, hasReasoning := mergedOptions["reasoning"]
316 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
317 mergedOptions["reasoning"] = map[string]any{
318 "enabled": true,
319 "effort": model.ModelCfg.ReasoningEffort,
320 }
321 }
322 parsed, err := openrouter.ParseOptions(mergedOptions)
323 if err == nil {
324 options[openrouter.Name] = parsed
325 }
326 case vercel.Name:
327 _, hasReasoning := mergedOptions["reasoning"]
328 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
329 mergedOptions["reasoning"] = map[string]any{
330 "enabled": true,
331 "effort": model.ModelCfg.ReasoningEffort,
332 }
333 }
334 parsed, err := vercel.ParseOptions(mergedOptions)
335 if err == nil {
336 options[vercel.Name] = parsed
337 }
338 case google.Name:
339 _, hasReasoning := mergedOptions["thinking_config"]
340 if !hasReasoning {
341 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
342 mergedOptions["thinking_config"] = map[string]any{
343 "thinking_budget": 2000,
344 "include_thoughts": true,
345 }
346 } else {
347 mergedOptions["thinking_config"] = map[string]any{
348 "thinking_level": model.ModelCfg.ReasoningEffort,
349 "include_thoughts": true,
350 }
351 }
352 }
353 parsed, err := google.ParseOptions(mergedOptions)
354 if err == nil {
355 options[google.Name] = parsed
356 }
357 case openaicompat.Name:
358 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
359 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
360 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
361 }
362 parsed, err := openaicompat.ParseOptions(mergedOptions)
363 if err == nil {
364 options[openaicompat.Name] = parsed
365 }
366 }
367
368 return options
369}
370
371func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
372 modelOptions := getProviderOptions(model, cfg)
373 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
374 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
375 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
376 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
377 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
378 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
379}
380
381func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
382 large, small, err := c.buildAgentModels(ctx, isSubAgent)
383 if err != nil {
384 return nil, err
385 }
386
387 largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
388 result := NewSessionAgent(SessionAgentOptions{
389 LargeModel: large,
390 SmallModel: small,
391 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
392 SystemPrompt: "",
393 IsSubAgent: isSubAgent,
394 DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
395 IsYolo: c.permissions.SkipRequests(),
396 Sessions: c.sessions,
397 Messages: c.messages,
398 Tools: nil,
399 Notify: c.notify,
400 })
401
402 c.readyWg.Go(func() error {
403 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
404 if err != nil {
405 return err
406 }
407 result.SetSystemPrompt(systemPrompt)
408 return nil
409 })
410
411 c.readyWg.Go(func() error {
412 tools, err := c.buildTools(ctx, agent)
413 if err != nil {
414 return err
415 }
416 result.SetTools(tools)
417 return nil
418 })
419
420 return result, nil
421}
422
423func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
424 var allTools []fantasy.AgentTool
425 if slices.Contains(agent.AllowedTools, AgentToolName) {
426 agentTool, err := c.agentTool(ctx)
427 if err != nil {
428 return nil, err
429 }
430 allTools = append(allTools, agentTool)
431 }
432
433 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
434 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
435 if err != nil {
436 return nil, err
437 }
438 allTools = append(allTools, agenticFetchTool)
439 }
440
441 // Get the model name for the agent
442 modelName := ""
443 if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
444 if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
445 modelName = model.Name
446 }
447 }
448
449 allTools = append(allTools,
450 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
451 tools.NewJobOutputTool(),
452 tools.NewJobKillTool(),
453 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
454 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
455 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
456 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
457 tools.NewGlobTool(c.cfg.WorkingDir()),
458 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
459 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
460 tools.NewSourcegraphTool(nil),
461 tools.NewTodosTool(c.sessions),
462 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
463 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
464 )
465
466 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
467 if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
468 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
469 }
470
471 if len(c.cfg.Config().MCP) > 0 {
472 allTools = append(
473 allTools,
474 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
475 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
476 )
477 }
478
479 var filteredTools []fantasy.AgentTool
480 for _, tool := range allTools {
481 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
482 filteredTools = append(filteredTools, tool)
483 }
484 }
485
486 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
487 if agent.AllowedMCP == nil {
488 // No MCP restrictions
489 filteredTools = append(filteredTools, tool)
490 continue
491 }
492 if len(agent.AllowedMCP) == 0 {
493 // No MCPs allowed
494 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
495 break
496 }
497
498 for mcp, tools := range agent.AllowedMCP {
499 if mcp != tool.MCP() {
500 continue
501 }
502 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
503 filteredTools = append(filteredTools, tool)
504 break
505 }
506 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
507 }
508 }
509 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
510 return strings.Compare(a.Info().Name, b.Info().Name)
511 })
512 return filteredTools, nil
513}
514
515// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
516func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
517 largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
518 if !ok {
519 return Model{}, Model{}, errLargeModelNotSelected
520 }
521 smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
522 if !ok {
523 return Model{}, Model{}, errSmallModelNotSelected
524 }
525
526 largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
527 if !ok {
528 return Model{}, Model{}, errLargeModelProviderNotConfigured
529 }
530
531 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
532 if err != nil {
533 return Model{}, Model{}, err
534 }
535
536 smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
537 if !ok {
538 return Model{}, Model{}, errSmallModelProviderNotConfigured
539 }
540
541 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
542 if err != nil {
543 return Model{}, Model{}, err
544 }
545
546 var largeCatwalkModel *catwalk.Model
547 var smallCatwalkModel *catwalk.Model
548
549 for _, m := range largeProviderCfg.Models {
550 if m.ID == largeModelCfg.Model {
551 largeCatwalkModel = &m
552 }
553 }
554 for _, m := range smallProviderCfg.Models {
555 if m.ID == smallModelCfg.Model {
556 smallCatwalkModel = &m
557 }
558 }
559
560 if largeCatwalkModel == nil {
561 return Model{}, Model{}, errLargeModelNotFound
562 }
563
564 if smallCatwalkModel == nil {
565 return Model{}, Model{}, errSmallModelNotFound
566 }
567
568 largeModelID := largeModelCfg.Model
569 smallModelID := smallModelCfg.Model
570
571 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
572 largeModelID += ":exacto"
573 }
574
575 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
576 smallModelID += ":exacto"
577 }
578
579 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
580 if err != nil {
581 return Model{}, Model{}, err
582 }
583 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
584 if err != nil {
585 return Model{}, Model{}, err
586 }
587
588 return Model{
589 Model: largeModel,
590 CatwalkCfg: *largeCatwalkModel,
591 ModelCfg: largeModelCfg,
592 }, Model{
593 Model: smallModel,
594 CatwalkCfg: *smallCatwalkModel,
595 ModelCfg: smallModelCfg,
596 }, nil
597}
598
599func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
600 var opts []anthropic.Option
601
602 switch {
603 case strings.HasPrefix(apiKey, "Bearer "):
604 // NOTE: Prevent the SDK from picking up the API key from env.
605 os.Setenv("ANTHROPIC_API_KEY", "")
606 headers["Authorization"] = apiKey
607 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
608 // NOTE: Prevent the SDK from picking up the API key from env.
609 os.Setenv("ANTHROPIC_API_KEY", "")
610 headers["Authorization"] = "Bearer " + apiKey
611 case apiKey != "":
612 // X-Api-Key header
613 opts = append(opts, anthropic.WithAPIKey(apiKey))
614 }
615
616 if len(headers) > 0 {
617 opts = append(opts, anthropic.WithHeaders(headers))
618 }
619
620 if baseURL != "" {
621 opts = append(opts, anthropic.WithBaseURL(baseURL))
622 }
623
624 if c.cfg.Config().Options.Debug {
625 httpClient := log.NewHTTPClient()
626 opts = append(opts, anthropic.WithHTTPClient(httpClient))
627 }
628 return anthropic.New(opts...)
629}
630
631func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
632 opts := []openai.Option{
633 openai.WithAPIKey(apiKey),
634 openai.WithUseResponsesAPI(),
635 }
636 if c.cfg.Config().Options.Debug {
637 httpClient := log.NewHTTPClient()
638 opts = append(opts, openai.WithHTTPClient(httpClient))
639 }
640 if len(headers) > 0 {
641 opts = append(opts, openai.WithHeaders(headers))
642 }
643 if baseURL != "" {
644 opts = append(opts, openai.WithBaseURL(baseURL))
645 }
646 return openai.New(opts...)
647}
648
649func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
650 opts := []openrouter.Option{
651 openrouter.WithAPIKey(apiKey),
652 }
653 if c.cfg.Config().Options.Debug {
654 httpClient := log.NewHTTPClient()
655 opts = append(opts, openrouter.WithHTTPClient(httpClient))
656 }
657 if len(headers) > 0 {
658 opts = append(opts, openrouter.WithHeaders(headers))
659 }
660 return openrouter.New(opts...)
661}
662
663func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
664 opts := []vercel.Option{
665 vercel.WithAPIKey(apiKey),
666 }
667 if c.cfg.Config().Options.Debug {
668 httpClient := log.NewHTTPClient()
669 opts = append(opts, vercel.WithHTTPClient(httpClient))
670 }
671 if len(headers) > 0 {
672 opts = append(opts, vercel.WithHeaders(headers))
673 }
674 return vercel.New(opts...)
675}
676
677func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
678 opts := []openaicompat.Option{
679 openaicompat.WithBaseURL(baseURL),
680 openaicompat.WithAPIKey(apiKey),
681 }
682
683 // Set HTTP client based on provider and debug mode.
684 var httpClient *http.Client
685 if providerID == string(catwalk.InferenceProviderCopilot) {
686 opts = append(opts, openaicompat.WithUseResponsesAPI())
687 httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
688 } else if c.cfg.Config().Options.Debug {
689 httpClient = log.NewHTTPClient()
690 }
691 if httpClient != nil {
692 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
693 }
694
695 if len(headers) > 0 {
696 opts = append(opts, openaicompat.WithHeaders(headers))
697 }
698
699 for extraKey, extraValue := range extraBody {
700 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
701 }
702
703 return openaicompat.New(opts...)
704}
705
706func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
707 opts := []azure.Option{
708 azure.WithBaseURL(baseURL),
709 azure.WithAPIKey(apiKey),
710 azure.WithUseResponsesAPI(),
711 }
712 if c.cfg.Config().Options.Debug {
713 httpClient := log.NewHTTPClient()
714 opts = append(opts, azure.WithHTTPClient(httpClient))
715 }
716 if options == nil {
717 options = make(map[string]string)
718 }
719 if apiVersion, ok := options["apiVersion"]; ok {
720 opts = append(opts, azure.WithAPIVersion(apiVersion))
721 }
722 if len(headers) > 0 {
723 opts = append(opts, azure.WithHeaders(headers))
724 }
725
726 return azure.New(opts...)
727}
728
729func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
730 var opts []bedrock.Option
731 if c.cfg.Config().Options.Debug {
732 httpClient := log.NewHTTPClient()
733 opts = append(opts, bedrock.WithHTTPClient(httpClient))
734 }
735 if len(headers) > 0 {
736 opts = append(opts, bedrock.WithHeaders(headers))
737 }
738 switch {
739 case apiKey != "":
740 opts = append(opts, bedrock.WithAPIKey(apiKey))
741 case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
742 opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
743 default:
744 // Skip, let the SDK do authentication.
745 }
746 return bedrock.New(opts...)
747}
748
749func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
750 opts := []google.Option{
751 google.WithBaseURL(baseURL),
752 google.WithGeminiAPIKey(apiKey),
753 }
754 if c.cfg.Config().Options.Debug {
755 httpClient := log.NewHTTPClient()
756 opts = append(opts, google.WithHTTPClient(httpClient))
757 }
758 if len(headers) > 0 {
759 opts = append(opts, google.WithHeaders(headers))
760 }
761 return google.New(opts...)
762}
763
764func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
765 opts := []google.Option{}
766 if c.cfg.Config().Options.Debug {
767 httpClient := log.NewHTTPClient()
768 opts = append(opts, google.WithHTTPClient(httpClient))
769 }
770 if len(headers) > 0 {
771 opts = append(opts, google.WithHeaders(headers))
772 }
773
774 project := options["project"]
775 location := options["location"]
776
777 opts = append(opts, google.WithVertex(project, location))
778
779 return google.New(opts...)
780}
781
782func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
783 opts := []hyper.Option{
784 hyper.WithBaseURL(baseURL),
785 hyper.WithAPIKey(apiKey),
786 }
787 if c.cfg.Config().Options.Debug {
788 httpClient := log.NewHTTPClient()
789 opts = append(opts, hyper.WithHTTPClient(httpClient))
790 }
791 return hyper.New(opts...)
792}
793
794func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
795 if model.Think {
796 return true
797 }
798 opts, err := anthropic.ParseOptions(model.ProviderOptions)
799 return err == nil && opts.Thinking != nil
800}
801
802func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
803 headers := maps.Clone(providerCfg.ExtraHeaders)
804 if headers == nil {
805 headers = make(map[string]string)
806 }
807
808 // handle special headers for anthropic
809 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
810 if v, ok := headers["anthropic-beta"]; ok {
811 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
812 } else {
813 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
814 }
815 }
816
817 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
818 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
819
820 switch providerCfg.Type {
821 case openai.Name:
822 return c.buildOpenaiProvider(baseURL, apiKey, headers)
823 case anthropic.Name:
824 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
825 case openrouter.Name:
826 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
827 case vercel.Name:
828 return c.buildVercelProvider(baseURL, apiKey, headers)
829 case azure.Name:
830 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
831 case bedrock.Name:
832 return c.buildBedrockProvider(apiKey, headers)
833 case google.Name:
834 return c.buildGoogleProvider(baseURL, apiKey, headers)
835 case "google-vertex":
836 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
837 case openaicompat.Name:
838 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
839 if providerCfg.ExtraBody == nil {
840 providerCfg.ExtraBody = map[string]any{}
841 }
842 providerCfg.ExtraBody["tool_stream"] = true
843 }
844 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
845 case hyper.Name:
846 return c.buildHyperProvider(baseURL, apiKey)
847 default:
848 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
849 }
850}
851
852func isExactoSupported(modelID string) bool {
853 supportedModels := []string{
854 "moonshotai/kimi-k2-0905",
855 "deepseek/deepseek-v3.1-terminus",
856 "z-ai/glm-4.6",
857 "openai/gpt-oss-120b",
858 "qwen/qwen3-coder",
859 }
860 return slices.Contains(supportedModels, modelID)
861}
862
863func (c *coordinator) Cancel(sessionID string) {
864 c.currentAgent.Cancel(sessionID)
865}
866
867func (c *coordinator) CancelAll() {
868 c.currentAgent.CancelAll()
869}
870
871func (c *coordinator) ClearQueue(sessionID string) {
872 c.currentAgent.ClearQueue(sessionID)
873}
874
875func (c *coordinator) IsBusy() bool {
876 return c.currentAgent.IsBusy()
877}
878
879func (c *coordinator) IsSessionBusy(sessionID string) bool {
880 return c.currentAgent.IsSessionBusy(sessionID)
881}
882
883func (c *coordinator) Model() Model {
884 return c.currentAgent.Model()
885}
886
887func (c *coordinator) UpdateModels(ctx context.Context) error {
888 // build the models again so we make sure we get the latest config
889 large, small, err := c.buildAgentModels(ctx, false)
890 if err != nil {
891 return err
892 }
893 c.currentAgent.SetModels(large, small)
894
895 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
896 if !ok {
897 return errCoderAgentNotConfigured
898 }
899
900 tools, err := c.buildTools(ctx, agentCfg)
901 if err != nil {
902 return err
903 }
904 c.currentAgent.SetTools(tools)
905 return nil
906}
907
908func (c *coordinator) RefreshTools(ctx context.Context) error {
909 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
910 if !ok {
911 return errors.New("coder agent not configured")
912 }
913
914 tools, err := c.buildTools(ctx, agentCfg)
915 if err != nil {
916 return err
917 }
918 c.currentAgent.SetTools(tools)
919 slog.Debug("refreshed agent tools", "count", len(tools))
920 return nil
921}
922
923func (c *coordinator) QueuedPrompts(sessionID string) int {
924 return c.currentAgent.QueuedPrompts(sessionID)
925}
926
927func (c *coordinator) QueuedPromptsList(sessionID string) []string {
928 return c.currentAgent.QueuedPromptsList(sessionID)
929}
930
931func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
932 providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
933 if !ok {
934 return errModelProviderNotConfigured
935 }
936 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
937}
938
939func (c *coordinator) isUnauthorized(err error) bool {
940 var providerErr *fantasy.ProviderError
941 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
942}
943
944func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
945 if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
946 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
947 return err
948 }
949 if err := c.UpdateModels(ctx); err != nil {
950 return err
951 }
952 return nil
953}
954
955func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
956 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
957 if err != nil {
958 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
959 return err
960 }
961
962 providerCfg.APIKey = newAPIKey
963 c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
964
965 if err := c.UpdateModels(ctx); err != nil {
966 return err
967 }
968 return nil
969}
970
971// subAgentParams holds the parameters for running a sub-agent.
972type subAgentParams struct {
973 Agent SessionAgent
974 SessionID string
975 AgentMessageID string
976 ToolCallID string
977 Prompt string
978 SessionTitle string
979 // SessionSetup is an optional callback invoked after session creation
980 // but before agent execution, for custom session configuration.
981 SessionSetup func(sessionID string)
982}
983
984// runSubAgent runs a sub-agent and handles session management and cost accumulation.
985// It creates a sub-session, runs the agent with the given prompt, and propagates
986// the cost to the parent session.
987func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
988 // Create sub-session
989 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
990 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
991 if err != nil {
992 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
993 }
994
995 // Call session setup function if provided
996 if params.SessionSetup != nil {
997 params.SessionSetup(session.ID)
998 }
999
1000 // Get model configuration
1001 model := params.Agent.Model()
1002 maxTokens := model.CatwalkCfg.DefaultMaxTokens
1003 if model.ModelCfg.MaxTokens != 0 {
1004 maxTokens = model.ModelCfg.MaxTokens
1005 }
1006
1007 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
1008 if !ok {
1009 return fantasy.ToolResponse{}, errModelProviderNotConfigured
1010 }
1011
1012 // Run the agent
1013 result, err := params.Agent.Run(ctx, SessionAgentCall{
1014 SessionID: session.ID,
1015 Prompt: params.Prompt,
1016 MaxOutputTokens: maxTokens,
1017 ProviderOptions: getProviderOptions(model, providerCfg),
1018 Temperature: model.ModelCfg.Temperature,
1019 TopP: model.ModelCfg.TopP,
1020 TopK: model.ModelCfg.TopK,
1021 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1022 PresencePenalty: model.ModelCfg.PresencePenalty,
1023 NonInteractive: true,
1024 })
1025 if err != nil {
1026 return fantasy.NewTextErrorResponse("error generating response"), nil
1027 }
1028
1029 // Update parent session cost
1030 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1031 return fantasy.ToolResponse{}, err
1032 }
1033
1034 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1035}
1036
1037// updateParentSessionCost accumulates the cost from a child session to its parent session.
1038func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1039 childSession, err := c.sessions.Get(ctx, childSessionID)
1040 if err != nil {
1041 return fmt.Errorf("get child session: %w", err)
1042 }
1043
1044 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1045 if err != nil {
1046 return fmt.Errorf("get parent session: %w", err)
1047 }
1048
1049 parentSession.Cost += childSession.Cost
1050
1051 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1052 return fmt.Errorf("save parent session: %w", err)
1053 }
1054
1055 return nil
1056}