coordinator.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/agent/prompt"
 11	"github.com/charmbracelet/crush/internal/agent/tools"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/history"
 15	"github.com/charmbracelet/crush/internal/log"
 16	"github.com/charmbracelet/crush/internal/lsp"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/session"
 20	"github.com/charmbracelet/fantasy/ai"
 21	"github.com/charmbracelet/fantasy/anthropic"
 22	"github.com/charmbracelet/fantasy/google"
 23	"github.com/charmbracelet/fantasy/openai"
 24	"github.com/charmbracelet/fantasy/openaicompat"
 25	"github.com/charmbracelet/fantasy/openrouter"
 26)
 27
 28type Coordinator interface {
 29	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 30	// SetMainAgent(string)
 31	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
 32	Cancel(sessionID string)
 33	CancelAll()
 34	IsSessionBusy(sessionID string) bool
 35	IsBusy() bool
 36	QueuedPrompts(sessionID string) int
 37	ClearQueue(sessionID string)
 38	Summarize(context.Context, string) error
 39	Model() Model
 40	UpdateModels() error
 41}
 42
 43type coordinator struct {
 44	cfg         *config.Config
 45	sessions    session.Service
 46	messages    message.Service
 47	permissions permission.Service
 48	history     history.Service
 49	lspClients  *csync.Map[string, *lsp.Client]
 50
 51	currentAgent SessionAgent
 52	agents       map[string]SessionAgent
 53}
 54
 55func NewCoordinator(
 56	cfg *config.Config,
 57	sessions session.Service,
 58	messages message.Service,
 59	permissions permission.Service,
 60	history history.Service,
 61	lspClients *csync.Map[string, *lsp.Client],
 62) (Coordinator, error) {
 63	c := &coordinator{
 64		cfg:         cfg,
 65		sessions:    sessions,
 66		messages:    messages,
 67		permissions: permissions,
 68		history:     history,
 69		lspClients:  lspClients,
 70		agents:      make(map[string]SessionAgent),
 71	}
 72
 73	agentCfg, ok := cfg.Agents[config.AgentCoder]
 74	if !ok {
 75		return nil, errors.New("coder agent not configured")
 76	}
 77
 78	// TODO: make this dynamic when we support multiple agents
 79	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 80	if err != nil {
 81		return nil, err
 82	}
 83
 84	agent, err := c.buildAgent(prompt, agentCfg)
 85	if err != nil {
 86		return nil, err
 87	}
 88	c.currentAgent = agent
 89	c.agents[config.AgentCoder] = agent
 90	return c, nil
 91}
 92
 93// Run implements Coordinator.
 94func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
 95	model := c.currentAgent.Model()
 96	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 97	if model.ModelCfg.MaxTokens != 0 {
 98		maxTokens = model.ModelCfg.MaxTokens
 99	}
