1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "path/filepath"
16 "slices"
17 "strings"
18
19 "charm.land/catwalk/pkg/catwalk"
20 "charm.land/fantasy"
21 "github.com/charmbracelet/crush/internal/agent/hyper"
22 "github.com/charmbracelet/crush/internal/agent/notify"
23 "github.com/charmbracelet/crush/internal/agent/prompt"
24 "github.com/charmbracelet/crush/internal/agent/tools"
25 "github.com/charmbracelet/crush/internal/config"
26 "github.com/charmbracelet/crush/internal/event"
27 "github.com/charmbracelet/crush/internal/filetracker"
28 "github.com/charmbracelet/crush/internal/history"
29 "github.com/charmbracelet/crush/internal/home"
30 "github.com/charmbracelet/crush/internal/hooks"
31 "github.com/charmbracelet/crush/internal/log"
32 "github.com/charmbracelet/crush/internal/lsp"
33 "github.com/charmbracelet/crush/internal/message"
34 "github.com/charmbracelet/crush/internal/oauth/copilot"
35 "github.com/charmbracelet/crush/internal/permission"
36 "github.com/charmbracelet/crush/internal/pubsub"
37 "github.com/charmbracelet/crush/internal/session"
38 "github.com/charmbracelet/crush/internal/skills"
39 "golang.org/x/sync/errgroup"
40
41 "charm.land/fantasy/providers/anthropic"
42 "charm.land/fantasy/providers/azure"
43 "charm.land/fantasy/providers/bedrock"
44 "charm.land/fantasy/providers/google"
45 "charm.land/fantasy/providers/openai"
46 "charm.land/fantasy/providers/openaicompat"
47 "charm.land/fantasy/providers/openrouter"
48 "charm.land/fantasy/providers/vercel"
49 openaisdk "github.com/charmbracelet/openai-go/option"
50 "github.com/qjebbs/go-jsons"
51)
52
53// Coordinator errors.
54var (
55 errCoderAgentNotConfigured = errors.New("coder agent not configured")
56 errModelProviderNotConfigured = errors.New("model provider not configured")
57 errLargeModelNotSelected = errors.New("large model not selected")
58 errSmallModelNotSelected = errors.New("small model not selected")
59 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
60 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
61 errLargeModelNotFound = errors.New("large model not found in provider config")
62 errSmallModelNotFound = errors.New("small model not found in provider config")
63)
64
65type Coordinator interface {
66 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
67 // SetMainAgent(string)
68 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
69 Cancel(sessionID string)
70 CancelAll()
71 IsSessionBusy(sessionID string) bool
72 IsBusy() bool
73 QueuedPrompts(sessionID string) int
74 QueuedPromptsList(sessionID string) []string
75 ClearQueue(sessionID string)
76 Summarize(context.Context, string) error
77 Model() Model
78 UpdateModels(ctx context.Context) error
79}
80
81type coordinator struct {
82 cfg *config.ConfigStore
83 sessions session.Service
84 messages message.Service
85 permissions permission.Service
86 history history.Service
87 filetracker filetracker.Service
88 lspManager *lsp.Manager
89 notify pubsub.Publisher[notify.Notification]
90
91 currentAgent SessionAgent
92 agents map[string]SessionAgent
93
94 // Skills discovery results (session-start snapshot).
95 allSkills []*skills.Skill // Pre-filter: all discovered after dedup.
96 activeSkills []*skills.Skill // Post-filter: active skills only.
97 skillTracker *skills.Tracker
98
99 readyWg errgroup.Group
100}
101
102func NewCoordinator(
103 ctx context.Context,
104 cfg *config.ConfigStore,
105 sessions session.Service,
106 messages message.Service,
107 permissions permission.Service,
108 history history.Service,
109 filetracker filetracker.Service,
110 lspManager *lsp.Manager,
111 notify pubsub.Publisher[notify.Notification],
112) (Coordinator, error) {
113 // Discover skills once at session start.
114 allSkills, activeSkills := discoverSkills(cfg)
115 skillTracker := skills.NewTracker(activeSkills)
116
117 c := &coordinator{
118 cfg: cfg,
119 sessions: sessions,
120 messages: messages,
121 permissions: permissions,
122 history: history,
123 filetracker: filetracker,
124 lspManager: lspManager,
125 notify: notify,
126 agents: make(map[string]SessionAgent),
127 allSkills: allSkills,
128 activeSkills: activeSkills,
129 skillTracker: skillTracker,
130 }
131
132 agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
133 if !ok {
134 return nil, errCoderAgentNotConfigured
135 }
136
137 // TODO: make this dynamic when we support multiple agents
138 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
139 if err != nil {
140 return nil, err
141 }
142
143 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
144 if err != nil {
145 return nil, err
146 }
147 c.currentAgent = agent
148 c.agents[config.AgentCoder] = agent
149 return c, nil
150}
151
152// Run implements Coordinator.
153func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
154 if err := c.readyWg.Wait(); err != nil {
155 return nil, err
156 }
157
158 // refresh models before each run
159 if err := c.UpdateModels(ctx); err != nil {
160 return nil, fmt.Errorf("failed to update models: %w", err)
161 }
162
163 model := c.currentAgent.Model()
164 maxTokens := model.CatwalkCfg.DefaultMaxTokens
165 if model.ModelCfg.MaxTokens != 0 {
166 maxTokens = model.ModelCfg.MaxTokens
167 }
168
169 if !model.CatwalkCfg.SupportsImages && attachments != nil {
170 // filter out image attachments
171 filteredAttachments := make([]message.Attachment, 0, len(attachments))
172 for _, att := range attachments {
173 if att.IsText() {
174 filteredAttachments = append(filteredAttachments, att)
175 }
176 }
177 attachments = filteredAttachments
178 }
179
180 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
181 if !ok {
182 return nil, errModelProviderNotConfigured
183 }
184
185 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
186
187 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
188 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
189 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
190 // NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate
191 // depends on the flow below. If refresh fails, proceed with the token we have.
192 slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err)
193 }
194 }
195
196 run := func() (*fantasy.AgentResult, error) {
197 return c.currentAgent.Run(ctx, SessionAgentCall{
198 SessionID: sessionID,
199 Prompt: prompt,
200 Attachments: attachments,
201 MaxOutputTokens: maxTokens,
202 ProviderOptions: mergedOptions,
203 Temperature: temp,
204 TopP: topP,
205 TopK: topK,
206 FrequencyPenalty: freqPenalty,
207 PresencePenalty: presPenalty,
208 })
209 }
210 beforeLoaded := c.skillTracker.LoadedNames()
211 result, originalErr := run()
212 logTurnSkillUsage(sessionID, prompt, c.activeSkills, c.skillTracker, beforeLoaded)
213
214 if c.isUnauthorized(originalErr) {
215 switch {
216 case providerCfg.OAuthToken != nil:
217 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
218 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
219 return nil, originalErr
220 }
221 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
222 return run()
223 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
224 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
225 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
226 return nil, originalErr
227 }
228 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
229 return run()
230 }
231 }
232
233 return result, originalErr
234}
235
236func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
237 options := fantasy.ProviderOptions{}
238
239 cfgOpts := []byte("{}")
240 providerCfgOpts := []byte("{}")
241 catwalkOpts := []byte("{}")
242
243 if model.ModelCfg.ProviderOptions != nil {
244 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
245 if err == nil {
246 cfgOpts = data
247 }
248 }
249
250 if providerCfg.ProviderOptions != nil {
251 data, err := json.Marshal(providerCfg.ProviderOptions)
252 if err == nil {
253 providerCfgOpts = data
254 }
255 }
256
257 if model.CatwalkCfg.Options.ProviderOptions != nil {
258 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
259 if err == nil {
260 catwalkOpts = data
261 }
262 }
263
264 readers := []io.Reader{
265 bytes.NewReader(catwalkOpts),
266 bytes.NewReader(providerCfgOpts),
267 bytes.NewReader(cfgOpts),
268 }
269
270 got, err := jsons.Merge(readers)
271 if err != nil {
272 slog.Error("Could not merge call config", "err", err)
273 return options
274 }
275
276 mergedOptions := make(map[string]any)
277
278 err = json.Unmarshal([]byte(got), &mergedOptions)
279 if err != nil {
280 slog.Error("Could not create config for call", "err", err)
281 return options
282 }
283
284 switch providerCfg.Type {
285 case openai.Name, azure.Name:
286 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
287 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
288 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
289 }
290 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
291 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
292 mergedOptions["reasoning_summary"] = "auto"
293 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
294 }
295 parsed, err := openai.ParseResponsesOptions(mergedOptions)
296 if err == nil {
297 options[openai.Name] = parsed
298 }
299 } else {
300 parsed, err := openai.ParseOptions(mergedOptions)
301 if err == nil {
302 options[openai.Name] = parsed
303 }
304 }
305 case anthropic.Name:
306 var (
307 _, hasEffort = mergedOptions["effort"]
308 _, hasThink = mergedOptions["thinking"]
309 )
310 switch {
311 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
312 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
313 case !hasThink && model.ModelCfg.Think:
314 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
315 }
316 parsed, err := anthropic.ParseOptions(mergedOptions)
317 if err == nil {
318 options[anthropic.Name] = parsed
319 }
320
321 case openrouter.Name:
322 _, hasReasoning := mergedOptions["reasoning"]
323 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
324 mergedOptions["reasoning"] = map[string]any{
325 "enabled": true,
326 "effort": model.ModelCfg.ReasoningEffort,
327 }
328 }
329 parsed, err := openrouter.ParseOptions(mergedOptions)
330 if err == nil {
331 options[openrouter.Name] = parsed
332 }
333 case vercel.Name:
334 _, hasReasoning := mergedOptions["reasoning"]
335 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
336 mergedOptions["reasoning"] = map[string]any{
337 "enabled": true,
338 "effort": model.ModelCfg.ReasoningEffort,
339 }
340 }
341 parsed, err := vercel.ParseOptions(mergedOptions)
342 if err == nil {
343 options[vercel.Name] = parsed
344 }
345 case google.Name:
346 _, hasReasoning := mergedOptions["thinking_config"]
347 if !hasReasoning {
348 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
349 mergedOptions["thinking_config"] = map[string]any{
350 "thinking_budget": 2000,
351 "include_thoughts": true,
352 }
353 } else {
354 mergedOptions["thinking_config"] = map[string]any{
355 "thinking_level": model.ModelCfg.ReasoningEffort,
356 "include_thoughts": true,
357 }
358 }
359 }
360 parsed, err := google.ParseOptions(mergedOptions)
361 if err == nil {
362 options[google.Name] = parsed
363 }
364 case openaicompat.Name, hyper.Name:
365 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
366 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
367 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
368 }
369
370 extraBody := make(map[string]any)
371
372 // "reasoning effort" is a standard OpenAI field, but "thinking" is not.
373 // Setting it in the right way for each provider.
374 // TODO: Abstract this in Fantasy somehow?
375 // TODO: Allow custom providers to specify how to set this?
376 switch providerCfg.ID {
377 case hyper.Name:
378 extraBody["thinking"] = model.ModelCfg.Think
379 case string(catwalk.InferenceProviderIoNet):
380 extraBody["chat_template_kwargs"] = map[string]any{
381 "thinking": model.ModelCfg.Think,
382 }
383 case string(catwalk.InferenceProviderZAI):
384 if model.ModelCfg.Think {
385 extraBody["thinking"] = map[string]any{
386 "type": "enabled",
387 }
388 } else {
389 extraBody["thinking"] = map[string]any{
390 "type": "disabled",
391 }
392 }
393 }
394
395 mergedOptions["extra_body"] = extraBody
396
397 parsed, err := openaicompat.ParseOptions(mergedOptions)
398 if err == nil {
399 options[openaicompat.Name] = parsed
400 }
401 }
402
403 return options
404}
405
406func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
407 modelOptions := getProviderOptions(model, cfg)
408 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
409 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
410 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
411 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
412 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
413 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
414}
415
416func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
417 large, small, err := c.buildAgentModels(ctx, isSubAgent)
418 if err != nil {
419 return nil, err
420 }
421
422 largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
423 result := NewSessionAgent(SessionAgentOptions{
424 LargeModel: large,
425 SmallModel: small,
426 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
427 SystemPrompt: "",
428 IsSubAgent: isSubAgent,
429 DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
430 IsYolo: c.permissions.SkipRequests(),
431 Sessions: c.sessions,
432 Messages: c.messages,
433 Tools: nil,
434 Notify: c.notify,
435 })
436
437 c.readyWg.Go(func() error {
438 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
439 if err != nil {
440 return err
441 }
442 result.SetSystemPrompt(systemPrompt)
443 return nil
444 })
445
446 c.readyWg.Go(func() error {
447 tools, err := c.buildTools(ctx, agent, isSubAgent)
448 if err != nil {
449 return err
450 }
451 result.SetTools(tools)
452 return nil
453 })
454
455 return result, nil
456}
457
458func (c *coordinator) buildTools(ctx context.Context, agent config.Agent, isSubAgent bool) ([]fantasy.AgentTool, error) {
459 var allTools []fantasy.AgentTool
460 if slices.Contains(agent.AllowedTools, AgentToolName) {
461 agentTool, err := c.agentTool(ctx)
462 if err != nil {
463 return nil, err
464 }
465 allTools = append(allTools, agentTool)
466 }
467
468 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
469 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
470 if err != nil {
471 return nil, err
472 }
473 allTools = append(allTools, agenticFetchTool)
474 }
475
476 // Get the model name for the agent
477 modelName := ""
478 if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
479 if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
480 modelName = model.Name
481 }
482 }
483
484 logFile := filepath.Join(c.cfg.Config().Options.DataDirectory, "logs", "crush.log")
485
486 // Build hook runner if PreToolUse hooks are configured.
487 var hookRunner *hooks.Runner
488 if preToolHooks := c.cfg.Config().Hooks[hooks.EventPreToolUse]; len(preToolHooks) > 0 {
489 hookRunner = hooks.NewRunner(preToolHooks, c.cfg.WorkingDir(), c.cfg.WorkingDir())
490 }
491
492 allTools = append(allTools,
493 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
494 tools.NewCrushInfoTool(c.cfg, c.lspManager, c.allSkills, c.activeSkills, c.skillTracker),
495 tools.NewCrushLogsTool(logFile),
496 tools.NewJobOutputTool(),
497 tools.NewJobKillTool(),
498 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
499 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
500 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
501 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
502 tools.NewGlobTool(c.cfg.WorkingDir()),
503 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
504 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
505 tools.NewSourcegraphTool(nil),
506 tools.NewTodosTool(c.sessions),
507 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.skillTracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
508 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
509 )
510
511 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
512 if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
513 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
514 }
515
516 if len(c.cfg.Config().MCP) > 0 {
517 allTools = append(
518 allTools,
519 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
520 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
521 )
522 }
523
524 var filteredTools []fantasy.AgentTool
525 for _, tool := range allTools {
526 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
527 filteredTools = append(filteredTools, tool)
528 }
529 }
530
531 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
532 if agent.AllowedMCP == nil {
533 // No MCP restrictions
534 filteredTools = append(filteredTools, tool)
535 continue
536 }
537 if len(agent.AllowedMCP) == 0 {
538 // No MCPs allowed
539 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
540 break
541 }
542
543 for mcp, tools := range agent.AllowedMCP {
544 if mcp != tool.MCP() {
545 continue
546 }
547 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
548 filteredTools = append(filteredTools, tool)
549 break
550 }
551 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
552 }
553 }
554 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
555 return strings.Compare(a.Info().Name, b.Info().Name)
556 })
557
558 // Wrap tools with hook interception for the top-level agent only.
559 // Sub-agents (the `agent` task tool, `agentic_fetch`, etc.) run
560 // without hook interception to avoid firing the user's hook N times
561 // per delegated turn. The top-level invocation of the sub-agent tool
562 // itself is still wrapped from the coder's side.
563 filteredTools = wrapToolsWithHooks(filteredTools, hookRunner, isSubAgent)
564
565 return filteredTools, nil
566}
567
568// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
569func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
570 largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
571 if !ok {
572 return Model{}, Model{}, errLargeModelNotSelected
573 }
574 smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
575 if !ok {
576 return Model{}, Model{}, errSmallModelNotSelected
577 }
578
579 largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
580 if !ok {
581 return Model{}, Model{}, errLargeModelProviderNotConfigured
582 }
583
584 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
585 if err != nil {
586 return Model{}, Model{}, err
587 }
588
589 smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
590 if !ok {
591 return Model{}, Model{}, errSmallModelProviderNotConfigured
592 }
593
594 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
595 if err != nil {
596 return Model{}, Model{}, err
597 }
598
599 var largeCatwalkModel *catwalk.Model
600 var smallCatwalkModel *catwalk.Model
601
602 for _, m := range largeProviderCfg.Models {
603 if m.ID == largeModelCfg.Model {
604 largeCatwalkModel = &m
605 }
606 }
607 for _, m := range smallProviderCfg.Models {
608 if m.ID == smallModelCfg.Model {
609 smallCatwalkModel = &m
610 }
611 }
612
613 if largeCatwalkModel == nil {
614 return Model{}, Model{}, errLargeModelNotFound
615 }
616
617 if smallCatwalkModel == nil {
618 return Model{}, Model{}, errSmallModelNotFound
619 }
620
621 largeModelID := largeModelCfg.Model
622 smallModelID := smallModelCfg.Model
623
624 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
625 largeModelID += ":exacto"
626 }
627
628 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
629 smallModelID += ":exacto"
630 }
631
632 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
633 if err != nil {
634 return Model{}, Model{}, err
635 }
636 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
637 if err != nil {
638 return Model{}, Model{}, err
639 }
640
641 return Model{
642 Model: largeModel,
643 CatwalkCfg: *largeCatwalkModel,
644 ModelCfg: largeModelCfg,
645 }, Model{
646 Model: smallModel,
647 CatwalkCfg: *smallCatwalkModel,
648 ModelCfg: smallModelCfg,
649 }, nil
650}
651
652func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
653 var opts []anthropic.Option
654
655 switch {
656 case strings.HasPrefix(apiKey, "Bearer "):
657 // NOTE: Prevent the SDK from picking up the API key from env.
658 os.Setenv("ANTHROPIC_API_KEY", "")
659 headers["Authorization"] = apiKey
660 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
661 // NOTE: Prevent the SDK from picking up the API key from env.
662 os.Setenv("ANTHROPIC_API_KEY", "")
663 headers["Authorization"] = "Bearer " + apiKey
664 case apiKey != "":
665 // X-Api-Key header
666 opts = append(opts, anthropic.WithAPIKey(apiKey))
667 }
668
669 if len(headers) > 0 {
670 opts = append(opts, anthropic.WithHeaders(headers))
671 }
672
673 if baseURL != "" {
674 opts = append(opts, anthropic.WithBaseURL(baseURL))
675 }
676
677 if c.cfg.Config().Options.Debug {
678 httpClient := log.NewHTTPClient()
679 opts = append(opts, anthropic.WithHTTPClient(httpClient))
680 }
681 return anthropic.New(opts...)
682}
683
684func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
685 opts := []openai.Option{
686 openai.WithAPIKey(apiKey),
687 openai.WithUseResponsesAPI(),
688 }
689 if c.cfg.Config().Options.Debug {
690 httpClient := log.NewHTTPClient()
691 opts = append(opts, openai.WithHTTPClient(httpClient))
692 }
693 if len(headers) > 0 {
694 opts = append(opts, openai.WithHeaders(headers))
695 }
696 if baseURL != "" {
697 opts = append(opts, openai.WithBaseURL(baseURL))
698 }
699 return openai.New(opts...)
700}
701
702func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
703 opts := []openrouter.Option{
704 openrouter.WithAPIKey(apiKey),
705 }
706 if c.cfg.Config().Options.Debug {
707 httpClient := log.NewHTTPClient()
708 opts = append(opts, openrouter.WithHTTPClient(httpClient))
709 }
710 if len(headers) > 0 {
711 opts = append(opts, openrouter.WithHeaders(headers))
712 }
713 return openrouter.New(opts...)
714}
715
716func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
717 opts := []vercel.Option{
718 vercel.WithAPIKey(apiKey),
719 }
720 if c.cfg.Config().Options.Debug {
721 httpClient := log.NewHTTPClient()
722 opts = append(opts, vercel.WithHTTPClient(httpClient))
723 }
724 if len(headers) > 0 {
725 opts = append(opts, vercel.WithHeaders(headers))
726 }
727 return vercel.New(opts...)
728}
729
730func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
731 opts := []openaicompat.Option{
732 openaicompat.WithBaseURL(baseURL),
733 openaicompat.WithAPIKey(apiKey),
734 }
735
736 // Set HTTP client based on provider and debug mode.
737 var httpClient *http.Client
738 if providerID == string(catwalk.InferenceProviderCopilot) {
739 opts = append(opts, openaicompat.WithUseResponsesAPI())
740 httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
741 } else if c.cfg.Config().Options.Debug {
742 httpClient = log.NewHTTPClient()
743 }
744 if httpClient != nil {
745 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
746 }
747
748 if len(headers) > 0 {
749 opts = append(opts, openaicompat.WithHeaders(headers))
750 }
751
752 for extraKey, extraValue := range extraBody {
753 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
754 }
755
756 return openaicompat.New(opts...)
757}
758
759func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
760 opts := []azure.Option{
761 azure.WithBaseURL(baseURL),
762 azure.WithAPIKey(apiKey),
763 azure.WithUseResponsesAPI(),
764 }
765 if c.cfg.Config().Options.Debug {
766 httpClient := log.NewHTTPClient()
767 opts = append(opts, azure.WithHTTPClient(httpClient))
768 }
769 if options == nil {
770 options = make(map[string]string)
771 }
772 if apiVersion, ok := options["apiVersion"]; ok {
773 opts = append(opts, azure.WithAPIVersion(apiVersion))
774 }
775 if len(headers) > 0 {
776 opts = append(opts, azure.WithHeaders(headers))
777 }
778
779 return azure.New(opts...)
780}
781
782func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
783 var opts []bedrock.Option
784 if c.cfg.Config().Options.Debug {
785 httpClient := log.NewHTTPClient()
786 opts = append(opts, bedrock.WithHTTPClient(httpClient))
787 }
788 if len(headers) > 0 {
789 opts = append(opts, bedrock.WithHeaders(headers))
790 }
791 switch {
792 case apiKey != "":
793 opts = append(opts, bedrock.WithAPIKey(apiKey))
794 case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
795 opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
796 default:
797 // Skip, let the SDK do authentication.
798 }
799 return bedrock.New(opts...)
800}
801
802func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
803 opts := []google.Option{
804 google.WithBaseURL(baseURL),
805 google.WithGeminiAPIKey(apiKey),
806 }
807 if c.cfg.Config().Options.Debug {
808 httpClient := log.NewHTTPClient()
809 opts = append(opts, google.WithHTTPClient(httpClient))
810 }
811 if len(headers) > 0 {
812 opts = append(opts, google.WithHeaders(headers))
813 }
814 return google.New(opts...)
815}
816
817func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
818 opts := []google.Option{}
819 if c.cfg.Config().Options.Debug {
820 httpClient := log.NewHTTPClient()
821 opts = append(opts, google.WithHTTPClient(httpClient))
822 }
823 if len(headers) > 0 {
824 opts = append(opts, google.WithHeaders(headers))
825 }
826
827 project := options["project"]
828 location := options["location"]
829
830 opts = append(opts, google.WithVertex(project, location))
831
832 return google.New(opts...)
833}
834
835func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
836 if model.Think {
837 return true
838 }
839 opts, err := anthropic.ParseOptions(model.ProviderOptions)
840 return err == nil && opts.Thinking != nil
841}
842
843func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
844 headers := maps.Clone(providerCfg.ExtraHeaders)
845 if headers == nil {
846 headers = make(map[string]string)
847 }
848
849 // handle special headers for anthropic
850 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
851 if v, ok := headers["anthropic-beta"]; ok {
852 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
853 } else {
854 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
855 }
856 }
857
858 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
859 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
860
861 switch providerCfg.Type {
862 case openai.Name:
863 return c.buildOpenaiProvider(baseURL, apiKey, headers)
864 case anthropic.Name:
865 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
866 case openrouter.Name:
867 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
868 case vercel.Name:
869 return c.buildVercelProvider(baseURL, apiKey, headers)
870 case azure.Name:
871 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
872 case bedrock.Name:
873 return c.buildBedrockProvider(apiKey, headers)
874 case google.Name:
875 return c.buildGoogleProvider(baseURL, apiKey, headers)
876 case "google-vertex":
877 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
878 case openaicompat.Name, hyper.Name:
879 switch providerCfg.ID {
880 case hyper.Name:
881 baseURL = hyper.BaseURL() + "/v1"
882 headers["x-crush-id"] = event.GetID()
883 case string(catwalk.InferenceProviderZAI):
884 if providerCfg.ExtraBody == nil {
885 providerCfg.ExtraBody = map[string]any{}
886 }
887 providerCfg.ExtraBody["tool_stream"] = true
888 }
889 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
890 default:
891 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
892 }
893}
894
895func isExactoSupported(modelID string) bool {
896 supportedModels := []string{
897 "moonshotai/kimi-k2-0905",
898 "deepseek/deepseek-v3.1-terminus",
899 "z-ai/glm-4.6",
900 "openai/gpt-oss-120b",
901 "qwen/qwen3-coder",
902 }
903 return slices.Contains(supportedModels, modelID)
904}
905
906func (c *coordinator) Cancel(sessionID string) {
907 c.currentAgent.Cancel(sessionID)
908}
909
910func (c *coordinator) CancelAll() {
911 c.currentAgent.CancelAll()
912}
913
914func (c *coordinator) ClearQueue(sessionID string) {
915 c.currentAgent.ClearQueue(sessionID)
916}
917
918func (c *coordinator) IsBusy() bool {
919 return c.currentAgent.IsBusy()
920}
921
922func (c *coordinator) IsSessionBusy(sessionID string) bool {
923 return c.currentAgent.IsSessionBusy(sessionID)
924}
925
926func (c *coordinator) Model() Model {
927 return c.currentAgent.Model()
928}
929
930func (c *coordinator) UpdateModels(ctx context.Context) error {
931 // build the models again so we make sure we get the latest config
932 large, small, err := c.buildAgentModels(ctx, false)
933 if err != nil {
934 return err
935 }
936 c.currentAgent.SetModels(large, small)
937
938 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
939 if !ok {
940 return errCoderAgentNotConfigured
941 }
942
943 tools, err := c.buildTools(ctx, agentCfg, false)
944 if err != nil {
945 return err
946 }
947 c.currentAgent.SetTools(tools)
948 return nil
949}
950
951func (c *coordinator) QueuedPrompts(sessionID string) int {
952 return c.currentAgent.QueuedPrompts(sessionID)
953}
954
955func (c *coordinator) QueuedPromptsList(sessionID string) []string {
956 return c.currentAgent.QueuedPromptsList(sessionID)
957}
958
959func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
960 providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
961 if !ok {
962 return errModelProviderNotConfigured
963 }
964 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
965}
966
967func (c *coordinator) isUnauthorized(err error) bool {
968 var providerErr *fantasy.ProviderError
969 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
970}
971
972func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
973 if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
974 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
975 return err
976 }
977 if err := c.UpdateModels(ctx); err != nil {
978 return err
979 }
980 return nil
981}
982
983func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
984 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
985 if err != nil {
986 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
987 return err
988 }
989
990 providerCfg.APIKey = newAPIKey
991 c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
992
993 if err := c.UpdateModels(ctx); err != nil {
994 return err
995 }
996 return nil
997}
998
999// subAgentParams holds the parameters for running a sub-agent.
1000type subAgentParams struct {
1001 Agent SessionAgent
1002 SessionID string
1003 AgentMessageID string
1004 ToolCallID string
1005 Prompt string
1006 SessionTitle string
1007 // SessionSetup is an optional callback invoked after session creation
1008 // but before agent execution, for custom session configuration.
1009 SessionSetup func(sessionID string)
1010}
1011
1012// runSubAgent runs a sub-agent and handles session management and cost accumulation.
1013// It creates a sub-session, runs the agent with the given prompt, and propagates
1014// the cost to the parent session.
1015func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
1016 // Create sub-session
1017 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
1018 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
1019 if err != nil {
1020 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
1021 }
1022
1023 // Call session setup function if provided
1024 if params.SessionSetup != nil {
1025 params.SessionSetup(session.ID)
1026 }
1027
1028 // Get model configuration
1029 model := params.Agent.Model()
1030 maxTokens := model.CatwalkCfg.DefaultMaxTokens
1031 if model.ModelCfg.MaxTokens != 0 {
1032 maxTokens = model.ModelCfg.MaxTokens
1033 }
1034
1035 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
1036 if !ok {
1037 return fantasy.ToolResponse{}, errModelProviderNotConfigured
1038 }
1039
1040 // Run the agent
1041 result, err := params.Agent.Run(ctx, SessionAgentCall{
1042 SessionID: session.ID,
1043 Prompt: params.Prompt,
1044 MaxOutputTokens: maxTokens,
1045 ProviderOptions: getProviderOptions(model, providerCfg),
1046 Temperature: model.ModelCfg.Temperature,
1047 TopP: model.ModelCfg.TopP,
1048 TopK: model.ModelCfg.TopK,
1049 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1050 PresencePenalty: model.ModelCfg.PresencePenalty,
1051 NonInteractive: true,
1052 })
1053 if err != nil {
1054 return fantasy.NewTextErrorResponse("error generating response"), nil
1055 }
1056
1057 // Update parent session cost
1058 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1059 return fantasy.ToolResponse{}, err
1060 }
1061
1062 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1063}
1064
1065// updateParentSessionCost accumulates the cost from a child session to its parent session.
1066func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1067 childSession, err := c.sessions.Get(ctx, childSessionID)
1068 if err != nil {
1069 return fmt.Errorf("get child session: %w", err)
1070 }
1071
1072 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1073 if err != nil {
1074 return fmt.Errorf("get parent session: %w", err)
1075 }
1076
1077 parentSession.Cost += childSession.Cost
1078
1079 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1080 return fmt.Errorf("save parent session: %w", err)
1081 }
1082
1083 return nil
1084}
1085
1086// discoverSkills runs the skill discovery pipeline and returns both the
1087// pre-filter (all discovered, after dedup) and post-filter (active) lists.
1088// It also emits a single diagnostic log line summarising the outcome to
1089// help track skill-loading health over time.
1090func discoverSkills(cfg *config.ConfigStore) (allSkills, activeSkills []*skills.Skill) {
1091 builtin, builtinStates := skills.DiscoverBuiltinWithStates()
1092 discovered := append([]*skills.Skill(nil), builtin...)
1093
1094 var userStates []*skills.SkillState
1095 var userPaths []string
1096
1097 opts := cfg.Config().Options
1098 if opts != nil && len(opts.SkillsPaths) > 0 {
1099 userPaths = make([]string, 0, len(opts.SkillsPaths))
1100 for _, pth := range opts.SkillsPaths {
1101 expanded := home.Long(pth)
1102 if strings.HasPrefix(expanded, "$") {
1103 if resolved, err := cfg.Resolver().ResolveValue(expanded); err == nil {
1104 expanded = resolved
1105 }
1106 }
1107 userPaths = append(userPaths, expanded)
1108 }
1109 var userSkills []*skills.Skill
1110 userSkills, userStates = skills.DiscoverWithStates(userPaths)
1111 discovered = append(discovered, userSkills...)
1112 }
1113
1114 allSkills = skills.Deduplicate(discovered)
1115 var disabledSkills []string
1116 if opts != nil {
1117 disabledSkills = opts.DisabledSkills
1118 }
1119 activeSkills = skills.Filter(allSkills, disabledSkills)
1120
1121 logDiscoveryStats(builtin, builtinStates, userStates, userPaths, allSkills, activeSkills, disabledSkills)
1122 return allSkills, activeSkills
1123}
1124
1125// logTurnSkillUsage emits a per-turn diagnostic line showing which skills
1126// (if any) were loaded during this turn and which looked relevant based on
1127// a cheap keyword match against the user prompt. The goal is to surface
1128// "should-have-loaded but didn't" situations for later analysis.
1129//
1130// Logged at Info level under component=skills; heavy fields are elided when
1131// there is nothing interesting to report.
1132func logTurnSkillUsage(
1133 sessionID string,
1134 prompt string,
1135 activeSkills []*skills.Skill,
1136 tracker *skills.Tracker,
1137 before []string,
1138) {
1139 if tracker == nil || len(activeSkills) == 0 {
1140 return
1141 }
1142
1143 after := tracker.LoadedNames()
1144
1145 beforeSet := make(map[string]bool, len(before))
1146 for _, n := range before {
1147 beforeSet[n] = true
1148 }
1149 var loadedThisTurn []string
1150 for _, n := range after {
1151 if !beforeSet[n] {
1152 loadedThisTurn = append(loadedThisTurn, n)
1153 }
1154 }
1155
1156 slog.Info("Skill turn summary",
1157 "component", "skills",
1158 "session_id", sessionID,
1159 "prompt_len", len(prompt),
1160 "active_total", len(activeSkills),
1161 "loaded_total", len(after),
1162 "loaded_this_turn", loadedThisTurn,
1163 )
1164}
1165
1166// logDiscoveryStats emits a single structured log line summarising skill
1167// discovery for the current session. It is intentionally low-volume: one
1168// line per session start.
1169func logDiscoveryStats(
1170 builtin []*skills.Skill,
1171 builtinStates, userStates []*skills.SkillState,
1172 userPaths []string,
1173 allSkills, activeSkills []*skills.Skill,
1174 disabled []string,
1175) {
1176 countErrors := func(states []*skills.SkillState) int {
1177 n := 0
1178 for _, s := range states {
1179 if s.State == skills.StateError {
1180 n++
1181 }
1182 }
1183 return n
1184 }
1185
1186 userOK := 0
1187 for _, s := range userStates {
1188 if s.State == skills.StateNormal {
1189 userOK++
1190 }
1191 }
1192
1193 activeNames := make([]string, 0, len(activeSkills))
1194 for _, s := range activeSkills {
1195 activeNames = append(activeNames, s.Name)
1196 }
1197
1198 xml := skills.ToPromptXML(activeSkills)
1199
1200 slog.Info("Skill discovery complete",
1201 "component", "skills",
1202 "builtin_ok", len(builtin),
1203 "builtin_errors", countErrors(builtinStates),
1204 "user_ok", userOK,
1205 "user_errors", countErrors(userStates),
1206 "user_paths", len(userPaths),
1207 "deduped_total", len(allSkills),
1208 "active", len(activeSkills),
1209 "disabled", len(disabled),
1210 "prompt_bytes", len(xml),
1211 "prompt_tok_est", skills.ApproxTokenCount(xml),
1212 "active_names", activeNames,
1213 )
1214}