coordinator.go

  1package agent
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"io"
 10	"log/slog"
 11	"slices"
 12	"strings"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/agent/prompt"
 17	"github.com/charmbracelet/crush/internal/agent/tools"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/history"
 21	"github.com/charmbracelet/crush/internal/log"
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26
 27	"charm.land/fantasy/providers/anthropic"
 28	"charm.land/fantasy/providers/azure"
 29	"charm.land/fantasy/providers/google"
 30	"charm.land/fantasy/providers/openai"
 31	"charm.land/fantasy/providers/openaicompat"
 32	"charm.land/fantasy/providers/openrouter"
 33	"github.com/qjebbs/go-jsons"
 34)
 35
 36type Coordinator interface {
 37	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 38	// SetMainAgent(string)
 39	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
 40	Cancel(sessionID string)
 41	CancelAll()
 42	IsSessionBusy(sessionID string) bool
 43	IsBusy() bool
 44	QueuedPrompts(sessionID string) int
 45	ClearQueue(sessionID string)
 46	Summarize(context.Context, string) error
 47	Model() Model
 48	UpdateModels() error
 49}
 50
 51type coordinator struct {
 52	cfg         *config.Config
 53	sessions    session.Service
 54	messages    message.Service
 55	permissions permission.Service
 56	history     history.Service
 57	lspClients  *csync.Map[string, *lsp.Client]
 58
 59	currentAgent SessionAgent
 60	agents       map[string]SessionAgent
 61}
 62
 63func NewCoordinator(
 64	cfg *config.Config,
 65	sessions session.Service,
 66	messages message.Service,
 67	permissions permission.Service,
 68	history history.Service,
 69	lspClients *csync.Map[string, *lsp.Client],
 70) (Coordinator, error) {
 71	c := &coordinator{
 72		cfg:         cfg,
 73		sessions:    sessions,
 74		messages:    messages,
 75		permissions: permissions,
 76		history:     history,
 77		lspClients:  lspClients,
 78		agents:      make(map[string]SessionAgent),
 79	}
 80
 81	agentCfg, ok := cfg.Agents[config.AgentCoder]
 82	if !ok {
 83		return nil, errors.New("coder agent not configured")
 84	}
 85
 86	// TODO: make this dynamic when we support multiple agents
 87	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 88	if err != nil {
 89		return nil, err
 90	}
 91
 92	agent, err := c.buildAgent(prompt, agentCfg)
 93	if err != nil {
 94		return nil, err
 95	}
 96	c.currentAgent = agent
 97	c.agents[config.AgentCoder] = agent
 98	return c, nil
 99}