100
101	if !model.CatwalkCfg.SupportsImages && attachments != nil {
102		attachments = nil
103	}
104
105	return c.currentAgent.Run(ctx, SessionAgentCall{
106		SessionID:        sessionID,
107		Prompt:           prompt,
108		Attachments:      attachments,
109		MaxOutputTokens:  maxTokens,
110		ProviderOptions:  c.getProviderOptions(model),
111		Temperature:      model.ModelCfg.Temperature,
112		TopP:             model.ModelCfg.TopP,
113		TopK:             model.ModelCfg.TopK,
114		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
115		PresencePenalty:  model.ModelCfg.PresencePenalty,
116	})
117}
118
119func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
120	options := ai.ProviderOptions{}
121
122	switch model.Model.Provider() {
123	case openai.Name:
124		parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
125		if err == nil {
126			options[openai.Name] = parsed
127		}
128	case anthropic.Name:
129		parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
130		if err == nil {
131			options[anthropic.Name] = parsed
132		}
133	case openrouter.Name:
134		parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
135		if err == nil {
136			options[openrouter.Name] = parsed
137		}
138	case google.Name:
139		parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
140		if err == nil {
141			options[google.Name] = parsed
142		}
143	case openaicompat.Name:
144		parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
145		if err == nil {
146			options[openaicompat.Name] = parsed
147		}
148	}
149
150	return options
151}
152
153func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
154	large, small, err := c.buildAgentModels()
155	if err != nil {
156		return nil, err
157	}
158
159	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
160	if err != nil {
161		return nil, err
162	}
163
164	tools, err := c.buildTools(agent)
165	if err != nil {
166		return nil, err
167	}
168	return NewSessionAgent(large, small, systemPrompt, c.sessions, c.messages, tools...), nil
169}
170
171func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
172	var allTools []ai.AgentTool
173	if slices.Contains(agent.AllowedTools, AgentToolName) {
174		agentTool, err := c.agentTool()
175		if err != nil {
176			return nil, err
177		}
178		allTools = append(allTools, agentTool)
179	}
180
181	allTools = append(allTools,
182		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
183		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
184		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
185		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
186		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
187		tools.NewGlobTool(c.cfg.WorkingDir()),
188		tools.NewGrepTool(c.cfg.WorkingDir()),
189		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
190		tools.NewSourcegraphTool(nil),
191		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
192		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
193	)
194
195	var filteredTools []ai.AgentTool
196	for _, tool := range allTools {
197		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
198			filteredTools = append(filteredTools, tool)
199		}
200	}
201
202	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
203
204	for _, mcpTool := range mcpTools {
205		if agent.AllowedMCP == nil {
206			// No MCP restrictions
207			filteredTools = append(filteredTools, mcpTool)
208		} else if len(agent.AllowedMCP) == 0 {
209			// no mcps allowed
210			break
211		}
212
213		for mcp, tools := range agent.AllowedMCP {
214			if mcp == mcpTool.MCP() {
215				if len(tools) == 0 {
216					filteredTools = append(filteredTools, mcpTool)
217				}
218				for _, t := range tools {
219					if t == mcpTool.MCPToolName() {
220						filteredTools = append(filteredTools, mcpTool)
221					}
222				}
223				break
224			}
225		}
226	}
227
228	return filteredTools, nil
229}
230
231// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
232func (c *coordinator) buildAgentModels() (Model, Model, error) {
233	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
234	if !ok {
235		return Model{}, Model{}, errors.New("large model not selected")
236	}
237	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
238	if !ok {
239		return Model{}, Model{}, errors.New("small model not selected")
240	}
241
242	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
243	if !ok {
244		return Model{}, Model{}, errors.New("large model provider not configured")
245	}
246
247	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
248	if err != nil {
249		return Model{}, Model{}, err
250	}
251
252	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
253	if !ok {
254		return Model{}, Model{}, errors.New("large model provider not configured")
255	}
256
257	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
258	if err != nil {
259		return Model{}, Model{}, err
260	}
261
262	var largeCatwalkModel *catwalk.Model
263	var smallCatwalkModel *catwalk.Model
264
265	for _, m := range largeProviderCfg.Models {
266		if m.ID == largeModelCfg.Model {
267			largeCatwalkModel = &m
268		}
269	}
270	for _, m := range smallProviderCfg.Models {
271		if m.ID == smallModelCfg.Model {
272			smallCatwalkModel = &m
273		}
274	}
275
276	if largeCatwalkModel == nil {
277		return Model{}, Model{}, errors.New("large model not found in provider config")
278	}
279
280	if smallCatwalkModel == nil {
281		return Model{}, Model{}, errors.New("snall model not found in provider config")
282	}
283
284	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
285	if err != nil {
286		return Model{}, Model{}, err
287	}
288	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
289	if err != nil {
290		return Model{}, Model{}, err
291	}
292
293	return Model{
294			Model:      largeModel,
295			CatwalkCfg: *largeCatwalkModel,
296			ModelCfg:   largeModelCfg,
297		}, Model{
298			Model:      smallModel,
299			CatwalkCfg: *smallCatwalkModel,
300			ModelCfg:   smallModelCfg,
301		}, nil
302}
303
304func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
305	hasBearerAuth := false
306	for key := range headers {
307		if strings.ToLower(key) == "authorization" {
308			hasBearerAuth = true
309			break
310		}
311	}
312	if hasBearerAuth {
313		apiKey = "" // clear apiKey to avoid using X-Api-Key header
314	}
315
316	var opts []anthropic.Option
317
318	if apiKey != "" {
319		// Use standard X-Api-Key header
320		opts = append(opts, anthropic.WithAPIKey(apiKey))
321	}
322
323	if len(headers) > 0 {
324		opts = append(opts, anthropic.WithHeaders(headers))
325	}
326
327	if baseURL != "" {
328		opts = append(opts, anthropic.WithBaseURL(baseURL))
329	}
330
331	if c.cfg.Options.Debug {
332		httpClient := log.NewHTTPClient()
333		opts = append(opts, anthropic.WithHTTPClient(httpClient))
334	}
335
336	return anthropic.New(opts...)
337}
338
339func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
340	opts := []openai.Option{
341		openai.WithAPIKey(apiKey),
342	}
343	if c.cfg.Options.Debug {
344		httpClient := log.NewHTTPClient()
345		opts = append(opts, openai.WithHTTPClient(httpClient))
346	}
347	if len(headers) > 0 {
348		opts = append(opts, openai.WithHeaders(headers))
349	}
350	if baseURL != "" {
351		opts = append(opts, openai.WithBaseURL(baseURL))
352	}
353	return openai.New(opts...)
354}
355
356func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
357	opts := []openrouter.Option{
358		openrouter.WithAPIKey(apiKey),
359		openrouter.WithLanguageUniqueToolCallIds(),
360	}
361	if c.cfg.Options.Debug {
362		httpClient := log.NewHTTPClient()
363		opts = append(opts, openrouter.WithHTTPClient(httpClient))
364	}
365	if len(headers) > 0 {
366		opts = append(opts, openrouter.WithHeaders(headers))
367	}
368	return openrouter.New(opts...)
369}
370
371func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
372	opts := []openaicompat.Option{
373		openaicompat.WithAPIKey(apiKey),
374	}
375	if c.cfg.Options.Debug {
376		httpClient := log.NewHTTPClient()
377		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
378	}
379	if len(headers) > 0 {
380		opts = append(opts, openaicompat.WithHeaders(headers))
381	}
382
383	return openaicompat.New(baseURL, opts...)
384}
385
386// TODO: add baseURL for google
387func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
388	opts := []google.Option{
389		google.WithAPIKey(apiKey),
390	}
391	if c.cfg.Options.Debug {
392		httpClient := log.NewHTTPClient()
393		opts = append(opts, google.WithHTTPClient(httpClient))
394	}
395	if len(headers) > 0 {
396		opts = append(opts, google.WithHeaders(headers))
397	}
398	return google.New(opts...)
399}
400
401func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
402	if model.Think {
403		return true
404	}
405
406	if model.ProviderOptions == nil {
407		return false
408	}
409
410	opts, err := anthropic.ParseOptions(model.ProviderOptions)
411	if err != nil {
412		return false
413	}
414	if opts.Thinking != nil {
415		return true
416	}
417	return false
418}
419
420func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
421	headers := providerCfg.ExtraHeaders
422
423	// handle special headers for anthropic
424	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
425		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
426	}
427
428	// TODO: make sure we have
429	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
430	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
431	var provider ai.Provider
432	switch providerCfg.Type {
433	case openai.Name:
434		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
435	case anthropic.Name:
436		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
437	case openrouter.Name:
438		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
439	case google.Name:
440		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
441	case openaicompat.Name:
442		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
443	default:
444		return nil, errors.New("provider type not supported")
445	}
446	return provider, nil
447}
448
449func (c *coordinator) Cancel(sessionID string) {
450	c.currentAgent.Cancel(sessionID)
451}
452
453func (c *coordinator) CancelAll() {
454	c.currentAgent.CancelAll()
455}
456
457func (c *coordinator) ClearQueue(sessionID string) {
458	c.currentAgent.ClearQueue(sessionID)
459}
460
461func (c *coordinator) IsBusy() bool {
462	return c.currentAgent.IsBusy()
463}
464
465func (c *coordinator) IsSessionBusy(sessionID string) bool {
466	return c.currentAgent.IsSessionBusy(sessionID)
467}
468
469func (c *coordinator) Model() Model {
470	return c.currentAgent.Model()
471}
472
473func (c *coordinator) UpdateModels() error {
474	// build the models again so we make sure we get the latest config
475	large, small, err := c.buildAgentModels()
476	if err != nil {
477		return err
478	}
479	c.currentAgent.SetModels(large, small)
480
481	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
482	if !ok {
483		return errors.New("coder agent not configured")
484	}
485
486	tools, err := c.buildTools(agentCfg)
487	if err != nil {
488		return err
489	}
490	c.currentAgent.SetTools(tools)
491	return nil
492}
493
494func (c *coordinator) QueuedPrompts(sessionID string) int {
495	return c.currentAgent.QueuedPrompts(sessionID)
496}
497
498func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
499	return c.currentAgent.Summarize(ctx, sessionID)
500}