coordinator.go

  1package agent
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"io"
 10	"log/slog"
 11	"slices"
 12	"strings"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/agent/prompt"
 17	"github.com/charmbracelet/crush/internal/agent/tools"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/history"
 21	"github.com/charmbracelet/crush/internal/log"
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26
 27	"charm.land/fantasy/providers/anthropic"
 28	"charm.land/fantasy/providers/azure"
 29	"charm.land/fantasy/providers/google"
 30	"charm.land/fantasy/providers/openai"
 31	"charm.land/fantasy/providers/openaicompat"
 32	"charm.land/fantasy/providers/openrouter"
 33	"github.com/qjebbs/go-jsons"
 34)
 35
 36type Coordinator interface {
 37	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 38	// SetMainAgent(string)
 39	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
 40	Cancel(sessionID string)
 41	CancelAll()
 42	IsSessionBusy(sessionID string) bool
 43	IsBusy() bool
 44	QueuedPrompts(sessionID string) int
 45	ClearQueue(sessionID string)
 46	Summarize(context.Context, string) error
 47	Model() Model
 48	UpdateModels(ctx context.Context) error
 49}
 50
 51type coordinator struct {
 52	cfg         *config.Config
 53	sessions    session.Service
 54	messages    message.Service
 55	permissions permission.Service
 56	history     history.Service
 57	lspClients  *csync.Map[string, *lsp.Client]
 58
 59	currentAgent SessionAgent
 60	agents       map[string]SessionAgent
 61}
 62
 63func NewCoordinator(
 64	ctx context.Context,
 65	cfg *config.Config,
 66	sessions session.Service,
 67	messages message.Service,
 68	permissions permission.Service,
 69	history history.Service,
 70	lspClients *csync.Map[string, *lsp.Client],
 71) (Coordinator, error) {
 72	c := &coordinator{
 73		cfg:         cfg,
 74		sessions:    sessions,
 75		messages:    messages,
 76		permissions: permissions,
 77		history:     history,
 78		lspClients:  lspClients,
 79		agents:      make(map[string]SessionAgent),
 80	}
 81
 82	agentCfg, ok := cfg.Agents[config.AgentCoder]
 83	if !ok {
 84		return nil, errors.New("coder agent not configured")
 85	}
 86
 87	// TODO: make this dynamic when we support multiple agents
 88	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 89	if err != nil {
 90		return nil, err
 91	}
 92
 93	agent, err := c.buildAgent(ctx, prompt, agentCfg)
 94	if err != nil {
 95		return nil, err
 96	}
 97	c.currentAgent = agent
 98	c.agents[config.AgentCoder] = agent
 99	return c, nil
100}
101
102// Run implements Coordinator.
103func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
104	model := c.currentAgent.Model()
105	maxTokens := model.CatwalkCfg.DefaultMaxTokens
106	if model.ModelCfg.MaxTokens != 0 {
107		maxTokens = model.ModelCfg.MaxTokens
108	}
109
110	if !model.CatwalkCfg.SupportsImages && attachments != nil {
111		attachments = nil
112	}
113
114	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
115	if !ok {
116		return nil, errors.New("model provider not configured")
117	}
118
119	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
120
121	return c.currentAgent.Run(ctx, SessionAgentCall{
122		SessionID:        sessionID,
123		Prompt:           prompt,
124		Attachments:      attachments,
125		MaxOutputTokens:  maxTokens,
126		ProviderOptions:  mergedOptions,
127		Temperature:      temp,
128		TopP:             topP,
129		TopK:             topK,
130		FrequencyPenalty: freqPenalty,
131		PresencePenalty:  presPenalty,
132	})
133}
134
135func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
136	options := fantasy.ProviderOptions{}
137
138	cfgOpts := []byte("{}")
139	catwalkOpts := []byte("{}")
140
141	if model.ModelCfg.ProviderOptions != nil {
142		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
143		if err == nil {
144			cfgOpts = data
145		}
146	}
147
148	if model.CatwalkCfg.Options.ProviderOptions != nil {
149		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
150		if err == nil {
151			catwalkOpts = data
152		}
153	}
154
155	readers := []io.Reader{
156		bytes.NewReader(catwalkOpts),
157		bytes.NewReader(cfgOpts),
158	}
159
160	got, err := jsons.Merge(readers)
161	if err != nil {
162		slog.Error("Could not merge call config", "err", err)
163		return options
164	}
165
166	mergedOptions := make(map[string]any)
167
168	err = json.Unmarshal([]byte(got), &mergedOptions)
169	if err != nil {
170		slog.Error("Could not create config for call", "err", err)
171		return options
172	}
173
174	switch tp {
175	case openai.Name:
176		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
177		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
178			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
179		}
180		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
181			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
182				mergedOptions["reasoning_summary"] = "auto"
183				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
184			}
185			parsed, err := openai.ParseResponsesOptions(mergedOptions)
186			if err == nil {
187				options[openai.Name] = parsed
188			}
189		} else {
190			parsed, err := openai.ParseOptions(mergedOptions)
191			if err == nil {
192				options[openai.Name] = parsed
193			}
194		}
195	case anthropic.Name:
196		_, hasThink := mergedOptions["thinking"]
197		if !hasThink && model.ModelCfg.Think {
198			mergedOptions["thinking"] = map[string]any{
199				// TODO: kujtim see if we need to make this dynamic
200				"budget_tokens": 2000,
201			}
202		}
203		parsed, err := anthropic.ParseOptions(mergedOptions)
204		if err == nil {
205			options[anthropic.Name] = parsed
206		}
207
208	case openrouter.Name:
209		_, hasReasoning := mergedOptions["reasoning"]
210		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
211			mergedOptions["reasoning"] = map[string]any{
212				"enabled": true,
213				"effort":  model.ModelCfg.ReasoningEffort,
214			}
215		}
216		parsed, err := openrouter.ParseOptions(mergedOptions)
217		if err == nil {
218			options[openrouter.Name] = parsed
219		}
220	case google.Name:
221		parsed, err := google.ParseOptions(mergedOptions)
222		if err == nil {
223			options[google.Name] = parsed
224		}
225	case azure.Name:
226		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
227		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
228			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
229		}
230		// azure uses the same options as openaicompat
231		parsed, err := openaicompat.ParseOptions(mergedOptions)
232		if err == nil {
233			options[azure.Name] = parsed
234		}
235	case openaicompat.Name:
236		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
237		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
238			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
239		}
240		parsed, err := openaicompat.ParseOptions(mergedOptions)
241		if err == nil {
242			options[openaicompat.Name] = parsed
243		}
244	}
245
246	return options
247}
248
249func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
250	modelOptions := getProviderOptions(model, tp)
251	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
252	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
253	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
254	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
255	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
256	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
257}
258
259func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
260	large, small, err := c.buildAgentModels(ctx)
261	if err != nil {
262		return nil, err
263	}
264
265	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
266	if err != nil {
267		return nil, err
268	}
269
270	tools, err := c.buildTools(ctx, agent)
271	if err != nil {
272		return nil, err
273	}
274	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
275}
276
277func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
278	var allTools []fantasy.AgentTool
279	if slices.Contains(agent.AllowedTools, AgentToolName) {
280		agentTool, err := c.agentTool(ctx)
281		if err != nil {
282			return nil, err
283		}
284		allTools = append(allTools, agentTool)
285	}
286
287	allTools = append(allTools,
288		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
289		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
290		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
291		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
292		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
293		tools.NewGlobTool(c.cfg.WorkingDir()),
294		tools.NewGrepTool(c.cfg.WorkingDir()),
295		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
296		tools.NewSourcegraphTool(nil),
297		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
298		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
299	)
300
301	var filteredTools []fantasy.AgentTool
302	for _, tool := range allTools {
303		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
304			filteredTools = append(filteredTools, tool)
305		}
306	}
307
308	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
309
310	for _, mcpTool := range mcpTools {
311		if agent.AllowedMCP == nil {
312			// No MCP restrictions
313			filteredTools = append(filteredTools, mcpTool)
314		} else if len(agent.AllowedMCP) == 0 {
315			// no mcps allowed
316			break
317		}
318
319		for mcp, tools := range agent.AllowedMCP {
320			if mcp == mcpTool.MCP() {
321				if len(tools) == 0 {
322					filteredTools = append(filteredTools, mcpTool)
323				}
324				for _, t := range tools {
325					if t == mcpTool.MCPToolName() {
326						filteredTools = append(filteredTools, mcpTool)
327					}
328				}
329				break
330			}
331		}
332	}
333
334	return filteredTools, nil
335}
336
337// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
338func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
339	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
340	if !ok {
341		return Model{}, Model{}, errors.New("large model not selected")
342	}
343	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
344	if !ok {
345		return Model{}, Model{}, errors.New("small model not selected")
346	}
347
348	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
349	if !ok {
350		return Model{}, Model{}, errors.New("large model provider not configured")
351	}
352
353	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
354	if err != nil {
355		return Model{}, Model{}, err
356	}
357
358	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
359	if !ok {
360		return Model{}, Model{}, errors.New("large model provider not configured")
361	}
362
363	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
364	if err != nil {
365		return Model{}, Model{}, err
366	}
367
368	var largeCatwalkModel *catwalk.Model
369	var smallCatwalkModel *catwalk.Model
370
371	for _, m := range largeProviderCfg.Models {
372		if m.ID == largeModelCfg.Model {
373			largeCatwalkModel = &m
374		}
375	}
376	for _, m := range smallProviderCfg.Models {
377		if m.ID == smallModelCfg.Model {
378			smallCatwalkModel = &m
379		}
380	}
381
382	if largeCatwalkModel == nil {
383		return Model{}, Model{}, errors.New("large model not found in provider config")
384	}
385
386	if smallCatwalkModel == nil {
387		return Model{}, Model{}, errors.New("snall model not found in provider config")
388	}
389
390	largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
391	if err != nil {
392		return Model{}, Model{}, err
393	}
394	smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
395	if err != nil {
396		return Model{}, Model{}, err
397	}
398
399	return Model{
400			Model:      largeModel,
401			CatwalkCfg: *largeCatwalkModel,
402			ModelCfg:   largeModelCfg,
403		}, Model{
404			Model:      smallModel,
405			CatwalkCfg: *smallCatwalkModel,
406			ModelCfg:   smallModelCfg,
407		}, nil
408}
409
410func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
411	hasBearerAuth := false
412	for key := range headers {
413		if strings.ToLower(key) == "authorization" {
414			hasBearerAuth = true
415			break
416		}
417	}
418	if hasBearerAuth {
419		apiKey = "" // clear apiKey to avoid using X-Api-Key header
420	}
421
422	var opts []anthropic.Option
423
424	if apiKey != "" {
425		// Use standard X-Api-Key header
426		opts = append(opts, anthropic.WithAPIKey(apiKey))
427	}
428
429	if len(headers) > 0 {
430		opts = append(opts, anthropic.WithHeaders(headers))
431	}
432
433	if baseURL != "" {
434		opts = append(opts, anthropic.WithBaseURL(baseURL))
435	}
436
437	if c.cfg.Options.Debug {
438		httpClient := log.NewHTTPClient()
439		opts = append(opts, anthropic.WithHTTPClient(httpClient))
440	}
441
442	return anthropic.New(opts...)
443}
444
445func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
446	opts := []openai.Option{
447		openai.WithAPIKey(apiKey),
448		openai.WithUseResponsesAPI(),
449	}
450	if c.cfg.Options.Debug {
451		httpClient := log.NewHTTPClient()
452		opts = append(opts, openai.WithHTTPClient(httpClient))
453	}
454	if len(headers) > 0 {
455		opts = append(opts, openai.WithHeaders(headers))
456	}
457	if baseURL != "" {
458		opts = append(opts, openai.WithBaseURL(baseURL))
459	}
460	return openai.New(opts...)
461}
462
463func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
464	opts := []openrouter.Option{
465		openrouter.WithAPIKey(apiKey),
466	}
467	if c.cfg.Options.Debug {
468		httpClient := log.NewHTTPClient()
469		opts = append(opts, openrouter.WithHTTPClient(httpClient))
470	}
471	if len(headers) > 0 {
472		opts = append(opts, openrouter.WithHeaders(headers))
473	}
474	return openrouter.New(opts...)
475}
476
477func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
478	opts := []openaicompat.Option{
479		openaicompat.WithBaseURL(baseURL),
480		openaicompat.WithAPIKey(apiKey),
481	}
482	if c.cfg.Options.Debug {
483		httpClient := log.NewHTTPClient()
484		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
485	}
486	if len(headers) > 0 {
487		opts = append(opts, openaicompat.WithHeaders(headers))
488	}
489
490	return openaicompat.New(opts...)
491}
492
493func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
494	opts := []azure.Option{
495		azure.WithBaseURL(baseURL),
496		azure.WithAPIKey(apiKey),
497	}
498	if c.cfg.Options.Debug {
499		httpClient := log.NewHTTPClient()
500		opts = append(opts, azure.WithHTTPClient(httpClient))
501	}
502	if options == nil {
503		options = make(map[string]string)
504	}
505	if apiVersion, ok := options["apiVersion"]; ok {
506		opts = append(opts, azure.WithAPIVersion(apiVersion))
507	}
508	if len(headers) > 0 {
509		opts = append(opts, azure.WithHeaders(headers))
510	}
511
512	return azure.New(opts...)
513}
514
515func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
516	opts := []google.Option{
517		google.WithBaseURL(baseURL),
518		google.WithGeminiAPIKey(apiKey),
519	}
520	if c.cfg.Options.Debug {
521		httpClient := log.NewHTTPClient()
522		opts = append(opts, google.WithHTTPClient(httpClient))
523	}
524	if len(headers) > 0 {
525		opts = append(opts, google.WithHeaders(headers))
526	}
527	return google.New(opts...)
528}
529
530func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
531	opts := []google.Option{}
532	if c.cfg.Options.Debug {
533		httpClient := log.NewHTTPClient()
534		opts = append(opts, google.WithHTTPClient(httpClient))
535	}
536	if len(headers) > 0 {
537		opts = append(opts, google.WithHeaders(headers))
538	}
539
540	project := options["project"]
541	location := options["location"]
542
543	opts = append(opts, google.WithVertex(project, location))
544
545	return google.New(opts...)
546}
547
548func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
549	if model.Think {
550		return true
551	}
552
553	if model.ProviderOptions == nil {
554		return false
555	}
556
557	opts, err := anthropic.ParseOptions(model.ProviderOptions)
558	if err != nil {
559		return false
560	}
561	if opts.Thinking != nil {
562		return true
563	}
564	return false
565}
566
567func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
568	headers := providerCfg.ExtraHeaders
569
570	// handle special headers for anthropic
571	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
572		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
573	}
574
575	// TODO: make sure we have
576	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
577	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
578
579	switch providerCfg.Type {
580	case openai.Name:
581		return c.buildOpenaiProvider(baseURL, apiKey, headers)
582	case anthropic.Name:
583		return c.buildAnthropicProvider(baseURL, apiKey, headers)
584	case openrouter.Name:
585		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
586	case azure.Name:
587		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
588	case google.Name:
589		return c.buildGoogleProvider(baseURL, apiKey, headers)
590	case "vertexai":
591		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
592	case openaicompat.Name:
593		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
594	default:
595		return nil, errors.New("provider type not supported")
596	}
597}
598
599func (c *coordinator) Cancel(sessionID string) {
600	c.currentAgent.Cancel(sessionID)
601}
602
603func (c *coordinator) CancelAll() {
604	c.currentAgent.CancelAll()
605}
606
607func (c *coordinator) ClearQueue(sessionID string) {
608	c.currentAgent.ClearQueue(sessionID)
609}
610
611func (c *coordinator) IsBusy() bool {
612	return c.currentAgent.IsBusy()
613}
614
615func (c *coordinator) IsSessionBusy(sessionID string) bool {
616	return c.currentAgent.IsSessionBusy(sessionID)
617}
618
619func (c *coordinator) Model() Model {
620	return c.currentAgent.Model()
621}
622
623func (c *coordinator) UpdateModels(ctx context.Context) error {
624	// build the models again so we make sure we get the latest config
625	large, small, err := c.buildAgentModels(ctx)
626	if err != nil {
627		return err
628	}
629	c.currentAgent.SetModels(large, small)
630
631	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
632	if !ok {
633		return errors.New("coder agent not configured")
634	}
635
636	tools, err := c.buildTools(ctx, agentCfg)
637	if err != nil {
638		return err
639	}
640	c.currentAgent.SetTools(tools)
641	return nil
642}
643
644func (c *coordinator) QueuedPrompts(sessionID string) int {
645	return c.currentAgent.QueuedPrompts(sessionID)
646}
647
648func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
649	providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
650	if !ok {
651		return errors.New("model provider not configured")
652	}
653	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
654}