coordinator.go

  1package agent
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"io"
 10	"log/slog"
 11	"slices"
 12	"strings"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/agent/prompt"
 17	"github.com/charmbracelet/crush/internal/agent/tools"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/history"
 21	"github.com/charmbracelet/crush/internal/log"
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26
 27	"charm.land/fantasy/providers/anthropic"
 28	"charm.land/fantasy/providers/azure"
 29	"charm.land/fantasy/providers/google"
 30	"charm.land/fantasy/providers/openai"
 31	"charm.land/fantasy/providers/openaicompat"
 32	"charm.land/fantasy/providers/openrouter"
 33	"github.com/qjebbs/go-jsons"
 34)
 35
 36type Coordinator interface {
 37	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 38	// SetMainAgent(string)
 39	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
 40	Cancel(sessionID string)
 41	CancelAll()
 42	IsSessionBusy(sessionID string) bool
 43	IsBusy() bool
 44	QueuedPrompts(sessionID string) int
 45	ClearQueue(sessionID string)
 46	Summarize(context.Context, string) error
 47	Model() Model
 48	UpdateModels() error
 49}
 50
 51type coordinator struct {
 52	cfg         *config.Config
 53	sessions    session.Service
 54	messages    message.Service
 55	permissions permission.Service
 56	history     history.Service
 57	lspClients  *csync.Map[string, *lsp.Client]
 58
 59	currentAgent SessionAgent
 60	agents       map[string]SessionAgent
 61}
 62
 63func NewCoordinator(
 64	cfg *config.Config,
 65	sessions session.Service,
 66	messages message.Service,
 67	permissions permission.Service,
 68	history history.Service,
 69	lspClients *csync.Map[string, *lsp.Client],
 70) (Coordinator, error) {
 71	c := &coordinator{
 72		cfg:         cfg,
 73		sessions:    sessions,
 74		messages:    messages,
 75		permissions: permissions,
 76		history:     history,
 77		lspClients:  lspClients,
 78		agents:      make(map[string]SessionAgent),
 79	}
 80
 81	agentCfg, ok := cfg.Agents[config.AgentCoder]
 82	if !ok {
 83		return nil, errors.New("coder agent not configured")
 84	}
 85
 86	// TODO: make this dynamic when we support multiple agents
 87	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 88	if err != nil {
 89		return nil, err
 90	}
 91
 92	agent, err := c.buildAgent(prompt, agentCfg)
 93	if err != nil {
 94		return nil, err
 95	}
 96	c.currentAgent = agent
 97	c.agents[config.AgentCoder] = agent
 98	return c, nil
 99}
