coordinator.go

  1package agent
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"io"
 10	"log/slog"
 11	"slices"
 12	"strings"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/agent/prompt"
 17	"github.com/charmbracelet/crush/internal/agent/tools"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/history"
 21	"github.com/charmbracelet/crush/internal/log"
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26
 27	"charm.land/fantasy/providers/anthropic"
 28	"charm.land/fantasy/providers/azure"
 29	"charm.land/fantasy/providers/google"
 30	"charm.land/fantasy/providers/openai"
 31	"charm.land/fantasy/providers/openaicompat"
 32	"charm.land/fantasy/providers/openrouter"
 33	"github.com/qjebbs/go-jsons"
 34)
 35
 36type Coordinator interface {
 37	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 38	// SetMainAgent(string)
 39	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
 40	Cancel(sessionID string)
 41	CancelAll()
 42	IsSessionBusy(sessionID string) bool
 43	IsBusy() bool
 44	QueuedPrompts(sessionID string) int
 45	ClearQueue(sessionID string)
 46	Summarize(context.Context, string) error
 47	Model() Model
 48	UpdateModels(ctx context.Context) error
 49}
 50
 51type coordinator struct {
 52	cfg         *config.Config
 53	sessions    session.Service
 54	messages    message.Service
 55	permissions permission.Service
 56	history     history.Service
 57	lspClients  *csync.Map[string, *lsp.Client]
 58
 59	currentAgent SessionAgent
 60	agents       map[string]SessionAgent
 61}
 62
 63func NewCoordinator(
 64	ctx context.Context,
 65	cfg *config.Config,
 66	sessions session.Service,
 67	messages message.Service,
 68	permissions permission.Service,
 69	history history.Service,
 70	lspClients *csync.Map[string, *lsp.Client],
 71) (Coordinator, error) {
 72	c := &coordinator{
 73		cfg:         cfg,
 74		sessions:    sessions,
 75		messages:    messages,
 76		permissions: permissions,
 77		history:     history,
 78		lspClients:  lspClients,
 79		agents:      make(map[string]SessionAgent),
 80	}
 81
 82	agentCfg, ok := cfg.Agents[config.AgentCoder]
 83	if !ok {
 84		return nil, errors.New("coder agent not configured")
 85	}
 86
 87	// TODO: make this dynamic when we support multiple agents
 88	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 89	if err != nil {
 90		return nil, err
 91	}
 92
 93	agent, err := c.buildAgent(ctx, prompt, agentCfg)
 94	if err != nil {
 95		return nil, err
 96	}
 97	c.currentAgent = agent
 98	c.agents[config.AgentCoder] = agent
 99	return c, nil
100}
101
102// Run implements Coordinator.
103func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
104	model := c.currentAgent.Model()
105	maxTokens := model.CatwalkCfg.DefaultMaxTokens
106	if model.ModelCfg.MaxTokens != 0 {
107		maxTokens = model.ModelCfg.MaxTokens
108	}
109
110	if !model.CatwalkCfg.SupportsImages && attachments != nil {
111		attachments = nil
112	}
113
114	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
115	if !ok {
116		return nil, errors.New("model provider not configured")
117	}
118
119	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
120
121	return c.currentAgent.Run(ctx, SessionAgentCall{
122		SessionID:        sessionID,
123		Prompt:           prompt,
124		Attachments:      attachments,
125		MaxOutputTokens:  maxTokens,
126		ProviderOptions:  mergedOptions,
127		Temperature:      temp,
128		TopP:             topP,
129		TopK:             topK,
130		FrequencyPenalty: freqPenalty,
131		PresencePenalty:  presPenalty,
132	})
133}
134
135func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
136	options := fantasy.ProviderOptions{}
137
138	cfgOpts := []byte("{}")
139	catwalkOpts := []byte("{}")
140
141	if model.ModelCfg.ProviderOptions != nil {
142		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
143		if err == nil {
144			cfgOpts = data
145		}
146	}
147
148	if model.CatwalkCfg.Options.ProviderOptions != nil {
149		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
150		if err == nil {
151			catwalkOpts = data
152		}
153	}
154
155	readers := []io.Reader{
156		bytes.NewReader(catwalkOpts),
157		bytes.NewReader(cfgOpts),
158	}
159
160	got, err := jsons.Merge(readers)
161	if err != nil {
162		slog.Error("Could not merge call config", "err", err)
163		return options
164	}
165
166	mergedOptions := make(map[string]any)
167
168	err = json.Unmarshal([]byte(got), &mergedOptions)
169	if err != nil {
170		slog.Error("Could not create config for call", "err", err)
171		return options
172	}
173
174	switch tp {
175	case openai.Name:
176		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
177		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
178			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
179		}
180		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
181			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
182				mergedOptions["reasoning_summary"] = "auto"
183				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
184			}
185			parsed, err := openai.ParseResponsesOptions(mergedOptions)
186			if err == nil {
187				options[openai.Name] = parsed
188			}
189		} else {
190			parsed, err := openai.ParseOptions(mergedOptions)
191			if err == nil {
192				options[openai.Name] = parsed
193			}
194		}
195	case anthropic.Name:
196		_, hasThink := mergedOptions["thinking"]
197		if !hasThink && model.ModelCfg.Think {
198			mergedOptions["thinking"] = map[string]any{
199				// TODO: kujtim see if we need to make this dynamic
200				"budget_tokens": 2000,
201			}
202		}
203		parsed, err := anthropic.ParseOptions(mergedOptions)
204		if err == nil {
205			options[anthropic.Name] = parsed
206		}
207
208	case openrouter.Name:
209		_, hasReasoning := mergedOptions["reasoning"]
210		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
211			mergedOptions["reasoning"] = map[string]any{
212				"enabled": true,
213				"effort":  model.ModelCfg.ReasoningEffort,
214			}
215		}
216		parsed, err := openrouter.ParseOptions(mergedOptions)
217		if err == nil {
218			options[openrouter.Name] = parsed
219		}
220	case google.Name:
221		_, hasReasoning := mergedOptions["thinking_config"]
222		if !hasReasoning {
223			mergedOptions["thinking_config"] = map[string]any{
224				"thinking_budget":  2000,
225				"include_thoughts": true,
226			}
227		}
228		parsed, err := google.ParseOptions(mergedOptions)
229		if err == nil {
230			options[google.Name] = parsed
231		}
232	case azure.Name:
233		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
234		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
235			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
236		}
237		// azure uses the same options as openaicompat
238		parsed, err := openaicompat.ParseOptions(mergedOptions)
239		if err == nil {
240			options[azure.Name] = parsed
241		}
242	case openaicompat.Name:
243		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
244		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
245			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
246		}
247		parsed, err := openaicompat.ParseOptions(mergedOptions)
248		if err == nil {
249			options[openaicompat.Name] = parsed
250		}
251	}
252
253	return options
254}
255
256func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
257	modelOptions := getProviderOptions(model, tp)
258	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
259	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
260	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
261	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
262	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
263	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
264}
265
266func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
267	large, small, err := c.buildAgentModels(ctx)
268	if err != nil {
269		return nil, err
270	}
271
272	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
273	if err != nil {
274		return nil, err
275	}
276
277	tools, err := c.buildTools(ctx, agent)
278	if err != nil {
279		return nil, err
280	}
281	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
282}
283
284func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
285	var allTools []fantasy.AgentTool
286	if slices.Contains(agent.AllowedTools, AgentToolName) {
287		agentTool, err := c.agentTool(ctx)
288		if err != nil {
289			return nil, err
290		}
291		allTools = append(allTools, agentTool)
292	}
293
294	allTools = append(allTools,
295		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
296		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
297		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
298		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
299		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
300		tools.NewGlobTool(c.cfg.WorkingDir()),
301		tools.NewGrepTool(c.cfg.WorkingDir()),
302		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
303		tools.NewSourcegraphTool(nil),
304		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
305		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
306	)
307
308	var filteredTools []fantasy.AgentTool
309	for _, tool := range allTools {
310		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
311			filteredTools = append(filteredTools, tool)
312		}
313	}
314
315	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
316
317	for _, mcpTool := range mcpTools {
318		if agent.AllowedMCP == nil {
319			// No MCP restrictions
320			filteredTools = append(filteredTools, mcpTool)
321		} else if len(agent.AllowedMCP) == 0 {
322			// no mcps allowed
323			break
324		}
325
326		for mcp, tools := range agent.AllowedMCP {
327			if mcp == mcpTool.MCP() {
328				if len(tools) == 0 {
329					filteredTools = append(filteredTools, mcpTool)
330				}
331				for _, t := range tools {
332					if t == mcpTool.MCPToolName() {
333						filteredTools = append(filteredTools, mcpTool)
334					}
335				}
336				break
337			}
338		}
339	}
340
341	return filteredTools, nil
342}
343
344// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
345func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
346	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
347	if !ok {
348		return Model{}, Model{}, errors.New("large model not selected")
349	}
350	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
351	if !ok {
352		return Model{}, Model{}, errors.New("small model not selected")
353	}
354
355	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
356	if !ok {
357		return Model{}, Model{}, errors.New("large model provider not configured")
358	}
359
360	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
361	if err != nil {
362		return Model{}, Model{}, err
363	}
364
365	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
366	if !ok {
367		return Model{}, Model{}, errors.New("large model provider not configured")
368	}
369
370	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
371	if err != nil {
372		return Model{}, Model{}, err
373	}
374
375	var largeCatwalkModel *catwalk.Model
376	var smallCatwalkModel *catwalk.Model
377
378	for _, m := range largeProviderCfg.Models {
379		if m.ID == largeModelCfg.Model {
380			largeCatwalkModel = &m
381		}
382	}
383	for _, m := range smallProviderCfg.Models {
384		if m.ID == smallModelCfg.Model {
385			smallCatwalkModel = &m
386		}
387	}
388
389	if largeCatwalkModel == nil {
390		return Model{}, Model{}, errors.New("large model not found in provider config")
391	}
392
393	if smallCatwalkModel == nil {
394		return Model{}, Model{}, errors.New("snall model not found in provider config")
395	}
396
397	largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
398	if err != nil {
399		return Model{}, Model{}, err
400	}
401	smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
402	if err != nil {
403		return Model{}, Model{}, err
404	}
405
406	return Model{
407			Model:      largeModel,
408			CatwalkCfg: *largeCatwalkModel,
409			ModelCfg:   largeModelCfg,
410		}, Model{
411			Model:      smallModel,
412			CatwalkCfg: *smallCatwalkModel,
413			ModelCfg:   smallModelCfg,
414		}, nil
415}
416
417func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
418	hasBearerAuth := false
419	for key := range headers {
420		if strings.ToLower(key) == "authorization" {
421			hasBearerAuth = true
422			break
423		}
424	}
425	if hasBearerAuth {
426		apiKey = "" // clear apiKey to avoid using X-Api-Key header
427	}
428
429	var opts []anthropic.Option
430
431	if apiKey != "" {
432		// Use standard X-Api-Key header
433		opts = append(opts, anthropic.WithAPIKey(apiKey))
434	}
435
436	if len(headers) > 0 {
437		opts = append(opts, anthropic.WithHeaders(headers))
438	}
439
440	if baseURL != "" {
441		opts = append(opts, anthropic.WithBaseURL(baseURL))
442	}
443
444	if c.cfg.Options.Debug {
445		httpClient := log.NewHTTPClient()
446		opts = append(opts, anthropic.WithHTTPClient(httpClient))
447	}
448
449	return anthropic.New(opts...)
450}
451
452func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
453	opts := []openai.Option{
454		openai.WithAPIKey(apiKey),
455		openai.WithUseResponsesAPI(),
456	}
457	if c.cfg.Options.Debug {
458		httpClient := log.NewHTTPClient()
459		opts = append(opts, openai.WithHTTPClient(httpClient))
460	}
461	if len(headers) > 0 {
462		opts = append(opts, openai.WithHeaders(headers))
463	}
464	if baseURL != "" {
465		opts = append(opts, openai.WithBaseURL(baseURL))
466	}
467	return openai.New(opts...)
468}
469
470func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
471	opts := []openrouter.Option{
472		openrouter.WithAPIKey(apiKey),
473	}
474	if c.cfg.Options.Debug {
475		httpClient := log.NewHTTPClient()
476		opts = append(opts, openrouter.WithHTTPClient(httpClient))
477	}
478	if len(headers) > 0 {
479		opts = append(opts, openrouter.WithHeaders(headers))
480	}
481	return openrouter.New(opts...)
482}
483
484func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
485	opts := []openaicompat.Option{
486		openaicompat.WithBaseURL(baseURL),
487		openaicompat.WithAPIKey(apiKey),
488	}
489	if c.cfg.Options.Debug {
490		httpClient := log.NewHTTPClient()
491		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
492	}
493	if len(headers) > 0 {
494		opts = append(opts, openaicompat.WithHeaders(headers))
495	}
496
497	return openaicompat.New(opts...)
498}
499
500func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
501	opts := []azure.Option{
502		azure.WithBaseURL(baseURL),
503		azure.WithAPIKey(apiKey),
504	}
505	if c.cfg.Options.Debug {
506		httpClient := log.NewHTTPClient()
507		opts = append(opts, azure.WithHTTPClient(httpClient))
508	}
509	if options == nil {
510		options = make(map[string]string)
511	}
512	if apiVersion, ok := options["apiVersion"]; ok {
513		opts = append(opts, azure.WithAPIVersion(apiVersion))
514	}
515	if len(headers) > 0 {
516		opts = append(opts, azure.WithHeaders(headers))
517	}
518
519	return azure.New(opts...)
520}
521
522func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
523	opts := []google.Option{
524		google.WithBaseURL(baseURL),
525		google.WithGeminiAPIKey(apiKey),
526	}
527	if c.cfg.Options.Debug {
528		httpClient := log.NewHTTPClient()
529		opts = append(opts, google.WithHTTPClient(httpClient))
530	}
531	if len(headers) > 0 {
532		opts = append(opts, google.WithHeaders(headers))
533	}
534	return google.New(opts...)
535}
536
537func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
538	opts := []google.Option{}
539	if c.cfg.Options.Debug {
540		httpClient := log.NewHTTPClient()
541		opts = append(opts, google.WithHTTPClient(httpClient))
542	}
543	if len(headers) > 0 {
544		opts = append(opts, google.WithHeaders(headers))
545	}
546
547	project := options["project"]
548	location := options["location"]
549
550	opts = append(opts, google.WithVertex(project, location))
551
552	return google.New(opts...)
553}
554
555func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
556	if model.Think {
557		return true
558	}
559
560	if model.ProviderOptions == nil {
561		return false
562	}
563
564	opts, err := anthropic.ParseOptions(model.ProviderOptions)
565	if err != nil {
566		return false
567	}
568	if opts.Thinking != nil {
569		return true
570	}
571	return false
572}
573
574func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
575	headers := providerCfg.ExtraHeaders
576
577	// handle special headers for anthropic
578	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
579		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
580	}
581
582	// TODO: make sure we have
583	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
584	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
585
586	switch providerCfg.Type {
587	case openai.Name:
588		return c.buildOpenaiProvider(baseURL, apiKey, headers)
589	case anthropic.Name:
590		return c.buildAnthropicProvider(baseURL, apiKey, headers)
591	case openrouter.Name:
592		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
593	case azure.Name:
594		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
595	case google.Name:
596		return c.buildGoogleProvider(baseURL, apiKey, headers)
597	case "vertexai":
598		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
599	case openaicompat.Name:
600		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
601	default:
602		return nil, errors.New("provider type not supported")
603	}
604}
605
606func (c *coordinator) Cancel(sessionID string) {
607	c.currentAgent.Cancel(sessionID)
608}
609
610func (c *coordinator) CancelAll() {
611	c.currentAgent.CancelAll()
612}
613
614func (c *coordinator) ClearQueue(sessionID string) {
615	c.currentAgent.ClearQueue(sessionID)
616}
617
618func (c *coordinator) IsBusy() bool {
619	return c.currentAgent.IsBusy()
620}
621
622func (c *coordinator) IsSessionBusy(sessionID string) bool {
623	return c.currentAgent.IsSessionBusy(sessionID)
624}
625
626func (c *coordinator) Model() Model {
627	return c.currentAgent.Model()
628}
629
630func (c *coordinator) UpdateModels(ctx context.Context) error {
631	// build the models again so we make sure we get the latest config
632	large, small, err := c.buildAgentModels(ctx)
633	if err != nil {
634		return err
635	}
636	c.currentAgent.SetModels(large, small)
637
638	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
639	if !ok {
640		return errors.New("coder agent not configured")
641	}
642
643	tools, err := c.buildTools(ctx, agentCfg)
644	if err != nil {
645		return err
646	}
647	c.currentAgent.SetTools(tools)
648	return nil
649}
650
651func (c *coordinator) QueuedPrompts(sessionID string) int {
652	return c.currentAgent.QueuedPrompts(sessionID)
653}
654
655func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
656	providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
657	if !ok {
658		return errors.New("model provider not configured")
659	}
660	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
661}