coordinator.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/agent/prompt"
 11	"github.com/charmbracelet/crush/internal/agent/tools"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/history"
 15	"github.com/charmbracelet/crush/internal/log"
 16	"github.com/charmbracelet/crush/internal/lsp"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/session"
 20	"github.com/charmbracelet/fantasy/ai"
 21	"github.com/charmbracelet/fantasy/anthropic"
 22	"github.com/charmbracelet/fantasy/google"
 23	"github.com/charmbracelet/fantasy/openai"
 24	"github.com/charmbracelet/fantasy/openaicompat"
 25	"github.com/charmbracelet/fantasy/openrouter"
 26)
 27
 28type Coordinator interface {
 29	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 30	// SetMainAgent(string)
 31	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
 32	Cancel(sessionID string)
 33	CancelAll()
 34	IsSessionBusy(sessionID string) bool
 35	IsBusy() bool
 36	QueuedPrompts(sessionID string) int
 37	ClearQueue(sessionID string)
 38	Summarize(context.Context, string) error
 39	Model() Model
 40	UpdateModels() error
 41}
 42
 43type coordinator struct {
 44	cfg         *config.Config
 45	sessions    session.Service
 46	messages    message.Service
 47	permissions permission.Service
 48	history     history.Service
 49	lspClients  *csync.Map[string, *lsp.Client]
 50
 51	currentAgent SessionAgent
 52	agents       map[string]SessionAgent
 53}
 54
 55func NewCoordinator(
 56	cfg *config.Config,
 57	sessions session.Service,
 58	messages message.Service,
 59	permissions permission.Service,
 60	history history.Service,
 61	lspClients *csync.Map[string, *lsp.Client],
 62) (Coordinator, error) {
 63	c := &coordinator{
 64		cfg:         cfg,
 65		sessions:    sessions,
 66		messages:    messages,
 67		permissions: permissions,
 68		history:     history,
 69		lspClients:  lspClients,
 70		agents:      make(map[string]SessionAgent),
 71	}
 72
 73	agentCfg, ok := cfg.Agents[config.AgentCoder]
 74	if !ok {
 75		return nil, errors.New("coder agent not configured")
 76	}
 77
 78	// TODO: make this dynamic when we support multiple agents
 79	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 80	if err != nil {
 81		return nil, err
 82	}
 83
 84	agent, err := c.buildAgent(prompt, agentCfg)
 85	if err != nil {
 86		return nil, err
 87	}
 88	c.currentAgent = agent
 89	c.agents[config.AgentCoder] = agent
 90	return c, nil
 91}
 92
 93// Run implements Coordinator.
 94func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
 95	model := c.currentAgent.Model()
 96	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 97	if model.ModelCfg.MaxTokens != 0 {
 98		maxTokens = model.ModelCfg.MaxTokens
 99	}