100
101// Run implements Coordinator.
102func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
103	model := c.currentAgent.Model()
104	maxTokens := model.CatwalkCfg.DefaultMaxTokens
105	if model.ModelCfg.MaxTokens != 0 {
106		maxTokens = model.ModelCfg.MaxTokens
107	}
108
109	if !model.CatwalkCfg.SupportsImages && attachments != nil {
110		attachments = nil
111	}
112
113	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
114	if !ok {
115		return nil, errors.New("model provider not configured")
116	}
117
118	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
119
120	return c.currentAgent.Run(ctx, SessionAgentCall{
121		SessionID:        sessionID,
122		Prompt:           prompt,
123		Attachments:      attachments,
124		MaxOutputTokens:  maxTokens,
125		ProviderOptions:  mergedOptions,
126		Temperature:      temp,
127		TopP:             topP,
128		TopK:             topK,
129		FrequencyPenalty: freqPenalty,
130		PresencePenalty:  presPenalty,
131	})
132}
133
134func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
135	options := fantasy.ProviderOptions{}
136
137	cfgOpts := []byte("{}")
138	catwalkOpts := []byte("{}")
139
140	if model.ModelCfg.ProviderOptions != nil {
141		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
142		if err == nil {
143			cfgOpts = data
144		}
145	}
146
147	if model.CatwalkCfg.Options.ProviderOptions != nil {
148		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
149		if err == nil {
150			catwalkOpts = data
151		}
152	}
153
154	readers := []io.Reader{
155		bytes.NewReader(catwalkOpts),
156		bytes.NewReader(cfgOpts),
157	}
158
159	got, err := jsons.Merge(readers)
160	if err != nil {
161		slog.Error("Could not merge call config", "err", err)
162		return options
163	}
164
165	mergedOptions := make(map[string]any)
166
167	err = json.Unmarshal([]byte(got), &mergedOptions)
168	if err != nil {
169		slog.Error("Could not create config for call", "err", err)
170		return options
171	}
172
173	switch tp {
174	case openai.Name:
175		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
176		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
177			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
178		}
179		parsed, err := openai.ParseOptions(mergedOptions)
180		if err == nil {
181			options[openai.Name] = parsed
182		}
183	case anthropic.Name:
184		_, hasThink := mergedOptions["thinking"]
185		if !hasThink && model.ModelCfg.Think {
186			mergedOptions["thinking"] = map[string]any{
187				// TODO: kujtim see if we need to make this dynamic
188				"budget_tokens": 2000,
189			}
190		}
191		parsed, err := anthropic.ParseOptions(mergedOptions)
192		if err == nil {
193			options[anthropic.Name] = parsed
194		}
195
196	case openrouter.Name:
197		_, hasReasoning := mergedOptions["reasoning"]
198		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
199			mergedOptions["reasoning"] = map[string]any{
200				"enabled": true,
201				"effort":  model.ModelCfg.ReasoningEffort,
202			}
203		}
204		parsed, err := openrouter.ParseOptions(mergedOptions)
205		if err == nil {
206			options[openrouter.Name] = parsed
207		}
208	case google.Name:
209		parsed, err := google.ParseOptions(mergedOptions)
210		if err == nil {
211			options[google.Name] = parsed
212		}
213	case azure.Name:
214		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
215		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
216			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
217		}
218		// azure uses the same options as openaicompat
219		parsed, err := openaicompat.ParseOptions(mergedOptions)
220		if err == nil {
221			options[azure.Name] = parsed
222		}
223	case openaicompat.Name:
224		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
225		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
226			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
227		}
228		parsed, err := openaicompat.ParseOptions(mergedOptions)
229		if err == nil {
230			options[openaicompat.Name] = parsed
231		}
232	}
233
234	return options
235}
236
237func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
238	modelOptions := getProviderOptions(model, tp)
239	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
240	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
241	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
242	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
243	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
244	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
245}
246
247func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
248	large, small, err := c.buildAgentModels()
249	if err != nil {
250		return nil, err
251	}
252
253	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
254	if err != nil {
255		return nil, err
256	}
257
258	tools, err := c.buildTools(agent)
259	if err != nil {
260		return nil, err
261	}
262	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
263}
264
265func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error) {
266	var allTools []fantasy.AgentTool
267	if slices.Contains(agent.AllowedTools, AgentToolName) {
268		agentTool, err := c.agentTool()
269		if err != nil {
270			return nil, err
271		}
272		allTools = append(allTools, agentTool)
273	}
274
275	allTools = append(allTools,
276		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
277		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
278		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
279		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
280		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
281		tools.NewGlobTool(c.cfg.WorkingDir()),
282		tools.NewGrepTool(c.cfg.WorkingDir()),
283		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
284		tools.NewSourcegraphTool(nil),
285		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
286		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
287	)
288
289	var filteredTools []fantasy.AgentTool
290	for _, tool := range allTools {
291		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
292			filteredTools = append(filteredTools, tool)
293		}
294	}
295
296	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
297
298	for _, mcpTool := range mcpTools {
299		if agent.AllowedMCP == nil {
300			// No MCP restrictions
301			filteredTools = append(filteredTools, mcpTool)
302		} else if len(agent.AllowedMCP) == 0 {
303			// no mcps allowed
304			break
305		}
306
307		for mcp, tools := range agent.AllowedMCP {
308			if mcp == mcpTool.MCP() {
309				if len(tools) == 0 {
310					filteredTools = append(filteredTools, mcpTool)
311				}
312				for _, t := range tools {
313					if t == mcpTool.MCPToolName() {
314						filteredTools = append(filteredTools, mcpTool)
315					}
316				}
317				break
318			}
319		}
320	}
321
322	return filteredTools, nil
323}
324
325// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
326func (c *coordinator) buildAgentModels() (Model, Model, error) {
327	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
328	if !ok {
329		return Model{}, Model{}, errors.New("large model not selected")
330	}
331	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
332	if !ok {
333		return Model{}, Model{}, errors.New("small model not selected")
334	}
335
336	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
337	if !ok {
338		return Model{}, Model{}, errors.New("large model provider not configured")
339	}
340
341	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
342	if err != nil {
343		return Model{}, Model{}, err
344	}
345
346	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
347	if !ok {
348		return Model{}, Model{}, errors.New("large model provider not configured")
349	}
350
351	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
352	if err != nil {
353		return Model{}, Model{}, err
354	}
355
356	var largeCatwalkModel *catwalk.Model
357	var smallCatwalkModel *catwalk.Model
358
359	for _, m := range largeProviderCfg.Models {
360		if m.ID == largeModelCfg.Model {
361			largeCatwalkModel = &m
362		}
363	}
364	for _, m := range smallProviderCfg.Models {
365		if m.ID == smallModelCfg.Model {
366			smallCatwalkModel = &m
367		}
368	}
369
370	if largeCatwalkModel == nil {
371		return Model{}, Model{}, errors.New("large model not found in provider config")
372	}
373
374	if smallCatwalkModel == nil {
375		return Model{}, Model{}, errors.New("snall model not found in provider config")
376	}
377
378	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
379	if err != nil {
380		return Model{}, Model{}, err
381	}
382	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
383	if err != nil {
384		return Model{}, Model{}, err
385	}
386
387	return Model{
388			Model:      largeModel,
389			CatwalkCfg: *largeCatwalkModel,
390			ModelCfg:   largeModelCfg,
391		}, Model{
392			Model:      smallModel,
393			CatwalkCfg: *smallCatwalkModel,
394			ModelCfg:   smallModelCfg,
395		}, nil
396}
397
398func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
399	hasBearerAuth := false
400	for key := range headers {
401		if strings.ToLower(key) == "authorization" {
402			hasBearerAuth = true
403			break
404		}
405	}
406	if hasBearerAuth {
407		apiKey = "" // clear apiKey to avoid using X-Api-Key header
408	}
409
410	var opts []anthropic.Option
411
412	if apiKey != "" {
413		// Use standard X-Api-Key header
414		opts = append(opts, anthropic.WithAPIKey(apiKey))
415	}
416
417	if len(headers) > 0 {
418		opts = append(opts, anthropic.WithHeaders(headers))
419	}
420
421	if baseURL != "" {
422		opts = append(opts, anthropic.WithBaseURL(baseURL))
423	}
424
425	if c.cfg.Options.Debug {
426		httpClient := log.NewHTTPClient()
427		opts = append(opts, anthropic.WithHTTPClient(httpClient))
428	}
429
430	return anthropic.New(opts...)
431}
432
433func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
434	opts := []openai.Option{
435		openai.WithAPIKey(apiKey),
436	}
437	if c.cfg.Options.Debug {
438		httpClient := log.NewHTTPClient()
439		opts = append(opts, openai.WithHTTPClient(httpClient))
440	}
441	if len(headers) > 0 {
442		opts = append(opts, openai.WithHeaders(headers))
443	}
444	if baseURL != "" {
445		opts = append(opts, openai.WithBaseURL(baseURL))
446	}
447	return openai.New(opts...)
448}
449
450func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) fantasy.Provider {
451	opts := []openrouter.Option{
452		openrouter.WithAPIKey(apiKey),
453	}
454	if c.cfg.Options.Debug {
455		httpClient := log.NewHTTPClient()
456		opts = append(opts, openrouter.WithHTTPClient(httpClient))
457	}
458	if len(headers) > 0 {
459		opts = append(opts, openrouter.WithHeaders(headers))
460	}
461	return openrouter.New(opts...)
462}
463
464func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
465	opts := []openaicompat.Option{
466		openaicompat.WithBaseURL(baseURL),
467		openaicompat.WithAPIKey(apiKey),
468	}
469	if c.cfg.Options.Debug {
470		httpClient := log.NewHTTPClient()
471		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
472	}
473	if len(headers) > 0 {
474		opts = append(opts, openaicompat.WithHeaders(headers))
475	}
476
477	return openaicompat.New(opts...)
478}
479
480func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) fantasy.Provider {
481	opts := []azure.Option{
482		azure.WithBaseURL(baseURL),
483		azure.WithAPIKey(apiKey),
484	}
485	if c.cfg.Options.Debug {
486		httpClient := log.NewHTTPClient()
487		opts = append(opts, azure.WithHTTPClient(httpClient))
488	}
489	if options == nil {
490		options = make(map[string]string)
491	}
492	if apiVersion, ok := options["apiVersion"]; ok {
493		opts = append(opts, azure.WithAPIVersion(apiVersion))
494	}
495	if len(headers) > 0 {
496		opts = append(opts, azure.WithHeaders(headers))
497	}
498
499	return azure.New(opts...)
500}
501
502func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
503	opts := []google.Option{
504		google.WithBaseURL(baseURL),
505		google.WithGeminiAPIKey(apiKey),
506	}
507	if c.cfg.Options.Debug {
508		httpClient := log.NewHTTPClient()
509		opts = append(opts, google.WithHTTPClient(httpClient))
510	}
511	if len(headers) > 0 {
512		opts = append(opts, google.WithHeaders(headers))
513	}
514	return google.New(opts...)
515}
516
517func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) fantasy.Provider {
518	opts := []google.Option{}
519	if c.cfg.Options.Debug {
520		httpClient := log.NewHTTPClient()
521		opts = append(opts, google.WithHTTPClient(httpClient))
522	}
523	if len(headers) > 0 {
524		opts = append(opts, google.WithHeaders(headers))
525	}
526
527	project := options["project"]
528	location := options["location"]
529
530	opts = append(opts, google.WithVertex(project, location))
531
532	return google.New(opts...)
533}
534
535func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
536	if model.Think {
537		return true
538	}
539
540	if model.ProviderOptions == nil {
541		return false
542	}
543
544	opts, err := anthropic.ParseOptions(model.ProviderOptions)
545	if err != nil {
546		return false
547	}
548	if opts.Thinking != nil {
549		return true
550	}
551	return false
552}
553
554func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
555	headers := providerCfg.ExtraHeaders
556
557	// handle special headers for anthropic
558	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
559		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
560	}
561
562	// TODO: make sure we have
563	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
564	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
565	var provider fantasy.Provider
566	switch providerCfg.Type {
567	case openai.Name:
568		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
569	case anthropic.Name:
570		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
571	case openrouter.Name:
572		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
573	case azure.Name:
574		provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
575	case google.Name:
576		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
577	// this is not in fantasy since its just the google provider with extra stuff
578	case "google-vertex":
579		provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
580	case openaicompat.Name:
581		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
582	default:
583		return nil, errors.New("provider type not supported")
584	}
585	return provider, nil
586}
587
588func (c *coordinator) Cancel(sessionID string) {
589	c.currentAgent.Cancel(sessionID)
590}
591
592func (c *coordinator) CancelAll() {
593	c.currentAgent.CancelAll()
594}
595
596func (c *coordinator) ClearQueue(sessionID string) {
597	c.currentAgent.ClearQueue(sessionID)
598}
599
600func (c *coordinator) IsBusy() bool {
601	return c.currentAgent.IsBusy()
602}
603
604func (c *coordinator) IsSessionBusy(sessionID string) bool {
605	return c.currentAgent.IsSessionBusy(sessionID)
606}
607
608func (c *coordinator) Model() Model {
609	return c.currentAgent.Model()
610}
611
612func (c *coordinator) UpdateModels() error {
613	// build the models again so we make sure we get the latest config
614	large, small, err := c.buildAgentModels()
615	if err != nil {
616		return err
617	}
618	c.currentAgent.SetModels(large, small)
619
620	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
621	if !ok {
622		return errors.New("coder agent not configured")
623	}
624
625	tools, err := c.buildTools(agentCfg)
626	if err != nil {
627		return err
628	}
629	c.currentAgent.SetTools(tools)
630	return nil
631}
632
633func (c *coordinator) QueuedPrompts(sessionID string) int {
634	return c.currentAgent.QueuedPrompts(sessionID)
635}
636
637func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
638	return c.currentAgent.Summarize(ctx, sessionID)
639}