coordinator.go

  1package agent
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"io"
 10	"log/slog"
 11	"slices"
 12	"strings"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/agent/prompt"
 17	"github.com/charmbracelet/crush/internal/agent/tools"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/history"
 21	"github.com/charmbracelet/crush/internal/log"
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26
 27	"charm.land/fantasy/providers/anthropic"
 28	"charm.land/fantasy/providers/azure"
 29	"charm.land/fantasy/providers/google"
 30	"charm.land/fantasy/providers/openai"
 31	"charm.land/fantasy/providers/openaicompat"
 32	"charm.land/fantasy/providers/openrouter"
 33	"github.com/qjebbs/go-jsons"
 34)
 35
 36type Coordinator interface {
 37	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 38	// SetMainAgent(string)
 39	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
 40	Cancel(sessionID string)
 41	CancelAll()
 42	IsSessionBusy(sessionID string) bool
 43	IsBusy() bool
 44	QueuedPrompts(sessionID string) int
 45	ClearQueue(sessionID string)
 46	Summarize(context.Context, string) error
 47	Model() Model
 48	UpdateModels(ctx context.Context) error
 49}
 50
 51type coordinator struct {
 52	cfg         *config.Config
 53	sessions    session.Service
 54	messages    message.Service
 55	permissions permission.Service
 56	history     history.Service
 57	lspClients  *csync.Map[string, *lsp.Client]
 58
 59	currentAgent SessionAgent
 60	agents       map[string]SessionAgent
 61}
 62
 63func NewCoordinator(
 64	ctx context.Context,
 65	cfg *config.Config,
 66	sessions session.Service,
 67	messages message.Service,
 68	permissions permission.Service,
 69	history history.Service,
 70	lspClients *csync.Map[string, *lsp.Client],
 71) (Coordinator, error) {
 72	c := &coordinator{
 73		cfg:         cfg,
 74		sessions:    sessions,
 75		messages:    messages,
 76		permissions: permissions,
 77		history:     history,
 78		lspClients:  lspClients,
 79		agents:      make(map[string]SessionAgent),
 80	}
 81
 82	agentCfg, ok := cfg.Agents[config.AgentCoder]
 83	if !ok {
 84		return nil, errors.New("coder agent not configured")
 85	}
 86
 87	// TODO: make this dynamic when we support multiple agents
 88	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 89	if err != nil {
 90		return nil, err
 91	}
 92
 93	agent, err := c.buildAgent(ctx, prompt, agentCfg)
 94	if err != nil {
 95		return nil, err
 96	}
 97	c.currentAgent = agent
 98	c.agents[config.AgentCoder] = agent
 99	return c, nil
100}
101
102// Run implements Coordinator.
103func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
104	model := c.currentAgent.Model()
105	maxTokens := model.CatwalkCfg.DefaultMaxTokens
106	if model.ModelCfg.MaxTokens != 0 {
107		maxTokens = model.ModelCfg.MaxTokens
108	}
109
110	if !model.CatwalkCfg.SupportsImages && attachments != nil {
111		attachments = nil
112	}
113
114	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
115	if !ok {
116		return nil, errors.New("model provider not configured")
117	}
118
119	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
120
121	return c.currentAgent.Run(ctx, SessionAgentCall{
122		SessionID:        sessionID,
123		Prompt:           prompt,
124		Attachments:      attachments,
125		MaxOutputTokens:  maxTokens,
126		ProviderOptions:  mergedOptions,
127		Temperature:      temp,
128		TopP:             topP,
129		TopK:             topK,
130		FrequencyPenalty: freqPenalty,
131		PresencePenalty:  presPenalty,
132	})
133}
134
135func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
136	options := fantasy.ProviderOptions{}
137
138	cfgOpts := []byte("{}")
139	catwalkOpts := []byte("{}")
140
141	if model.ModelCfg.ProviderOptions != nil {
142		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
143		if err == nil {
144			cfgOpts = data
145		}
146	}
147
148	if model.CatwalkCfg.Options.ProviderOptions != nil {
149		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
150		if err == nil {
151			catwalkOpts = data
152		}
153	}
154
155	readers := []io.Reader{
156		bytes.NewReader(catwalkOpts),
157		bytes.NewReader(cfgOpts),
158	}
159
160	got, err := jsons.Merge(readers)
161	if err != nil {
162		slog.Error("Could not merge call config", "err", err)
163		return options
164	}
165
166	mergedOptions := make(map[string]any)
167
168	err = json.Unmarshal([]byte(got), &mergedOptions)
169	if err != nil {
170		slog.Error("Could not create config for call", "err", err)
171		return options
172	}
173
174	switch tp {
175	case openai.Name:
176		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
177		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
178			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
179		}
180		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
181			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
182				mergedOptions["reasoning_summary"] = "auto"
183				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
184			}
185			parsed, err := openai.ParseResponsesOptions(mergedOptions)
186			if err == nil {
187				options[openai.Name] = parsed
188			}
189		} else {
190			parsed, err := openai.ParseOptions(mergedOptions)
191			if err == nil {
192				options[openai.Name] = parsed
193			}
194		}
195	case anthropic.Name:
196		_, hasThink := mergedOptions["thinking"]
197		if !hasThink && model.ModelCfg.Think {
198			mergedOptions["thinking"] = map[string]any{
199				// TODO: kujtim see if we need to make this dynamic
200				"budget_tokens": 2000,
201			}
202		}
203		parsed, err := anthropic.ParseOptions(mergedOptions)
204		if err == nil {
205			options[anthropic.Name] = parsed
206		}
207
208	case openrouter.Name:
209		_, hasReasoning := mergedOptions["reasoning"]
210		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
211			mergedOptions["reasoning"] = map[string]any{
212				"enabled": true,
213				"effort":  model.ModelCfg.ReasoningEffort,
214			}
215		}
216		parsed, err := openrouter.ParseOptions(mergedOptions)
217		if err == nil {
218			options[openrouter.Name] = parsed
219		}
220	case google.Name:
221		_, hasReasoning := mergedOptions["thinking_config"]
222		if !hasReasoning {
223			mergedOptions["thinking_config"] = map[string]any{
224				"thinking_budget":  2000,
225				"include_thoughts": true,
226			}
227		}
228		parsed, err := google.ParseOptions(mergedOptions)
229		if err == nil {
230			options[google.Name] = parsed
231		}
232	case azure.Name:
233		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
234		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
235			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
236		}
237		// azure uses the same options as openaicompat
238		parsed, err := openaicompat.ParseOptions(mergedOptions)
239		if err == nil {
240			options[azure.Name] = parsed
241		}
242	case openaicompat.Name:
243		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
244		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
245			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
246		}
247		parsed, err := openaicompat.ParseOptions(mergedOptions)
248		if err == nil {
249			options[openaicompat.Name] = parsed
250		}
251	}
252
253	return options
254}
255
256func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
257	modelOptions := getProviderOptions(model, tp)
258	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
259	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
260	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
261	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
262	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
263	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
264}
265
266func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
267	large, small, err := c.buildAgentModels(ctx)
268	if err != nil {
269		return nil, err
270	}
271
272	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
273	if err != nil {
274		return nil, err
275	}
276
277	tools, err := c.buildTools(ctx, agent)
278	if err != nil {
279		return nil, err
280	}
281	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
282}
283
284func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
285	var allTools []fantasy.AgentTool
286	if slices.Contains(agent.AllowedTools, AgentToolName) {
287		agentTool, err := c.agentTool(ctx)
288		if err != nil {
289			return nil, err
290		}
291		allTools = append(allTools, agentTool)
292	}
293
294	allTools = append(allTools,
295		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
296		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
297		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
298		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
299		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
300		tools.NewGlobTool(c.cfg.WorkingDir()),
301		tools.NewGrepTool(c.cfg.WorkingDir()),
302		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
303		tools.NewSourcegraphTool(nil),
304		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
305		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
306	)
307
308	if len(c.cfg.LSP) > 0 {
309		allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
310	}
311
312	var filteredTools []fantasy.AgentTool
313	for _, tool := range allTools {
314		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
315			filteredTools = append(filteredTools, tool)
316		}
317	}
318
319	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
320
321	for _, mcpTool := range mcpTools {
322		if agent.AllowedMCP == nil {
323			// No MCP restrictions
324			filteredTools = append(filteredTools, mcpTool)
325		} else if len(agent.AllowedMCP) == 0 {
326			// no mcps allowed
327			break
328		}
329
330		for mcp, tools := range agent.AllowedMCP {
331			if mcp == mcpTool.MCP() {
332				if len(tools) == 0 {
333					filteredTools = append(filteredTools, mcpTool)
334				}
335				for _, t := range tools {
336					if t == mcpTool.MCPToolName() {
337						filteredTools = append(filteredTools, mcpTool)
338					}
339				}
340				break
341			}
342		}
343	}
344	slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
345		return strings.Compare(a.Info().Name, b.Info().Name)
346	})
347	return filteredTools, nil
348}
349
350// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
351func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
352	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
353	if !ok {
354		return Model{}, Model{}, errors.New("large model not selected")
355	}
356	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
357	if !ok {
358		return Model{}, Model{}, errors.New("small model not selected")
359	}
360
361	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
362	if !ok {
363		return Model{}, Model{}, errors.New("large model provider not configured")
364	}
365
366	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
367	if err != nil {
368		return Model{}, Model{}, err
369	}
370
371	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
372	if !ok {
373		return Model{}, Model{}, errors.New("large model provider not configured")
374	}
375
376	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
377	if err != nil {
378		return Model{}, Model{}, err
379	}
380
381	var largeCatwalkModel *catwalk.Model
382	var smallCatwalkModel *catwalk.Model
383
384	for _, m := range largeProviderCfg.Models {
385		if m.ID == largeModelCfg.Model {
386			largeCatwalkModel = &m
387		}
388	}
389	for _, m := range smallProviderCfg.Models {
390		if m.ID == smallModelCfg.Model {
391			smallCatwalkModel = &m
392		}
393	}
394
395	if largeCatwalkModel == nil {
396		return Model{}, Model{}, errors.New("large model not found in provider config")
397	}
398
399	if smallCatwalkModel == nil {
400		return Model{}, Model{}, errors.New("snall model not found in provider config")
401	}
402
403	largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
404	if err != nil {
405		return Model{}, Model{}, err
406	}
407	smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
408	if err != nil {
409		return Model{}, Model{}, err
410	}
411
412	return Model{
413			Model:      largeModel,
414			CatwalkCfg: *largeCatwalkModel,
415			ModelCfg:   largeModelCfg,
416		}, Model{
417			Model:      smallModel,
418			CatwalkCfg: *smallCatwalkModel,
419			ModelCfg:   smallModelCfg,
420		}, nil
421}
422
423func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
424	hasBearerAuth := false
425	for key := range headers {
426		if strings.ToLower(key) == "authorization" {
427			hasBearerAuth = true
428			break
429		}
430	}
431	if hasBearerAuth {
432		apiKey = "" // clear apiKey to avoid using X-Api-Key header
433	}
434
435	var opts []anthropic.Option
436
437	if apiKey != "" {
438		// Use standard X-Api-Key header
439		opts = append(opts, anthropic.WithAPIKey(apiKey))
440	}
441
442	if len(headers) > 0 {
443		opts = append(opts, anthropic.WithHeaders(headers))
444	}
445
446	if baseURL != "" {
447		opts = append(opts, anthropic.WithBaseURL(baseURL))
448	}
449
450	if c.cfg.Options.Debug {
451		httpClient := log.NewHTTPClient()
452		opts = append(opts, anthropic.WithHTTPClient(httpClient))
453	}
454
455	return anthropic.New(opts...)
456}
457
458func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
459	opts := []openai.Option{
460		openai.WithAPIKey(apiKey),
461		openai.WithUseResponsesAPI(),
462	}
463	if c.cfg.Options.Debug {
464		httpClient := log.NewHTTPClient()
465		opts = append(opts, openai.WithHTTPClient(httpClient))
466	}
467	if len(headers) > 0 {
468		opts = append(opts, openai.WithHeaders(headers))
469	}
470	if baseURL != "" {
471		opts = append(opts, openai.WithBaseURL(baseURL))
472	}
473	return openai.New(opts...)
474}
475
476func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
477	opts := []openrouter.Option{
478		openrouter.WithAPIKey(apiKey),
479	}
480	if c.cfg.Options.Debug {
481		httpClient := log.NewHTTPClient()
482		opts = append(opts, openrouter.WithHTTPClient(httpClient))
483	}
484	if len(headers) > 0 {
485		opts = append(opts, openrouter.WithHeaders(headers))
486	}
487	return openrouter.New(opts...)
488}
489
490func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
491	opts := []openaicompat.Option{
492		openaicompat.WithBaseURL(baseURL),
493		openaicompat.WithAPIKey(apiKey),
494	}
495	if c.cfg.Options.Debug {
496		httpClient := log.NewHTTPClient()
497		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
498	}
499	if len(headers) > 0 {
500		opts = append(opts, openaicompat.WithHeaders(headers))
501	}
502
503	return openaicompat.New(opts...)
504}
505
506func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
507	opts := []azure.Option{
508		azure.WithBaseURL(baseURL),
509		azure.WithAPIKey(apiKey),
510	}
511	if c.cfg.Options.Debug {
512		httpClient := log.NewHTTPClient()
513		opts = append(opts, azure.WithHTTPClient(httpClient))
514	}
515	if options == nil {
516		options = make(map[string]string)
517	}
518	if apiVersion, ok := options["apiVersion"]; ok {
519		opts = append(opts, azure.WithAPIVersion(apiVersion))
520	}
521	if len(headers) > 0 {
522		opts = append(opts, azure.WithHeaders(headers))
523	}
524
525	return azure.New(opts...)
526}
527
528func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
529	opts := []google.Option{
530		google.WithBaseURL(baseURL),
531		google.WithGeminiAPIKey(apiKey),
532	}
533	if c.cfg.Options.Debug {
534		httpClient := log.NewHTTPClient()
535		opts = append(opts, google.WithHTTPClient(httpClient))
536	}
537	if len(headers) > 0 {
538		opts = append(opts, google.WithHeaders(headers))
539	}
540	return google.New(opts...)
541}
542
543func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
544	opts := []google.Option{}
545	if c.cfg.Options.Debug {
546		httpClient := log.NewHTTPClient()
547		opts = append(opts, google.WithHTTPClient(httpClient))
548	}
549	if len(headers) > 0 {
550		opts = append(opts, google.WithHeaders(headers))
551	}
552
553	project := options["project"]
554	location := options["location"]
555
556	opts = append(opts, google.WithVertex(project, location))
557
558	return google.New(opts...)
559}
560
561func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
562	if model.Think {
563		return true
564	}
565
566	if model.ProviderOptions == nil {
567		return false
568	}
569
570	opts, err := anthropic.ParseOptions(model.ProviderOptions)
571	if err != nil {
572		return false
573	}
574	if opts.Thinking != nil {
575		return true
576	}
577	return false
578}
579
580func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
581	headers := providerCfg.ExtraHeaders
582
583	// handle special headers for anthropic
584	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
585		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
586	}
587
588	// TODO: make sure we have
589	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
590	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
591
592	switch providerCfg.Type {
593	case openai.Name:
594		return c.buildOpenaiProvider(baseURL, apiKey, headers)
595	case anthropic.Name:
596		return c.buildAnthropicProvider(baseURL, apiKey, headers)
597	case openrouter.Name:
598		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
599	case azure.Name:
600		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
601	case google.Name:
602		return c.buildGoogleProvider(baseURL, apiKey, headers)
603	case "vertexai":
604		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
605	case openaicompat.Name:
606		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
607	default:
608		return nil, errors.New("provider type not supported")
609	}
610}
611
612func (c *coordinator) Cancel(sessionID string) {
613	c.currentAgent.Cancel(sessionID)
614}
615
616func (c *coordinator) CancelAll() {
617	c.currentAgent.CancelAll()
618}
619
620func (c *coordinator) ClearQueue(sessionID string) {
621	c.currentAgent.ClearQueue(sessionID)
622}
623
624func (c *coordinator) IsBusy() bool {
625	return c.currentAgent.IsBusy()
626}
627
628func (c *coordinator) IsSessionBusy(sessionID string) bool {
629	return c.currentAgent.IsSessionBusy(sessionID)
630}
631
632func (c *coordinator) Model() Model {
633	return c.currentAgent.Model()
634}
635
636func (c *coordinator) UpdateModels(ctx context.Context) error {
637	// build the models again so we make sure we get the latest config
638	large, small, err := c.buildAgentModels(ctx)
639	if err != nil {
640		return err
641	}
642	c.currentAgent.SetModels(large, small)
643
644	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
645	if !ok {
646		return errors.New("coder agent not configured")
647	}
648
649	tools, err := c.buildTools(ctx, agentCfg)
650	if err != nil {
651		return err
652	}
653	c.currentAgent.SetTools(tools)
654	return nil
655}
656
657func (c *coordinator) QueuedPrompts(sessionID string) int {
658	return c.currentAgent.QueuedPrompts(sessionID)
659}
660
661func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
662	providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
663	if !ok {
664		return errors.New("model provider not configured")
665	}
666	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
667}