100
101// Run implements Coordinator.
102func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
103	model := c.currentAgent.Model()
104	maxTokens := model.CatwalkCfg.DefaultMaxTokens
105	if model.ModelCfg.MaxTokens != 0 {
106		maxTokens = model.ModelCfg.MaxTokens
107	}
108
109	if !model.CatwalkCfg.SupportsImages && attachments != nil {
110		attachments = nil
111	}
112
113	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
114	if !ok {
115		return nil, errors.New("model provider not configured")
116	}
117
118	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
119
120	return c.currentAgent.Run(ctx, SessionAgentCall{
121		SessionID:        sessionID,
122		Prompt:           prompt,
123		Attachments:      attachments,
124		MaxOutputTokens:  maxTokens,
125		ProviderOptions:  mergedOptions,
126		Temperature:      temp,
127		TopP:             topP,
128		TopK:             topK,
129		FrequencyPenalty: freqPenalty,
130		PresencePenalty:  presPenalty,
131	})
132}
133
134func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
135	options := fantasy.ProviderOptions{}
136
137	cfgOpts := []byte("{}")
138	catwalkOpts := []byte("{}")
139
140	if model.ModelCfg.ProviderOptions != nil {
141		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
142		if err == nil {
143			cfgOpts = data
144		}
145	}
146
147	if model.CatwalkCfg.Options.ProviderOptions != nil {
148		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
149		if err == nil {
150			catwalkOpts = data
151		}
152	}
153
154	readers := []io.Reader{
155		bytes.NewReader(catwalkOpts),
156		bytes.NewReader(cfgOpts),
157	}
158
159	got, err := jsons.Merge(readers)
160	if err != nil {
161		slog.Error("Could not merge call config", "err", err)
162		return options
163	}
164
165	mergedOptions := make(map[string]any)
166
167	err = json.Unmarshal([]byte(got), &mergedOptions)
168	if err != nil {
169		slog.Error("Could not create config for call", "err", err)
170		return options
171	}
172
173	switch tp {
174	case openai.Name:
175		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
176		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
177			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
178		}
179		if openai.IsResponsesModel(model.CatwalkCfg.ID) {
180			if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
181				mergedOptions["reasoning_summary"] = "auto"
182				mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
183			}
184			parsed, err := openai.ParseResponsesOptions(mergedOptions)
185			if err == nil {
186				options[openai.Name] = parsed
187			}
188		} else {
189			parsed, err := openai.ParseOptions(mergedOptions)
190			if err == nil {
191				options[openai.Name] = parsed
192			}
193		}
194	case anthropic.Name:
195		_, hasThink := mergedOptions["thinking"]
196		if !hasThink && model.ModelCfg.Think {
197			mergedOptions["thinking"] = map[string]any{
198				// TODO: kujtim see if we need to make this dynamic
199				"budget_tokens": 2000,
200			}
201		}
202		parsed, err := anthropic.ParseOptions(mergedOptions)
203		if err == nil {
204			options[anthropic.Name] = parsed
205		}
206
207	case openrouter.Name:
208		_, hasReasoning := mergedOptions["reasoning"]
209		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
210			mergedOptions["reasoning"] = map[string]any{
211				"enabled": true,
212				"effort":  model.ModelCfg.ReasoningEffort,
213			}
214		}
215		parsed, err := openrouter.ParseOptions(mergedOptions)
216		if err == nil {
217			options[openrouter.Name] = parsed
218		}
219	case google.Name:
220		parsed, err := google.ParseOptions(mergedOptions)
221		if err == nil {
222			options[google.Name] = parsed
223		}
224	case azure.Name:
225		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
226		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
227			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
228		}
229		// azure uses the same options as openaicompat
230		parsed, err := openaicompat.ParseOptions(mergedOptions)
231		if err == nil {
232			options[azure.Name] = parsed
233		}
234	case openaicompat.Name:
235		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
236		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
237			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
238		}
239		parsed, err := openaicompat.ParseOptions(mergedOptions)
240		if err == nil {
241			options[openaicompat.Name] = parsed
242		}
243	}
244
245	return options
246}
247
248func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
249	modelOptions := getProviderOptions(model, tp)
250	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
251	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
252	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
253	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
254	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
255	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
256}
257
258func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
259	large, small, err := c.buildAgentModels()
260	if err != nil {
261		return nil, err
262	}
263
264	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
265	if err != nil {
266		return nil, err
267	}
268
269	tools, err := c.buildTools(agent)
270	if err != nil {
271		return nil, err
272	}
273	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
274}
275
276func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error) {
277	var allTools []fantasy.AgentTool
278	if slices.Contains(agent.AllowedTools, AgentToolName) {
279		agentTool, err := c.agentTool()
280		if err != nil {
281			return nil, err
282		}
283		allTools = append(allTools, agentTool)
284	}
285
286	allTools = append(allTools,
287		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
288		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
289		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
290		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
291		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
292		tools.NewGlobTool(c.cfg.WorkingDir()),
293		tools.NewGrepTool(c.cfg.WorkingDir()),
294		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
295		tools.NewSourcegraphTool(nil),
296		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
297		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
298	)
299
300	var filteredTools []fantasy.AgentTool
301	for _, tool := range allTools {
302		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
303			filteredTools = append(filteredTools, tool)
304		}
305	}
306
307	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
308
309	for _, mcpTool := range mcpTools {
310		if agent.AllowedMCP == nil {
311			// No MCP restrictions
312			filteredTools = append(filteredTools, mcpTool)
313		} else if len(agent.AllowedMCP) == 0 {
314			// no mcps allowed
315			break
316		}
317
318		for mcp, tools := range agent.AllowedMCP {
319			if mcp == mcpTool.MCP() {
320				if len(tools) == 0 {
321					filteredTools = append(filteredTools, mcpTool)
322				}
323				for _, t := range tools {
324					if t == mcpTool.MCPToolName() {
325						filteredTools = append(filteredTools, mcpTool)
326					}
327				}
328				break
329			}
330		}
331	}
332
333	return filteredTools, nil
334}
335
336// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
337func (c *coordinator) buildAgentModels() (Model, Model, error) {
338	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
339	if !ok {
340		return Model{}, Model{}, errors.New("large model not selected")
341	}
342	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
343	if !ok {
344		return Model{}, Model{}, errors.New("small model not selected")
345	}
346
347	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
348	if !ok {
349		return Model{}, Model{}, errors.New("large model provider not configured")
350	}
351
352	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
353	if err != nil {
354		return Model{}, Model{}, err
355	}
356
357	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
358	if !ok {
359		return Model{}, Model{}, errors.New("large model provider not configured")
360	}
361
362	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
363	if err != nil {
364		return Model{}, Model{}, err
365	}
366
367	var largeCatwalkModel *catwalk.Model
368	var smallCatwalkModel *catwalk.Model
369
370	for _, m := range largeProviderCfg.Models {
371		if m.ID == largeModelCfg.Model {
372			largeCatwalkModel = &m
373		}
374	}
375	for _, m := range smallProviderCfg.Models {
376		if m.ID == smallModelCfg.Model {
377			smallCatwalkModel = &m
378		}
379	}
380
381	if largeCatwalkModel == nil {
382		return Model{}, Model{}, errors.New("large model not found in provider config")
383	}
384
385	if smallCatwalkModel == nil {
386		return Model{}, Model{}, errors.New("snall model not found in provider config")
387	}
388
389	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
390	if err != nil {
391		return Model{}, Model{}, err
392	}
393	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
394	if err != nil {
395		return Model{}, Model{}, err
396	}
397
398	return Model{
399			Model:      largeModel,
400			CatwalkCfg: *largeCatwalkModel,
401			ModelCfg:   largeModelCfg,
402		}, Model{
403			Model:      smallModel,
404			CatwalkCfg: *smallCatwalkModel,
405			ModelCfg:   smallModelCfg,
406		}, nil
407}
408
409func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
410	hasBearerAuth := false
411	for key := range headers {
412		if strings.ToLower(key) == "authorization" {
413			hasBearerAuth = true
414			break
415		}
416	}
417	if hasBearerAuth {
418		apiKey = "" // clear apiKey to avoid using X-Api-Key header
419	}
420
421	var opts []anthropic.Option
422
423	if apiKey != "" {
424		// Use standard X-Api-Key header
425		opts = append(opts, anthropic.WithAPIKey(apiKey))
426	}
427
428	if len(headers) > 0 {
429		opts = append(opts, anthropic.WithHeaders(headers))
430	}
431
432	if baseURL != "" {
433		opts = append(opts, anthropic.WithBaseURL(baseURL))
434	}
435
436	if c.cfg.Options.Debug {
437		httpClient := log.NewHTTPClient()
438		opts = append(opts, anthropic.WithHTTPClient(httpClient))
439	}
440
441	return anthropic.New(opts...)
442}
443
444func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
445	opts := []openai.Option{
446		openai.WithAPIKey(apiKey),
447		openai.WithUseResponsesAPI(),
448	}
449	if c.cfg.Options.Debug {
450		httpClient := log.NewHTTPClient()
451		opts = append(opts, openai.WithHTTPClient(httpClient))
452	}
453	if len(headers) > 0 {
454		opts = append(opts, openai.WithHeaders(headers))
455	}
456	if baseURL != "" {
457		opts = append(opts, openai.WithBaseURL(baseURL))
458	}
459	return openai.New(opts...)
460}
461
462func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) fantasy.Provider {
463	opts := []openrouter.Option{
464		openrouter.WithAPIKey(apiKey),
465	}
466	if c.cfg.Options.Debug {
467		httpClient := log.NewHTTPClient()
468		opts = append(opts, openrouter.WithHTTPClient(httpClient))
469	}
470	if len(headers) > 0 {
471		opts = append(opts, openrouter.WithHeaders(headers))
472	}
473	return openrouter.New(opts...)
474}
475
476func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
477	opts := []openaicompat.Option{
478		openaicompat.WithBaseURL(baseURL),
479		openaicompat.WithAPIKey(apiKey),
480	}
481	if c.cfg.Options.Debug {
482		httpClient := log.NewHTTPClient()
483		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
484	}
485	if len(headers) > 0 {
486		opts = append(opts, openaicompat.WithHeaders(headers))
487	}
488
489	return openaicompat.New(opts...)
490}
491
492func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) fantasy.Provider {
493	opts := []azure.Option{
494		azure.WithBaseURL(baseURL),
495		azure.WithAPIKey(apiKey),
496	}
497	if c.cfg.Options.Debug {
498		httpClient := log.NewHTTPClient()
499		opts = append(opts, azure.WithHTTPClient(httpClient))
500	}
501	if options == nil {
502		options = make(map[string]string)
503	}
504	if apiVersion, ok := options["apiVersion"]; ok {
505		opts = append(opts, azure.WithAPIVersion(apiVersion))
506	}
507	if len(headers) > 0 {
508		opts = append(opts, azure.WithHeaders(headers))
509	}
510
511	return azure.New(opts...)
512}
513
514func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
515	opts := []google.Option{
516		google.WithBaseURL(baseURL),
517		google.WithGeminiAPIKey(apiKey),
518	}
519	if c.cfg.Options.Debug {
520		httpClient := log.NewHTTPClient()
521		opts = append(opts, google.WithHTTPClient(httpClient))
522	}
523	if len(headers) > 0 {
524		opts = append(opts, google.WithHeaders(headers))
525	}
526	return google.New(opts...)
527}
528
529func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) fantasy.Provider {
530	opts := []google.Option{}
531	if c.cfg.Options.Debug {
532		httpClient := log.NewHTTPClient()
533		opts = append(opts, google.WithHTTPClient(httpClient))
534	}
535	if len(headers) > 0 {
536		opts = append(opts, google.WithHeaders(headers))
537	}
538
539	project := options["project"]
540	location := options["location"]
541
542	opts = append(opts, google.WithVertex(project, location))
543
544	return google.New(opts...)
545}
546
547func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
548	if model.Think {
549		return true
550	}
551
552	if model.ProviderOptions == nil {
553		return false
554	}
555
556	opts, err := anthropic.ParseOptions(model.ProviderOptions)
557	if err != nil {
558		return false
559	}
560	if opts.Thinking != nil {
561		return true
562	}
563	return false
564}
565
566func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
567	headers := providerCfg.ExtraHeaders
568
569	// handle special headers for anthropic
570	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
571		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
572	}
573
574	// TODO: make sure we have
575	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
576	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
577	var provider fantasy.Provider
578	switch providerCfg.Type {
579	case openai.Name:
580		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
581	case anthropic.Name:
582		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
583	case openrouter.Name:
584		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
585	case azure.Name:
586		provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
587	case google.Name:
588		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
589	// this is not in fantasy since its just the google provider with extra stuff
590	case "google-vertex":
591		provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
592	case openaicompat.Name:
593		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
594	default:
595		return nil, errors.New("provider type not supported")
596	}
597	return provider, nil
598}
599
600func (c *coordinator) Cancel(sessionID string) {
601	c.currentAgent.Cancel(sessionID)
602}
603
604func (c *coordinator) CancelAll() {
605	c.currentAgent.CancelAll()
606}
607
608func (c *coordinator) ClearQueue(sessionID string) {
609	c.currentAgent.ClearQueue(sessionID)
610}
611
612func (c *coordinator) IsBusy() bool {
613	return c.currentAgent.IsBusy()
614}
615
616func (c *coordinator) IsSessionBusy(sessionID string) bool {
617	return c.currentAgent.IsSessionBusy(sessionID)
618}
619
620func (c *coordinator) Model() Model {
621	return c.currentAgent.Model()
622}
623
624func (c *coordinator) UpdateModels() error {
625	// build the models again so we make sure we get the latest config
626	large, small, err := c.buildAgentModels()
627	if err != nil {
628		return err
629	}
630	c.currentAgent.SetModels(large, small)
631
632	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
633	if !ok {
634		return errors.New("coder agent not configured")
635	}
636
637	tools, err := c.buildTools(agentCfg)
638	if err != nil {
639		return err
640	}
641	c.currentAgent.SetTools(tools)
642	return nil
643}
644
645func (c *coordinator) QueuedPrompts(sessionID string) int {
646	return c.currentAgent.QueuedPrompts(sessionID)
647}
648
649func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
650	providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
651	if !ok {
652		return errors.New("model provider not configured")
653	}
654	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
655}