coordinator.go

  1package agent
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"io"
 10	"log/slog"
 11	"slices"
 12	"strings"
 13
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/agent/prompt"
 16	"github.com/charmbracelet/crush/internal/agent/tools"
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/history"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/message"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/fantasy/ai"
 26	"github.com/charmbracelet/fantasy/anthropic"
 27	"github.com/charmbracelet/fantasy/azure"
 28	"github.com/charmbracelet/fantasy/google"
 29	"github.com/charmbracelet/fantasy/openai"
 30	"github.com/charmbracelet/fantasy/openaicompat"
 31	"github.com/charmbracelet/fantasy/openrouter"
 32	"github.com/qjebbs/go-jsons"
 33)
 34
 35type Coordinator interface {
 36	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 37	// SetMainAgent(string)
 38	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
 39	Cancel(sessionID string)
 40	CancelAll()
 41	IsSessionBusy(sessionID string) bool
 42	IsBusy() bool
 43	QueuedPrompts(sessionID string) int
 44	ClearQueue(sessionID string)
 45	Summarize(context.Context, string) error
 46	Model() Model
 47	UpdateModels() error
 48}
 49
 50type coordinator struct {
 51	cfg         *config.Config
 52	sessions    session.Service
 53	messages    message.Service
 54	permissions permission.Service
 55	history     history.Service
 56	lspClients  *csync.Map[string, *lsp.Client]
 57
 58	currentAgent SessionAgent
 59	agents       map[string]SessionAgent
 60}
 61
 62func NewCoordinator(
 63	cfg *config.Config,
 64	sessions session.Service,
 65	messages message.Service,
 66	permissions permission.Service,
 67	history history.Service,
 68	lspClients *csync.Map[string, *lsp.Client],
 69) (Coordinator, error) {
 70	c := &coordinator{
 71		cfg:         cfg,
 72		sessions:    sessions,
 73		messages:    messages,
 74		permissions: permissions,
 75		history:     history,
 76		lspClients:  lspClients,
 77		agents:      make(map[string]SessionAgent),
 78	}
 79
 80	agentCfg, ok := cfg.Agents[config.AgentCoder]
 81	if !ok {
 82		return nil, errors.New("coder agent not configured")
 83	}
 84
 85	// TODO: make this dynamic when we support multiple agents
 86	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 87	if err != nil {
 88		return nil, err
 89	}
 90
 91	agent, err := c.buildAgent(prompt, agentCfg)
 92	if err != nil {
 93		return nil, err
 94	}
 95	c.currentAgent = agent
 96	c.agents[config.AgentCoder] = agent
 97	return c, nil
 98}
 99
