coordinator.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"log/slog"
  9	"slices"
 10	"strings"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/agent/prompt"
 14	"github.com/charmbracelet/crush/internal/agent/tools"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/log"
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/crush/internal/message"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/session"
 23	"github.com/charmbracelet/fantasy/ai"
 24	"github.com/charmbracelet/fantasy/anthropic"
 25	"github.com/charmbracelet/fantasy/azure"
 26	"github.com/charmbracelet/fantasy/google"
 27	"github.com/charmbracelet/fantasy/openai"
 28	"github.com/charmbracelet/fantasy/openaicompat"
 29	"github.com/charmbracelet/fantasy/openrouter"
 30	"github.com/qjebbs/go-jsons"
 31)
 32
 33type Coordinator interface {
 34	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 35	// SetMainAgent(string)
 36	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
 37	Cancel(sessionID string)
 38	CancelAll()
 39	IsSessionBusy(sessionID string) bool
 40	IsBusy() bool
 41	QueuedPrompts(sessionID string) int
 42	ClearQueue(sessionID string)
 43	Summarize(context.Context, string) error
 44	Model() Model
 45	UpdateModels() error
 46}
 47
 48type coordinator struct {
 49	cfg         *config.Config
 50	sessions    session.Service
 51	messages    message.Service
 52	permissions permission.Service
 53	history     history.Service
 54	lspClients  *csync.Map[string, *lsp.Client]
 55
 56	currentAgent SessionAgent
 57	agents       map[string]SessionAgent
 58}
 59
 60func NewCoordinator(
 61	cfg *config.Config,
 62	sessions session.Service,
 63	messages message.Service,
 64	permissions permission.Service,
 65	history history.Service,
 66	lspClients *csync.Map[string, *lsp.Client],
 67) (Coordinator, error) {
 68	c := &coordinator{
 69		cfg:         cfg,
 70		sessions:    sessions,
 71		messages:    messages,
 72		permissions: permissions,
 73		history:     history,
 74		lspClients:  lspClients,
 75		agents:      make(map[string]SessionAgent),
 76	}
 77
 78	agentCfg, ok := cfg.Agents[config.AgentCoder]
 79	if !ok {
 80		return nil, errors.New("coder agent not configured")
 81	}
 82
 83	// TODO: make this dynamic when we support multiple agents
 84	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 85	if err != nil {
 86		return nil, err
 87	}
 88
 89	agent, err := c.buildAgent(prompt, agentCfg)
 90	if err != nil {
 91		return nil, err
 92	}
 93	c.currentAgent = agent
 94	c.agents[config.AgentCoder] = agent
 95	return c, nil
 96}
 97
 98// Run implements Coordinator.
 99func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
