coordinator.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/agent/prompt"
 11	"github.com/charmbracelet/crush/internal/agent/tools"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/history"
 15	"github.com/charmbracelet/crush/internal/log"
 16	"github.com/charmbracelet/crush/internal/lsp"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/session"
 20	"github.com/charmbracelet/fantasy/ai"
 21	"github.com/charmbracelet/fantasy/anthropic"
 22	"github.com/charmbracelet/fantasy/azure"
 23	"github.com/charmbracelet/fantasy/google"
 24	"github.com/charmbracelet/fantasy/openai"
 25	"github.com/charmbracelet/fantasy/openaicompat"
 26	"github.com/charmbracelet/fantasy/openrouter"
 27)
 28
 29type Coordinator interface {
 30	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 31	// SetMainAgent(string)
 32	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
 33	Cancel(sessionID string)
 34	CancelAll()
 35	IsSessionBusy(sessionID string) bool
 36	IsBusy() bool
 37	QueuedPrompts(sessionID string) int
 38	ClearQueue(sessionID string)
 39	Summarize(context.Context, string) error
 40	Model() Model
 41	UpdateModels() error
 42}
 43
 44type coordinator struct {
 45	cfg         *config.Config
 46	sessions    session.Service
 47	messages    message.Service
 48	permissions permission.Service
 49	history     history.Service
 50	lspClients  *csync.Map[string, *lsp.Client]
 51
 52	currentAgent SessionAgent
 53	agents       map[string]SessionAgent
 54}
 55
 56func NewCoordinator(
 57	cfg *config.Config,
 58	sessions session.Service,
 59	messages message.Service,
 60	permissions permission.Service,
 61	history history.Service,
 62	lspClients *csync.Map[string, *lsp.Client],
 63) (Coordinator, error) {
 64	c := &coordinator{
 65		cfg:         cfg,
 66		sessions:    sessions,
 67		messages:    messages,
 68		permissions: permissions,
 69		history:     history,
 70		lspClients:  lspClients,
 71		agents:      make(map[string]SessionAgent),
 72	}
 73
 74	agentCfg, ok := cfg.Agents[config.AgentCoder]
 75	if !ok {
 76		return nil, errors.New("coder agent not configured")
 77	}
 78
 79	// TODO: make this dynamic when we support multiple agents
 80	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 81	if err != nil {
 82		return nil, err
 83	}
 84
 85	agent, err := c.buildAgent(prompt, agentCfg)
 86	if err != nil {
 87		return nil, err
 88	}
 89	c.currentAgent = agent
 90	c.agents[config.AgentCoder] = agent
 91	return c, nil
 92}
 93
 94// Run implements Coordinator.
 95func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
 96	model := c.currentAgent.Model()
 97	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 98	if model.ModelCfg.MaxTokens != 0 {
 99		maxTokens = model.ModelCfg.MaxTokens
100	}
101
102	if !model.CatwalkCfg.SupportsImages && attachments != nil {
103		attachments = nil
104	}
105
106	return c.currentAgent.Run(ctx, SessionAgentCall{
107		SessionID:        sessionID,
108		Prompt:           prompt,
109		Attachments:      attachments,
110		MaxOutputTokens:  maxTokens,
111		ProviderOptions:  c.getProviderOptions(model),
112		Temperature:      model.ModelCfg.Temperature,
113		TopP:             model.ModelCfg.TopP,
114		TopK:             model.ModelCfg.TopK,
115		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
116		PresencePenalty:  model.ModelCfg.PresencePenalty,
117	})
118}
119
120func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
121	options := ai.ProviderOptions{}
122
123	switch model.Model.Provider() {
124	case openai.Name:
125		parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
126		if err == nil {
127			options[openai.Name] = parsed
128		}
129	case anthropic.Name:
130		parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
131		if err == nil {
132			options[anthropic.Name] = parsed
133		}
134	case openrouter.Name:
135		parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
136		if err == nil {
137			options[openrouter.Name] = parsed
138		}
139	case google.Name:
140		parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
141		if err == nil {
142			options[google.Name] = parsed
143		}
144	case openaicompat.Name:
145		parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
146		if err == nil {
147			options[openaicompat.Name] = parsed
148		}
149	}
150
151	return options
152}
153
154func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
155	large, small, err := c.buildAgentModels()
156	if err != nil {
157		return nil, err
158	}
159
160	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
161	if err != nil {
162		return nil, err
163	}
164
165	tools, err := c.buildTools(agent)
166	if err != nil {
167		return nil, err
168	}
169	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
170}
171
172func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
173	var allTools []ai.AgentTool
174	if slices.Contains(agent.AllowedTools, AgentToolName) {
175		agentTool, err := c.agentTool()
176		if err != nil {
177			return nil, err
178		}
179		allTools = append(allTools, agentTool)
180	}
181
182	allTools = append(allTools,
183		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
184		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
185		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
186		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
187		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
188		tools.NewGlobTool(c.cfg.WorkingDir()),
189		tools.NewGrepTool(c.cfg.WorkingDir()),
190		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
191		tools.NewSourcegraphTool(nil),
192		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
193		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
194	)
195
196	var filteredTools []ai.AgentTool
197	for _, tool := range allTools {
198		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
199			filteredTools = append(filteredTools, tool)
200		}
201	}
202
203	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
204
205	for _, mcpTool := range mcpTools {
206		if agent.AllowedMCP == nil {
207			// No MCP restrictions
208			filteredTools = append(filteredTools, mcpTool)
209		} else if len(agent.AllowedMCP) == 0 {
210			// no mcps allowed
211			break
212		}
213
214		for mcp, tools := range agent.AllowedMCP {
215			if mcp == mcpTool.MCP() {
216				if len(tools) == 0 {
217					filteredTools = append(filteredTools, mcpTool)
218				}
219				for _, t := range tools {
220					if t == mcpTool.MCPToolName() {
221						filteredTools = append(filteredTools, mcpTool)
222					}
223				}
224				break
225			}
226		}
227	}
228
229	return filteredTools, nil
230}
231
232// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
233func (c *coordinator) buildAgentModels() (Model, Model, error) {
234	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
235	if !ok {
236		return Model{}, Model{}, errors.New("large model not selected")
237	}
238	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
239	if !ok {
240		return Model{}, Model{}, errors.New("small model not selected")
241	}
242
243	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
244	if !ok {
245		return Model{}, Model{}, errors.New("large model provider not configured")
246	}
247
248	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
249	if err != nil {
250		return Model{}, Model{}, err
251	}
252
253	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
254	if !ok {
255		return Model{}, Model{}, errors.New("large model provider not configured")
256	}
257
258	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
259	if err != nil {
260		return Model{}, Model{}, err
261	}
262
263	var largeCatwalkModel *catwalk.Model
264	var smallCatwalkModel *catwalk.Model
265
266	for _, m := range largeProviderCfg.Models {
267		if m.ID == largeModelCfg.Model {
268			largeCatwalkModel = &m
269		}
270	}
271	for _, m := range smallProviderCfg.Models {
272		if m.ID == smallModelCfg.Model {
273			smallCatwalkModel = &m
274		}
275	}
276
277	if largeCatwalkModel == nil {
278		return Model{}, Model{}, errors.New("large model not found in provider config")
279	}
280
281	if smallCatwalkModel == nil {
282		return Model{}, Model{}, errors.New("snall model not found in provider config")
283	}
284
285	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
286	if err != nil {
287		return Model{}, Model{}, err
288	}
289	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
290	if err != nil {
291		return Model{}, Model{}, err
292	}
293
294	return Model{
295			Model:      largeModel,
296			CatwalkCfg: *largeCatwalkModel,
297			ModelCfg:   largeModelCfg,
298		}, Model{
299			Model:      smallModel,
300			CatwalkCfg: *smallCatwalkModel,
301			ModelCfg:   smallModelCfg,
302		}, nil
303}
304
305func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
306	hasBearerAuth := false
307	for key := range headers {
308		if strings.ToLower(key) == "authorization" {
309			hasBearerAuth = true
310			break
311		}
312	}
313	if hasBearerAuth {
314		apiKey = "" // clear apiKey to avoid using X-Api-Key header
315	}
316
317	var opts []anthropic.Option
318
319	if apiKey != "" {
320		// Use standard X-Api-Key header
321		opts = append(opts, anthropic.WithAPIKey(apiKey))
322	}
323
324	if len(headers) > 0 {
325		opts = append(opts, anthropic.WithHeaders(headers))
326	}
327
328	if baseURL != "" {
329		opts = append(opts, anthropic.WithBaseURL(baseURL))
330	}
331
332	if c.cfg.Options.Debug {
333		httpClient := log.NewHTTPClient()
334		opts = append(opts, anthropic.WithHTTPClient(httpClient))
335	}
336
337	return anthropic.New(opts...)
338}
339
340func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
341	opts := []openai.Option{
342		openai.WithAPIKey(apiKey),
343	}
344	if c.cfg.Options.Debug {
345		httpClient := log.NewHTTPClient()
346		opts = append(opts, openai.WithHTTPClient(httpClient))
347	}
348	if len(headers) > 0 {
349		opts = append(opts, openai.WithHeaders(headers))
350	}
351	if baseURL != "" {
352		opts = append(opts, openai.WithBaseURL(baseURL))
353	}
354	return openai.New(opts...)
355}
356
357func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
358	opts := []openrouter.Option{
359		openrouter.WithAPIKey(apiKey),
360	}
361	if c.cfg.Options.Debug {
362		httpClient := log.NewHTTPClient()
363		opts = append(opts, openrouter.WithHTTPClient(httpClient))
364	}
365	if len(headers) > 0 {
366		opts = append(opts, openrouter.WithHeaders(headers))
367	}
368	return openrouter.New(opts...)
369}
370
371func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
372	opts := []openaicompat.Option{
373		openaicompat.WithBaseURL(baseURL),
374		openaicompat.WithAPIKey(apiKey),
375	}
376	if c.cfg.Options.Debug {
377		httpClient := log.NewHTTPClient()
378		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
379	}
380	if len(headers) > 0 {
381		opts = append(opts, openaicompat.WithHeaders(headers))
382	}
383
384	return openaicompat.New(opts...)
385}
386
387func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
388	opts := []azure.Option{
389		azure.WithBaseURL(baseURL),
390		azure.WithAPIKey(apiKey),
391	}
392	if c.cfg.Options.Debug {
393		httpClient := log.NewHTTPClient()
394		opts = append(opts, azure.WithHTTPClient(httpClient))
395	}
396	if options == nil {
397		options = make(map[string]string)
398	}
399	if apiVersion, ok := options["apiVersion"]; ok {
400		opts = append(opts, azure.WithAPIVersion(apiVersion))
401	}
402	if len(headers) > 0 {
403		opts = append(opts, azure.WithHeaders(headers))
404	}
405
406	return azure.New(opts...)
407}
408
409// TODO: add baseURL for google
410func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
411	opts := []google.Option{
412		google.WithAPIKey(apiKey),
413	}
414	if c.cfg.Options.Debug {
415		httpClient := log.NewHTTPClient()
416		opts = append(opts, google.WithHTTPClient(httpClient))
417	}
418	if len(headers) > 0 {
419		opts = append(opts, google.WithHeaders(headers))
420	}
421	return google.New(opts...)
422}
423
424func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
425	if model.Think {
426		return true
427	}
428
429	if model.ProviderOptions == nil {
430		return false
431	}
432
433	opts, err := anthropic.ParseOptions(model.ProviderOptions)
434	if err != nil {
435		return false
436	}
437	if opts.Thinking != nil {
438		return true
439	}
440	return false
441}
442
443func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
444	headers := providerCfg.ExtraHeaders
445
446	// handle special headers for anthropic
447	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
448		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
449	}
450
451	// TODO: make sure we have
452	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
453	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
454	var provider ai.Provider
455	switch providerCfg.Type {
456	case openai.Name:
457		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
458	case anthropic.Name:
459		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
460	case openrouter.Name:
461		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
462	case azure.Name:
463		provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
464	case google.Name:
465		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
466	case openaicompat.Name:
467		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
468	default:
469		return nil, errors.New("provider type not supported")
470	}
471	return provider, nil
472}
473
474func (c *coordinator) Cancel(sessionID string) {
475	c.currentAgent.Cancel(sessionID)
476}
477
478func (c *coordinator) CancelAll() {
479	c.currentAgent.CancelAll()
480}
481
482func (c *coordinator) ClearQueue(sessionID string) {
483	c.currentAgent.ClearQueue(sessionID)
484}
485
486func (c *coordinator) IsBusy() bool {
487	return c.currentAgent.IsBusy()
488}
489
490func (c *coordinator) IsSessionBusy(sessionID string) bool {
491	return c.currentAgent.IsSessionBusy(sessionID)
492}
493
494func (c *coordinator) Model() Model {
495	return c.currentAgent.Model()
496}
497
498func (c *coordinator) UpdateModels() error {
499	// build the models again so we make sure we get the latest config
500	large, small, err := c.buildAgentModels()
501	if err != nil {
502		return err
503	}
504	c.currentAgent.SetModels(large, small)
505
506	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
507	if !ok {
508		return errors.New("coder agent not configured")
509	}
510
511	tools, err := c.buildTools(agentCfg)
512	if err != nil {
513		return err
514	}
515	c.currentAgent.SetTools(tools)
516	return nil
517}
518
519func (c *coordinator) QueuedPrompts(sessionID string) int {
520	return c.currentAgent.QueuedPrompts(sessionID)
521}
522
523func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
524	return c.currentAgent.Summarize(ctx, sessionID)
525}