100
101	if !model.CatwalkCfg.SupportsImages && attachments != nil {
102		attachments = nil
103	}
104
105	return c.currentAgent.Run(ctx, SessionAgentCall{
106		SessionID:        sessionID,
107		Prompt:           prompt,
108		Attachments:      attachments,
109		MaxOutputTokens:  maxTokens,
110		ProviderOptions:  c.getProviderOptions(model),
111		Temperature:      model.ModelCfg.Temperature,
112		TopP:             model.ModelCfg.TopP,
113		TopK:             model.ModelCfg.TopK,
114		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
115		PresencePenalty:  model.ModelCfg.PresencePenalty,
116	})
117}
118
119func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
120	options := ai.ProviderOptions{}
121
122	switch model.Model.Provider() {
123	case openai.Name:
124		parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
125		if err == nil {
126			options[openai.Name] = parsed
127		}
128	case anthropic.Name:
129		parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
130		if err == nil {
131			options[anthropic.Name] = parsed
132		}
133	case openrouter.Name:
134		parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
135		if err == nil {
136			options[openrouter.Name] = parsed
137		}
138	case google.Name:
139		parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
140		if err == nil {
141			options[google.Name] = parsed
142		}
143	case openaicompat.Name:
144		parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
145		if err == nil {
146			options[openaicompat.Name] = parsed
147		}
148	}
149
150	return options
151}
152
153func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
154	large, small, err := c.buildAgentModels()
155	if err != nil {
156		return nil, err
157	}
158
159	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
160	if err != nil {
161		return nil, err
162	}
163
164	tools, err := c.buildTools(agent)
165	if err != nil {
166		return nil, err
167	}
168	return NewSessionAgent(large, small, systemPrompt, c.sessions, c.messages, tools...), nil
169}
170
171func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
172	var allTools []ai.AgentTool
173	if slices.Contains(agent.AllowedTools, AgentToolName) {
174		agentTool, err := c.agentTool()
175		if err != nil {
176			return nil, err
177		}
178		allTools = append(allTools, agentTool)
179	}
180
181	allTools = append(allTools,
182		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
183		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
184		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
185		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
186		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
187		tools.NewGlobTool(c.cfg.WorkingDir()),
188		tools.NewGrepTool(c.cfg.WorkingDir()),
189		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
190		tools.NewSourcegraphTool(nil),
191		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
192		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
193	)
194
195	var filteredTools []ai.AgentTool
196	for _, tool := range allTools {
197		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
198			filteredTools = append(filteredTools, tool)
199		}
200	}
201
202	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
203
204	for _, mcpTool := range mcpTools {
205		if agent.AllowedMCP == nil {
206			// No MCP restrictions
207			filteredTools = append(filteredTools, mcpTool)
208		} else if len(agent.AllowedMCP) == 0 {
209			// no mcps allowed
210			break
211		}
212
213		for mcp, tools := range agent.AllowedMCP {
214			if mcp == mcpTool.MCP() {
215				if len(tools) == 0 {
216					filteredTools = append(filteredTools, mcpTool)
217				}
218				for _, t := range tools {
219					if t == mcpTool.MCPToolName() {
220						filteredTools = append(filteredTools, mcpTool)
221					}
222				}
223				break
224			}
225		}
226	}
227
228	return filteredTools, nil
229}
230
231// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
232func (c *coordinator) buildAgentModels() (Model, Model, error) {
233	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
234	if !ok {
235		return Model{}, Model{}, errors.New("large model not selected")
236	}
237	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
238	if !ok {
239		return Model{}, Model{}, errors.New("small model not selected")
240	}
241
242	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
243	if !ok {
244		return Model{}, Model{}, errors.New("large model provider not configured")
245	}
246
247	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
248	if err != nil {
249		return Model{}, Model{}, err
250	}
251
252	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
253	if !ok {
254		return Model{}, Model{}, errors.New("large model provider not configured")
255	}
256
257	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
258	if err != nil {
259		return Model{}, Model{}, err
260	}
261
262	var largeCatwalkModel *catwalk.Model
263	var smallCatwalkModel *catwalk.Model
264
265	for _, m := range largeProviderCfg.Models {
266		if m.ID == largeModelCfg.Model {
267			largeCatwalkModel = &m
268		}
269	}
270	for _, m := range smallProviderCfg.Models {
271		if m.ID == smallModelCfg.Model {
272			smallCatwalkModel = &m
273		}
274	}
275
276	if largeCatwalkModel == nil {
277		return Model{}, Model{}, errors.New("large model not found in provider config")
278	}
279
280	if smallCatwalkModel == nil {
281		return Model{}, Model{}, errors.New("snall model not found in provider config")
282	}
283
284	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
285	if err != nil {
286		return Model{}, Model{}, err
287	}
288	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
289	if err != nil {
290		return Model{}, Model{}, err
291	}
292
293	return Model{
294			Model:      largeModel,
295			CatwalkCfg: *largeCatwalkModel,
296			ModelCfg:   largeModelCfg,
297		}, Model{
298			Model:      smallModel,
299			CatwalkCfg: *smallCatwalkModel,
300			ModelCfg:   smallModelCfg,
301		}, nil
302}
303
304func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
305	hasBearerAuth := false
306	for key := range headers {
307		if strings.ToLower(key) == "authorization" {
308			hasBearerAuth = true
309			break
310		}
311	}
312	if hasBearerAuth {
313		apiKey = "" // clear apiKey to avoid using X-Api-Key header
314	}
315
316	var opts []anthropic.Option
317
318	if apiKey != "" {
319		// Use standard X-Api-Key header
320		opts = append(opts, anthropic.WithAPIKey(apiKey))
321	}
322
323	if len(headers) > 0 {
324		opts = append(opts, anthropic.WithHeaders(headers))
325	}
326
327	if baseURL != "" {
328		opts = append(opts, anthropic.WithBaseURL(baseURL))
329	}
330
331	if c.cfg.Options.Debug {
332		httpClient := log.NewHTTPClient()
333		opts = append(opts, anthropic.WithHTTPClient(httpClient))
334	}
335
336	return anthropic.New(opts...)
337}
338
339func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
340	opts := []openai.Option{
341		openai.WithAPIKey(apiKey),
342	}
343	if c.cfg.Options.Debug {
344		httpClient := log.NewHTTPClient()
345		opts = append(opts, openai.WithHTTPClient(httpClient))
346	}
347	if len(headers) > 0 {
348		opts = append(opts, openai.WithHeaders(headers))
349	}
350	if baseURL != "" {
351		opts = append(opts, openai.WithBaseURL(baseURL))
352	}
353	return openai.New(opts...)
354}
355
356func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
357	opts := []openrouter.Option{
358		openrouter.WithAPIKey(apiKey),
359	}
360	if c.cfg.Options.Debug {
361		httpClient := log.NewHTTPClient()
362		opts = append(opts, openrouter.WithHTTPClient(httpClient))
363	}
364	if len(headers) > 0 {
365		opts = append(opts, openrouter.WithHeaders(headers))
366	}
367	return openrouter.New(opts...)
368}
369
370func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
371	opts := []openaicompat.Option{
372		openaicompat.WithAPIKey(apiKey),
373	}
374	if c.cfg.Options.Debug {
375		httpClient := log.NewHTTPClient()
376		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
377	}
378	if len(headers) > 0 {
379		opts = append(opts, openaicompat.WithHeaders(headers))
380	}
381
382	return openaicompat.New(baseURL, opts...)
383}
384
385// TODO: add baseURL for google
386func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
387	opts := []google.Option{
388		google.WithAPIKey(apiKey),
389	}
390	if c.cfg.Options.Debug {
391		httpClient := log.NewHTTPClient()
392		opts = append(opts, google.WithHTTPClient(httpClient))
393	}
394	if len(headers) > 0 {
395		opts = append(opts, google.WithHeaders(headers))
396	}
397	return google.New(opts...)
398}
399
400func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
401	if model.Think {
402		return true
403	}
404
405	if model.ProviderOptions == nil {
406		return false
407	}
408
409	opts, err := anthropic.ParseOptions(model.ProviderOptions)
410	if err != nil {
411		return false
412	}
413	if opts.Thinking != nil {
414		return true
415	}
416	return false
417}
418
419func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
420	headers := providerCfg.ExtraHeaders
421
422	// handle special headers for anthropic
423	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
424		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
425	}
426
427	// TODO: make sure we have
428	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
429	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
430	var provider ai.Provider
431	switch providerCfg.Type {
432	case openai.Name:
433		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
434	case anthropic.Name:
435		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
436	case openrouter.Name:
437		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
438	case google.Name:
439		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
440	case openaicompat.Name:
441		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
442	default:
443		return nil, errors.New("provider type not supported")
444	}
445	return provider, nil
446}
447
448func (c *coordinator) Cancel(sessionID string) {
449	c.currentAgent.Cancel(sessionID)
450}
451
452func (c *coordinator) CancelAll() {
453	c.currentAgent.CancelAll()
454}
455
456func (c *coordinator) ClearQueue(sessionID string) {
457	c.currentAgent.ClearQueue(sessionID)
458}
459
460func (c *coordinator) IsBusy() bool {
461	return c.currentAgent.IsBusy()
462}
463
464func (c *coordinator) IsSessionBusy(sessionID string) bool {
465	return c.currentAgent.IsSessionBusy(sessionID)
466}
467
468func (c *coordinator) Model() Model {
469	return c.currentAgent.Model()
470}
471
472func (c *coordinator) UpdateModels() error {
473	// build the models again so we make sure we get the latest config
474	large, small, err := c.buildAgentModels()
475	if err != nil {
476		return err
477	}
478	c.currentAgent.SetModels(large, small)
479
480	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
481	if !ok {
482		return errors.New("coder agent not configured")
483	}
484
485	tools, err := c.buildTools(agentCfg)
486	if err != nil {
487		return err
488	}
489	c.currentAgent.SetTools(tools)
490	return nil
491}
492
493func (c *coordinator) QueuedPrompts(sessionID string) int {
494	return c.currentAgent.QueuedPrompts(sessionID)
495}
496
497func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
498	return c.currentAgent.Summarize(ctx, sessionID)
499}