100	model := c.currentAgent.Model()
101	maxTokens := model.CatwalkCfg.DefaultMaxTokens
102	if model.ModelCfg.MaxTokens != 0 {
103		maxTokens = model.ModelCfg.MaxTokens
104	}
105
106	if !model.CatwalkCfg.SupportsImages && attachments != nil {
107		attachments = nil
108	}
109
110	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model)
111
112	return c.currentAgent.Run(ctx, SessionAgentCall{
113		SessionID:        sessionID,
114		Prompt:           prompt,
115		Attachments:      attachments,
116		MaxOutputTokens:  maxTokens,
117		ProviderOptions:  mergedOptions,
118		Temperature:      temp,
119		TopP:             topP,
120		TopK:             topK,
121		FrequencyPenalty: freqPenalty,
122		PresencePenalty:  presPenalty,
123	})
124}
125
126func getProviderOptions(model Model) ai.ProviderOptions {
127	options := ai.ProviderOptions{}
128
129	cfgOpts := "{}"
130	catwalkOpts := "{}"
131
132	if model.ModelCfg.ProviderOptions != nil {
133		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
134		if err == nil {
135			cfgOpts = string(data)
136		}
137	}
138
139	if model.CatwalkCfg.Options.ProviderOptions != nil {
140		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
141		if err == nil {
142			catwalkOpts = string(data)
143		}
144	}
145
146	got, err := jsons.Merge([]string{catwalkOpts, cfgOpts})
147	if err != nil {
148		slog.Error("Could not merge call config", "err", err)
149		return options
150	}
151
152	mergedOptions := make(map[string]any)
153
154	err = json.Unmarshal([]byte(got), &mergedOptions)
155	if err != nil {
156		slog.Error("Could not create config for call", "err", err)
157		return options
158	}
159
160	switch model.Model.Provider() {
161	case openai.Name:
162		parsed, err := openai.ParseOptions(mergedOptions)
163		if err == nil {
164			options[openai.Name] = parsed
165		}
166	case anthropic.Name:
167		parsed, err := anthropic.ParseOptions(mergedOptions)
168		if err == nil {
169			options[anthropic.Name] = parsed
170		}
171	case openrouter.Name:
172		parsed, err := openrouter.ParseOptions(mergedOptions)
173		if err == nil {
174			options[openrouter.Name] = parsed
175		}
176	case google.Name:
177		parsed, err := google.ParseOptions(mergedOptions)
178		if err == nil {
179			options[google.Name] = parsed
180		}
181	case openaicompat.Name:
182		parsed, err := openaicompat.ParseOptions(mergedOptions)
183		if err == nil {
184			options[openaicompat.Name] = parsed
185		}
186	}
187
188	return options
189}
190
191func mergeCallOptions(model Model) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
192	modelOptions := getProviderOptions(model)
193	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
194	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
195	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
196	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
197	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
198	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
199}
200
201func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
202	large, small, err := c.buildAgentModels()
203	if err != nil {
204		return nil, err
205	}
206
207	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
208	if err != nil {
209		return nil, err
210	}
211
212	tools, err := c.buildTools(agent)
213	if err != nil {
214		return nil, err
215	}
216	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
217}
218
219func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
220	var allTools []ai.AgentTool
221	if slices.Contains(agent.AllowedTools, AgentToolName) {
222		agentTool, err := c.agentTool()
223		if err != nil {
224			return nil, err
225		}
226		allTools = append(allTools, agentTool)
227	}
228
229	allTools = append(allTools,
230		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
231		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
232		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
233		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
234		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
235		tools.NewGlobTool(c.cfg.WorkingDir()),
236		tools.NewGrepTool(c.cfg.WorkingDir()),
237		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
238		tools.NewSourcegraphTool(nil),
239		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
240		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
241	)
242
243	var filteredTools []ai.AgentTool
244	for _, tool := range allTools {
245		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
246			filteredTools = append(filteredTools, tool)
247		}
248	}
249
250	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
251
252	for _, mcpTool := range mcpTools {
253		if agent.AllowedMCP == nil {
254			// No MCP restrictions
255			filteredTools = append(filteredTools, mcpTool)
256		} else if len(agent.AllowedMCP) == 0 {
257			// no mcps allowed
258			break
259		}
260
261		for mcp, tools := range agent.AllowedMCP {
262			if mcp == mcpTool.MCP() {
263				if len(tools) == 0 {
264					filteredTools = append(filteredTools, mcpTool)
265				}
266				for _, t := range tools {
267					if t == mcpTool.MCPToolName() {
268						filteredTools = append(filteredTools, mcpTool)
269					}
270				}
271				break
272			}
273		}
274	}
275
276	return filteredTools, nil
277}
278
279// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
280func (c *coordinator) buildAgentModels() (Model, Model, error) {
281	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
282	if !ok {
283		return Model{}, Model{}, errors.New("large model not selected")
284	}
285	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
286	if !ok {
287		return Model{}, Model{}, errors.New("small model not selected")
288	}
289
290	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
291	if !ok {
292		return Model{}, Model{}, errors.New("large model provider not configured")
293	}
294
295	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
296	if err != nil {
297		return Model{}, Model{}, err
298	}
299
300	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
301	if !ok {
302		return Model{}, Model{}, errors.New("large model provider not configured")
303	}
304
305	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
306	if err != nil {
307		return Model{}, Model{}, err
308	}
309
310	var largeCatwalkModel *catwalk.Model
311	var smallCatwalkModel *catwalk.Model
312
313	for _, m := range largeProviderCfg.Models {
314		if m.ID == largeModelCfg.Model {
315			largeCatwalkModel = &m
316		}
317	}
318	for _, m := range smallProviderCfg.Models {
319		if m.ID == smallModelCfg.Model {
320			smallCatwalkModel = &m
321		}
322	}
323
324	if largeCatwalkModel == nil {
325		return Model{}, Model{}, errors.New("large model not found in provider config")
326	}
327
328	if smallCatwalkModel == nil {
329		return Model{}, Model{}, errors.New("snall model not found in provider config")
330	}
331
332	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
333	if err != nil {
334		return Model{}, Model{}, err
335	}
336	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
337	if err != nil {
338		return Model{}, Model{}, err
339	}
340
341	return Model{
342			Model:      largeModel,
343			CatwalkCfg: *largeCatwalkModel,
344			ModelCfg:   largeModelCfg,
345		}, Model{
346			Model:      smallModel,
347			CatwalkCfg: *smallCatwalkModel,
348			ModelCfg:   smallModelCfg,
349		}, nil
350}
351
352func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
353	hasBearerAuth := false
354	for key := range headers {
355		if strings.ToLower(key) == "authorization" {
356			hasBearerAuth = true
357			break
358		}
359	}
360	if hasBearerAuth {
361		apiKey = "" // clear apiKey to avoid using X-Api-Key header
362	}
363
364	var opts []anthropic.Option
365
366	if apiKey != "" {
367		// Use standard X-Api-Key header
368		opts = append(opts, anthropic.WithAPIKey(apiKey))
369	}
370
371	if len(headers) > 0 {
372		opts = append(opts, anthropic.WithHeaders(headers))
373	}
374
375	if baseURL != "" {
376		opts = append(opts, anthropic.WithBaseURL(baseURL))
377	}
378
379	if c.cfg.Options.Debug {
380		httpClient := log.NewHTTPClient()
381		opts = append(opts, anthropic.WithHTTPClient(httpClient))
382	}
383
384	return anthropic.New(opts...)
385}
386
387func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
388	opts := []openai.Option{
389		openai.WithAPIKey(apiKey),
390	}
391	if c.cfg.Options.Debug {
392		httpClient := log.NewHTTPClient()
393		opts = append(opts, openai.WithHTTPClient(httpClient))
394	}
395	if len(headers) > 0 {
396		opts = append(opts, openai.WithHeaders(headers))
397	}
398	if baseURL != "" {
399		opts = append(opts, openai.WithBaseURL(baseURL))
400	}
401	return openai.New(opts...)
402}
403
404func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
405	opts := []openrouter.Option{
406		openrouter.WithAPIKey(apiKey),
407	}
408	if c.cfg.Options.Debug {
409		httpClient := log.NewHTTPClient()
410		opts = append(opts, openrouter.WithHTTPClient(httpClient))
411	}
412	if len(headers) > 0 {
413		opts = append(opts, openrouter.WithHeaders(headers))
414	}
415	return openrouter.New(opts...)
416}
417
418func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
419	opts := []openaicompat.Option{
420		openaicompat.WithBaseURL(baseURL),
421		openaicompat.WithAPIKey(apiKey),
422	}
423	if c.cfg.Options.Debug {
424		httpClient := log.NewHTTPClient()
425		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
426	}
427	if len(headers) > 0 {
428		opts = append(opts, openaicompat.WithHeaders(headers))
429	}
430
431	return openaicompat.New(opts...)
432}
433
434func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
435	opts := []azure.Option{
436		azure.WithBaseURL(baseURL),
437		azure.WithAPIKey(apiKey),
438	}
439	if c.cfg.Options.Debug {
440		httpClient := log.NewHTTPClient()
441		opts = append(opts, azure.WithHTTPClient(httpClient))
442	}
443	if options == nil {
444		options = make(map[string]string)
445	}
446	if apiVersion, ok := options["apiVersion"]; ok {
447		opts = append(opts, azure.WithAPIVersion(apiVersion))
448	}
449	if len(headers) > 0 {
450		opts = append(opts, azure.WithHeaders(headers))
451	}
452
453	return azure.New(opts...)
454}
455
456// TODO: add baseURL for google
457func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
458	opts := []google.Option{
459		google.WithAPIKey(apiKey),
460	}
461	if c.cfg.Options.Debug {
462		httpClient := log.NewHTTPClient()
463		opts = append(opts, google.WithHTTPClient(httpClient))
464	}
465	if len(headers) > 0 {
466		opts = append(opts, google.WithHeaders(headers))
467	}
468	return google.New(opts...)
469}
470
471func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
472	if model.Think {
473		return true
474	}
475
476	if model.ProviderOptions == nil {
477		return false
478	}
479
480	opts, err := anthropic.ParseOptions(model.ProviderOptions)
481	if err != nil {
482		return false
483	}
484	if opts.Thinking != nil {
485		return true
486	}
487	return false
488}
489
490func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
491	headers := providerCfg.ExtraHeaders
492
493	// handle special headers for anthropic
494	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
495		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
496	}
497
498	// TODO: make sure we have
499	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
500	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
501	var provider ai.Provider
502	switch providerCfg.Type {
503	case openai.Name:
504		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
505	case anthropic.Name:
506		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
507	case openrouter.Name:
508		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
509	case azure.Name:
510		provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
511	case google.Name:
512		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
513	case openaicompat.Name:
514		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
515	default:
516		return nil, errors.New("provider type not supported")
517	}
518	return provider, nil
519}
520
521func (c *coordinator) Cancel(sessionID string) {
522	c.currentAgent.Cancel(sessionID)
523}
524
525func (c *coordinator) CancelAll() {
526	c.currentAgent.CancelAll()
527}
528
529func (c *coordinator) ClearQueue(sessionID string) {
530	c.currentAgent.ClearQueue(sessionID)
531}
532
533func (c *coordinator) IsBusy() bool {
534	return c.currentAgent.IsBusy()
535}
536
537func (c *coordinator) IsSessionBusy(sessionID string) bool {
538	return c.currentAgent.IsSessionBusy(sessionID)
539}
540
541func (c *coordinator) Model() Model {
542	return c.currentAgent.Model()
543}
544
545func (c *coordinator) UpdateModels() error {
546	// build the models again so we make sure we get the latest config
547	large, small, err := c.buildAgentModels()
548	if err != nil {
549		return err
550	}
551	c.currentAgent.SetModels(large, small)
552
553	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
554	if !ok {
555		return errors.New("coder agent not configured")
556	}
557
558	tools, err := c.buildTools(agentCfg)
559	if err != nil {
560		return err
561	}
562	c.currentAgent.SetTools(tools)
563	return nil
564}
565
566func (c *coordinator) QueuedPrompts(sessionID string) int {
567	return c.currentAgent.QueuedPrompts(sessionID)
568}
569
570func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
571	return c.currentAgent.Summarize(ctx, sessionID)
572}