100// Run implements Coordinator.
101func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
102	model := c.currentAgent.Model()
103	maxTokens := model.CatwalkCfg.DefaultMaxTokens
104	if model.ModelCfg.MaxTokens != 0 {
105		maxTokens = model.ModelCfg.MaxTokens
106	}
107
108	if !model.CatwalkCfg.SupportsImages && attachments != nil {
109		attachments = nil
110	}
111
112	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model)
113
114	return c.currentAgent.Run(ctx, SessionAgentCall{
115		SessionID:        sessionID,
116		Prompt:           prompt,
117		Attachments:      attachments,
118		MaxOutputTokens:  maxTokens,
119		ProviderOptions:  mergedOptions,
120		Temperature:      temp,
121		TopP:             topP,
122		TopK:             topK,
123		FrequencyPenalty: freqPenalty,
124		PresencePenalty:  presPenalty,
125	})
126}
127
128func getProviderOptions(model Model) ai.ProviderOptions {
129	options := ai.ProviderOptions{}
130
131	cfgOpts := []byte("{}")
132	catwalkOpts := []byte("{}")
133
134	if model.ModelCfg.ProviderOptions != nil {
135		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
136		if err == nil {
137			cfgOpts = data
138		}
139	}
140
141	if model.CatwalkCfg.Options.ProviderOptions != nil {
142		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
143		if err == nil {
144			catwalkOpts = data
145		}
146	}
147
148	readers := []io.Reader{
149		bytes.NewReader(catwalkOpts),
150		bytes.NewReader(cfgOpts),
151	}
152
153	got, err := jsons.Merge(readers)
154	if err != nil {
155		slog.Error("Could not merge call config", "err", err)
156		return options
157	}
158
159	mergedOptions := make(map[string]any)
160
161	err = json.Unmarshal([]byte(got), &mergedOptions)
162	if err != nil {
163		slog.Error("Could not create config for call", "err", err)
164		return options
165	}
166
167	switch model.Model.Provider() {
168	case openai.Name:
169		parsed, err := openai.ParseOptions(mergedOptions)
170		if err == nil {
171			options[openai.Name] = parsed
172		}
173	case anthropic.Name:
174		parsed, err := anthropic.ParseOptions(mergedOptions)
175		if err == nil {
176			options[anthropic.Name] = parsed
177		}
178	case openrouter.Name:
179		parsed, err := openrouter.ParseOptions(mergedOptions)
180		if err == nil {
181			options[openrouter.Name] = parsed
182		}
183	case google.Name:
184		parsed, err := google.ParseOptions(mergedOptions)
185		if err == nil {
186			options[google.Name] = parsed
187		}
188	case openaicompat.Name:
189		parsed, err := openaicompat.ParseOptions(mergedOptions)
190		if err == nil {
191			options[openaicompat.Name] = parsed
192		}
193	}
194
195	return options
196}
197
198func mergeCallOptions(model Model) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
199	modelOptions := getProviderOptions(model)
200	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
201	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
202	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
203	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
204	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
205	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
206}
207
208func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
209	large, small, err := c.buildAgentModels()
210	if err != nil {
211		return nil, err
212	}
213
214	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
215	if err != nil {
216		return nil, err
217	}
218
219	tools, err := c.buildTools(agent)
220	if err != nil {
221		return nil, err
222	}
223	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
224}
225
226func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
227	var allTools []ai.AgentTool
228	if slices.Contains(agent.AllowedTools, AgentToolName) {
229		agentTool, err := c.agentTool()
230		if err != nil {
231			return nil, err
232		}
233		allTools = append(allTools, agentTool)
234	}
235
236	allTools = append(allTools,
237		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
238		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
239		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
240		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
241		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
242		tools.NewGlobTool(c.cfg.WorkingDir()),
243		tools.NewGrepTool(c.cfg.WorkingDir()),
244		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
245		tools.NewSourcegraphTool(nil),
246		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
247		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
248	)
249
250	var filteredTools []ai.AgentTool
251	for _, tool := range allTools {
252		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
253			filteredTools = append(filteredTools, tool)
254		}
255	}
256
257	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
258
259	for _, mcpTool := range mcpTools {
260		if agent.AllowedMCP == nil {
261			// No MCP restrictions
262			filteredTools = append(filteredTools, mcpTool)
263		} else if len(agent.AllowedMCP) == 0 {
264			// no mcps allowed
265			break
266		}
267
268		for mcp, tools := range agent.AllowedMCP {
269			if mcp == mcpTool.MCP() {
270				if len(tools) == 0 {
271					filteredTools = append(filteredTools, mcpTool)
272				}
273				for _, t := range tools {
274					if t == mcpTool.MCPToolName() {
275						filteredTools = append(filteredTools, mcpTool)
276					}
277				}
278				break
279			}
280		}
281	}
282
283	return filteredTools, nil
284}
285
286// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
287func (c *coordinator) buildAgentModels() (Model, Model, error) {
288	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
289	if !ok {
290		return Model{}, Model{}, errors.New("large model not selected")
291	}
292	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
293	if !ok {
294		return Model{}, Model{}, errors.New("small model not selected")
295	}
296
297	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
298	if !ok {
299		return Model{}, Model{}, errors.New("large model provider not configured")
300	}
301
302	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
303	if err != nil {
304		return Model{}, Model{}, err
305	}
306
307	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
308	if !ok {
309		return Model{}, Model{}, errors.New("large model provider not configured")
310	}
311
312	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
313	if err != nil {
314		return Model{}, Model{}, err
315	}
316
317	var largeCatwalkModel *catwalk.Model
318	var smallCatwalkModel *catwalk.Model
319
320	for _, m := range largeProviderCfg.Models {
321		if m.ID == largeModelCfg.Model {
322			largeCatwalkModel = &m
323		}
324	}
325	for _, m := range smallProviderCfg.Models {
326		if m.ID == smallModelCfg.Model {
327			smallCatwalkModel = &m
328		}
329	}
330
331	if largeCatwalkModel == nil {
332		return Model{}, Model{}, errors.New("large model not found in provider config")
333	}
334
335	if smallCatwalkModel == nil {
336		return Model{}, Model{}, errors.New("snall model not found in provider config")
337	}
338
339	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
340	if err != nil {
341		return Model{}, Model{}, err
342	}
343	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
344	if err != nil {
345		return Model{}, Model{}, err
346	}
347
348	return Model{
349			Model:      largeModel,
350			CatwalkCfg: *largeCatwalkModel,
351			ModelCfg:   largeModelCfg,
352		}, Model{
353			Model:      smallModel,
354			CatwalkCfg: *smallCatwalkModel,
355			ModelCfg:   smallModelCfg,
356		}, nil
357}
358
359func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
360	hasBearerAuth := false
361	for key := range headers {
362		if strings.ToLower(key) == "authorization" {
363			hasBearerAuth = true
364			break
365		}
366	}
367	if hasBearerAuth {
368		apiKey = "" // clear apiKey to avoid using X-Api-Key header
369	}
370
371	var opts []anthropic.Option
372
373	if apiKey != "" {
374		// Use standard X-Api-Key header
375		opts = append(opts, anthropic.WithAPIKey(apiKey))
376	}
377
378	if len(headers) > 0 {
379		opts = append(opts, anthropic.WithHeaders(headers))
380	}
381
382	if baseURL != "" {
383		opts = append(opts, anthropic.WithBaseURL(baseURL))
384	}
385
386	if c.cfg.Options.Debug {
387		httpClient := log.NewHTTPClient()
388		opts = append(opts, anthropic.WithHTTPClient(httpClient))
389	}
390
391	return anthropic.New(opts...)
392}
393
394func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
395	opts := []openai.Option{
396		openai.WithAPIKey(apiKey),
397	}
398	if c.cfg.Options.Debug {
399		httpClient := log.NewHTTPClient()
400		opts = append(opts, openai.WithHTTPClient(httpClient))
401	}
402	if len(headers) > 0 {
403		opts = append(opts, openai.WithHeaders(headers))
404	}
405	if baseURL != "" {
406		opts = append(opts, openai.WithBaseURL(baseURL))
407	}
408	return openai.New(opts...)
409}
410
411func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
412	opts := []openrouter.Option{
413		openrouter.WithAPIKey(apiKey),
414	}
415	if c.cfg.Options.Debug {
416		httpClient := log.NewHTTPClient()
417		opts = append(opts, openrouter.WithHTTPClient(httpClient))
418	}
419	if len(headers) > 0 {
420		opts = append(opts, openrouter.WithHeaders(headers))
421	}
422	return openrouter.New(opts...)
423}
424
425func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
426	opts := []openaicompat.Option{
427		openaicompat.WithBaseURL(baseURL),
428		openaicompat.WithAPIKey(apiKey),
429	}
430	if c.cfg.Options.Debug {
431		httpClient := log.NewHTTPClient()
432		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
433	}
434	if len(headers) > 0 {
435		opts = append(opts, openaicompat.WithHeaders(headers))
436	}
437
438	return openaicompat.New(opts...)
439}
440
441func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
442	opts := []azure.Option{
443		azure.WithBaseURL(baseURL),
444		azure.WithAPIKey(apiKey),
445	}
446	if c.cfg.Options.Debug {
447		httpClient := log.NewHTTPClient()
448		opts = append(opts, azure.WithHTTPClient(httpClient))
449	}
450	if options == nil {
451		options = make(map[string]string)
452	}
453	if apiVersion, ok := options["apiVersion"]; ok {
454		opts = append(opts, azure.WithAPIVersion(apiVersion))
455	}
456	if len(headers) > 0 {
457		opts = append(opts, azure.WithHeaders(headers))
458	}
459
460	return azure.New(opts...)
461}
462
463// TODO: add baseURL for google
464func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
465	opts := []google.Option{
466		google.WithAPIKey(apiKey),
467	}
468	if c.cfg.Options.Debug {
469		httpClient := log.NewHTTPClient()
470		opts = append(opts, google.WithHTTPClient(httpClient))
471	}
472	if len(headers) > 0 {
473		opts = append(opts, google.WithHeaders(headers))
474	}
475	return google.New(opts...)
476}
477
478func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
479	if model.Think {
480		return true
481	}
482
483	if model.ProviderOptions == nil {
484		return false
485	}
486
487	opts, err := anthropic.ParseOptions(model.ProviderOptions)
488	if err != nil {
489		return false
490	}
491	if opts.Thinking != nil {
492		return true
493	}
494	return false
495}
496
497func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
498	headers := providerCfg.ExtraHeaders
499
500	// handle special headers for anthropic
501	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
502		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
503	}
504
505	// TODO: make sure we have
506	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
507	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
508	var provider ai.Provider
509	switch providerCfg.Type {
510	case openai.Name:
511		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
512	case anthropic.Name:
513		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
514	case openrouter.Name:
515		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
516	case azure.Name:
517		provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
518	case google.Name:
519		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
520	case openaicompat.Name:
521		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
522	default:
523		return nil, errors.New("provider type not supported")
524	}
525	return provider, nil
526}
527
528func (c *coordinator) Cancel(sessionID string) {
529	c.currentAgent.Cancel(sessionID)
530}
531
532func (c *coordinator) CancelAll() {
533	c.currentAgent.CancelAll()
534}
535
536func (c *coordinator) ClearQueue(sessionID string) {
537	c.currentAgent.ClearQueue(sessionID)
538}
539
540func (c *coordinator) IsBusy() bool {
541	return c.currentAgent.IsBusy()
542}
543
544func (c *coordinator) IsSessionBusy(sessionID string) bool {
545	return c.currentAgent.IsSessionBusy(sessionID)
546}
547
548func (c *coordinator) Model() Model {
549	return c.currentAgent.Model()
550}
551
552func (c *coordinator) UpdateModels() error {
553	// build the models again so we make sure we get the latest config
554	large, small, err := c.buildAgentModels()
555	if err != nil {
556		return err
557	}
558	c.currentAgent.SetModels(large, small)
559
560	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
561	if !ok {
562		return errors.New("coder agent not configured")
563	}
564
565	tools, err := c.buildTools(agentCfg)
566	if err != nil {
567		return err
568	}
569	c.currentAgent.SetTools(tools)
570	return nil
571}
572
573func (c *coordinator) QueuedPrompts(sessionID string) int {
574	return c.currentAgent.QueuedPrompts(sessionID)
575}
576
577func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
578	return c.currentAgent.Summarize(ctx, sessionID)
579}