coordinator.go

  1package agent
  2
  3import (
  4	"bytes"
  5	"cmp"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"io"
 10	"log/slog"
 11	"slices"
 12	"strings"
 13
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/agent/prompt"
 16	"github.com/charmbracelet/crush/internal/agent/tools"
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/history"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/message"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/fantasy/ai"
 26	"github.com/charmbracelet/fantasy/anthropic"
 27	"github.com/charmbracelet/fantasy/azure"
 28	"github.com/charmbracelet/fantasy/google"
 29	"github.com/charmbracelet/fantasy/openai"
 30	"github.com/charmbracelet/fantasy/openaicompat"
 31	"github.com/charmbracelet/fantasy/openrouter"
 32	"github.com/qjebbs/go-jsons"
 33)
 34
 35type Coordinator interface {
 36	// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
 37	// SetMainAgent(string)
 38	Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
 39	Cancel(sessionID string)
 40	CancelAll()
 41	IsSessionBusy(sessionID string) bool
 42	IsBusy() bool
 43	QueuedPrompts(sessionID string) int
 44	ClearQueue(sessionID string)
 45	Summarize(context.Context, string) error
 46	Model() Model
 47	UpdateModels() error
 48}
 49
 50type coordinator struct {
 51	cfg         *config.Config
 52	sessions    session.Service
 53	messages    message.Service
 54	permissions permission.Service
 55	history     history.Service
 56	lspClients  *csync.Map[string, *lsp.Client]
 57
 58	currentAgent SessionAgent
 59	agents       map[string]SessionAgent
 60}
 61
 62func NewCoordinator(
 63	cfg *config.Config,
 64	sessions session.Service,
 65	messages message.Service,
 66	permissions permission.Service,
 67	history history.Service,
 68	lspClients *csync.Map[string, *lsp.Client],
 69) (Coordinator, error) {
 70	c := &coordinator{
 71		cfg:         cfg,
 72		sessions:    sessions,
 73		messages:    messages,
 74		permissions: permissions,
 75		history:     history,
 76		lspClients:  lspClients,
 77		agents:      make(map[string]SessionAgent),
 78	}
 79
 80	agentCfg, ok := cfg.Agents[config.AgentCoder]
 81	if !ok {
 82		return nil, errors.New("coder agent not configured")
 83	}
 84
 85	// TODO: make this dynamic when we support multiple agents
 86	prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 87	if err != nil {
 88		return nil, err
 89	}
 90
 91	agent, err := c.buildAgent(prompt, agentCfg)
 92	if err != nil {
 93		return nil, err
 94	}
 95	c.currentAgent = agent
 96	c.agents[config.AgentCoder] = agent
 97	return c, nil
 98}
 99
100// Run implements Coordinator.
101func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
102	model := c.currentAgent.Model()
103	maxTokens := model.CatwalkCfg.DefaultMaxTokens
104	if model.ModelCfg.MaxTokens != 0 {
105		maxTokens = model.ModelCfg.MaxTokens
106	}
107
108	if !model.CatwalkCfg.SupportsImages && attachments != nil {
109		attachments = nil
110	}
111
112	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
113	if !ok {
114		return nil, errors.New("model provider not configured")
115	}
116
117	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
118
119	return c.currentAgent.Run(ctx, SessionAgentCall{
120		SessionID:        sessionID,
121		Prompt:           prompt,
122		Attachments:      attachments,
123		MaxOutputTokens:  maxTokens,
124		ProviderOptions:  mergedOptions,
125		Temperature:      temp,
126		TopP:             topP,
127		TopK:             topK,
128		FrequencyPenalty: freqPenalty,
129		PresencePenalty:  presPenalty,
130	})
131}
132
133func getProviderOptions(model Model, tp catwalk.Type) ai.ProviderOptions {
134	options := ai.ProviderOptions{}
135
136	cfgOpts := []byte("{}")
137	catwalkOpts := []byte("{}")
138
139	if model.ModelCfg.ProviderOptions != nil {
140		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
141		if err == nil {
142			cfgOpts = data
143		}
144	}
145
146	if model.CatwalkCfg.Options.ProviderOptions != nil {
147		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
148		if err == nil {
149			catwalkOpts = data
150		}
151	}
152
153	readers := []io.Reader{
154		bytes.NewReader(catwalkOpts),
155		bytes.NewReader(cfgOpts),
156	}
157
158	got, err := jsons.Merge(readers)
159	if err != nil {
160		slog.Error("Could not merge call config", "err", err)
161		return options
162	}
163
164	mergedOptions := make(map[string]any)
165
166	err = json.Unmarshal([]byte(got), &mergedOptions)
167	if err != nil {
168		slog.Error("Could not create config for call", "err", err)
169		return options
170	}
171
172	switch tp {
173	case openai.Name:
174		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
175		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
176			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
177		}
178		parsed, err := openai.ParseOptions(mergedOptions)
179		if err == nil {
180			options[openai.Name] = parsed
181		}
182	case anthropic.Name:
183		_, hasThink := mergedOptions["thinking"]
184		if !hasThink && model.ModelCfg.Think {
185			mergedOptions["thinking"] = map[string]any{
186				// TODO: kujtim see if we need to make this dynamic
187				"budget_tokens": 2000,
188			}
189		}
190		parsed, err := anthropic.ParseOptions(mergedOptions)
191		if err == nil {
192			options[anthropic.Name] = parsed
193		}
194
195	case openrouter.Name:
196		_, hasReasoning := mergedOptions["reasoning"]
197		if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
198			mergedOptions["reasoning"] = map[string]any{
199				"enabled": true,
200				"effort":  model.ModelCfg.ReasoningEffort,
201			}
202		}
203		parsed, err := openrouter.ParseOptions(mergedOptions)
204		if err == nil {
205			options[openrouter.Name] = parsed
206		}
207	case google.Name:
208		parsed, err := google.ParseOptions(mergedOptions)
209		if err == nil {
210			options[google.Name] = parsed
211		}
212	case azure.Name:
213		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
214		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
215			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
216		}
217		// azure uses the same options as openaicompat
218		parsed, err := openaicompat.ParseOptions(mergedOptions)
219		if err == nil {
220			options[azure.Name] = parsed
221		}
222	case openaicompat.Name:
223		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
224		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
225			mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
226		}
227		parsed, err := openaicompat.ParseOptions(mergedOptions)
228		if err == nil {
229			options[openaicompat.Name] = parsed
230		}
231	}
232
233	return options
234}
235
236func mergeCallOptions(model Model, tp catwalk.Type) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
237	modelOptions := getProviderOptions(model, tp)
238	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
239	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
240	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
241	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
242	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
243	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
244}
245
246func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
247	large, small, err := c.buildAgentModels()
248	if err != nil {
249		return nil, err
250	}
251
252	systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
253	if err != nil {
254		return nil, err
255	}
256
257	tools, err := c.buildTools(agent)
258	if err != nil {
259		return nil, err
260	}
261	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
262}
263
264func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
265	var allTools []ai.AgentTool
266	if slices.Contains(agent.AllowedTools, AgentToolName) {
267		agentTool, err := c.agentTool()
268		if err != nil {
269			return nil, err
270		}
271		allTools = append(allTools, agentTool)
272	}
273
274	allTools = append(allTools,
275		tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
276		tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
277		tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
278		tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
279		tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
280		tools.NewGlobTool(c.cfg.WorkingDir()),
281		tools.NewGrepTool(c.cfg.WorkingDir()),
282		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
283		tools.NewSourcegraphTool(nil),
284		tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
285		tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
286	)
287
288	var filteredTools []ai.AgentTool
289	for _, tool := range allTools {
290		if slices.Contains(agent.AllowedTools, tool.Info().Name) {
291			filteredTools = append(filteredTools, tool)
292		}
293	}
294
295	mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
296
297	for _, mcpTool := range mcpTools {
298		if agent.AllowedMCP == nil {
299			// No MCP restrictions
300			filteredTools = append(filteredTools, mcpTool)
301		} else if len(agent.AllowedMCP) == 0 {
302			// no mcps allowed
303			break
304		}
305
306		for mcp, tools := range agent.AllowedMCP {
307			if mcp == mcpTool.MCP() {
308				if len(tools) == 0 {
309					filteredTools = append(filteredTools, mcpTool)
310				}
311				for _, t := range tools {
312					if t == mcpTool.MCPToolName() {
313						filteredTools = append(filteredTools, mcpTool)
314					}
315				}
316				break
317			}
318		}
319	}
320
321	return filteredTools, nil
322}
323
324// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
325func (c *coordinator) buildAgentModels() (Model, Model, error) {
326	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
327	if !ok {
328		return Model{}, Model{}, errors.New("large model not selected")
329	}
330	smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
331	if !ok {
332		return Model{}, Model{}, errors.New("small model not selected")
333	}
334
335	largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
336	if !ok {
337		return Model{}, Model{}, errors.New("large model provider not configured")
338	}
339
340	largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
341	if err != nil {
342		return Model{}, Model{}, err
343	}
344
345	smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
346	if !ok {
347		return Model{}, Model{}, errors.New("large model provider not configured")
348	}
349
350	smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
351	if err != nil {
352		return Model{}, Model{}, err
353	}
354
355	var largeCatwalkModel *catwalk.Model
356	var smallCatwalkModel *catwalk.Model
357
358	for _, m := range largeProviderCfg.Models {
359		if m.ID == largeModelCfg.Model {
360			largeCatwalkModel = &m
361		}
362	}
363	for _, m := range smallProviderCfg.Models {
364		if m.ID == smallModelCfg.Model {
365			smallCatwalkModel = &m
366		}
367	}
368
369	if largeCatwalkModel == nil {
370		return Model{}, Model{}, errors.New("large model not found in provider config")
371	}
372
373	if smallCatwalkModel == nil {
374		return Model{}, Model{}, errors.New("snall model not found in provider config")
375	}
376
377	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
378	if err != nil {
379		return Model{}, Model{}, err
380	}
381	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
382	if err != nil {
383		return Model{}, Model{}, err
384	}
385
386	return Model{
387			Model:      largeModel,
388			CatwalkCfg: *largeCatwalkModel,
389			ModelCfg:   largeModelCfg,
390		}, Model{
391			Model:      smallModel,
392			CatwalkCfg: *smallCatwalkModel,
393			ModelCfg:   smallModelCfg,
394		}, nil
395}
396
397func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
398	hasBearerAuth := false
399	for key := range headers {
400		if strings.ToLower(key) == "authorization" {
401			hasBearerAuth = true
402			break
403		}
404	}
405	if hasBearerAuth {
406		apiKey = "" // clear apiKey to avoid using X-Api-Key header
407	}
408
409	var opts []anthropic.Option
410
411	if apiKey != "" {
412		// Use standard X-Api-Key header
413		opts = append(opts, anthropic.WithAPIKey(apiKey))
414	}
415
416	if len(headers) > 0 {
417		opts = append(opts, anthropic.WithHeaders(headers))
418	}
419
420	if baseURL != "" {
421		opts = append(opts, anthropic.WithBaseURL(baseURL))
422	}
423
424	if c.cfg.Options.Debug {
425		httpClient := log.NewHTTPClient()
426		opts = append(opts, anthropic.WithHTTPClient(httpClient))
427	}
428
429	return anthropic.New(opts...)
430}
431
432func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
433	opts := []openai.Option{
434		openai.WithAPIKey(apiKey),
435	}
436	if c.cfg.Options.Debug {
437		httpClient := log.NewHTTPClient()
438		opts = append(opts, openai.WithHTTPClient(httpClient))
439	}
440	if len(headers) > 0 {
441		opts = append(opts, openai.WithHeaders(headers))
442	}
443	if baseURL != "" {
444		opts = append(opts, openai.WithBaseURL(baseURL))
445	}
446	return openai.New(opts...)
447}
448
449func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
450	opts := []openrouter.Option{
451		openrouter.WithAPIKey(apiKey),
452	}
453	if c.cfg.Options.Debug {
454		httpClient := log.NewHTTPClient()
455		opts = append(opts, openrouter.WithHTTPClient(httpClient))
456	}
457	if len(headers) > 0 {
458		opts = append(opts, openrouter.WithHeaders(headers))
459	}
460	return openrouter.New(opts...)
461}
462
463func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
464	opts := []openaicompat.Option{
465		openaicompat.WithBaseURL(baseURL),
466		openaicompat.WithAPIKey(apiKey),
467	}
468	if c.cfg.Options.Debug {
469		httpClient := log.NewHTTPClient()
470		opts = append(opts, openaicompat.WithHTTPClient(httpClient))
471	}
472	if len(headers) > 0 {
473		opts = append(opts, openaicompat.WithHeaders(headers))
474	}
475
476	return openaicompat.New(opts...)
477}
478
479func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
480	opts := []azure.Option{
481		azure.WithBaseURL(baseURL),
482		azure.WithAPIKey(apiKey),
483	}
484	if c.cfg.Options.Debug {
485		httpClient := log.NewHTTPClient()
486		opts = append(opts, azure.WithHTTPClient(httpClient))
487	}
488	if options == nil {
489		options = make(map[string]string)
490	}
491	if apiVersion, ok := options["apiVersion"]; ok {
492		opts = append(opts, azure.WithAPIVersion(apiVersion))
493	}
494	if len(headers) > 0 {
495		opts = append(opts, azure.WithHeaders(headers))
496	}
497
498	return azure.New(opts...)
499}
500
501func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
502	opts := []google.Option{
503		google.WithBaseURL(baseURL),
504		google.WithGeminiAPIKey(apiKey),
505	}
506	if c.cfg.Options.Debug {
507		httpClient := log.NewHTTPClient()
508		opts = append(opts, google.WithHTTPClient(httpClient))
509	}
510	if len(headers) > 0 {
511		opts = append(opts, google.WithHeaders(headers))
512	}
513	return google.New(opts...)
514}
515
516func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) ai.Provider {
517	opts := []google.Option{}
518	if c.cfg.Options.Debug {
519		httpClient := log.NewHTTPClient()
520		opts = append(opts, google.WithHTTPClient(httpClient))
521	}
522	if len(headers) > 0 {
523		opts = append(opts, google.WithHeaders(headers))
524	}
525
526	project := options["project"]
527	location := options["location"]
528
529	opts = append(opts, google.WithVertex(project, location))
530
531	return google.New(opts...)
532}
533
534func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
535	if model.Think {
536		return true
537	}
538
539	if model.ProviderOptions == nil {
540		return false
541	}
542
543	opts, err := anthropic.ParseOptions(model.ProviderOptions)
544	if err != nil {
545		return false
546	}
547	if opts.Thinking != nil {
548		return true
549	}
550	return false
551}
552
553func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
554	headers := providerCfg.ExtraHeaders
555
556	// handle special headers for anthropic
557	if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
558		headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
559	}
560
561	// TODO: make sure we have
562	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
563	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
564	var provider ai.Provider
565	switch providerCfg.Type {
566	case openai.Name:
567		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
568	case anthropic.Name:
569		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
570	case openrouter.Name:
571		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
572	case azure.Name:
573		provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
574	case google.Name:
575		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
576	// this is not in fantasy since its just the google provider with extra stuff
577	case "google-vertex":
578		provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
579	case openaicompat.Name:
580		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
581	default:
582		return nil, errors.New("provider type not supported")
583	}
584	return provider, nil
585}
586
587func (c *coordinator) Cancel(sessionID string) {
588	c.currentAgent.Cancel(sessionID)
589}
590
591func (c *coordinator) CancelAll() {
592	c.currentAgent.CancelAll()
593}
594
595func (c *coordinator) ClearQueue(sessionID string) {
596	c.currentAgent.ClearQueue(sessionID)
597}
598
599func (c *coordinator) IsBusy() bool {
600	return c.currentAgent.IsBusy()
601}
602
603func (c *coordinator) IsSessionBusy(sessionID string) bool {
604	return c.currentAgent.IsSessionBusy(sessionID)
605}
606
607func (c *coordinator) Model() Model {
608	return c.currentAgent.Model()
609}
610
611func (c *coordinator) UpdateModels() error {
612	// build the models again so we make sure we get the latest config
613	large, small, err := c.buildAgentModels()
614	if err != nil {
615		return err
616	}
617	c.currentAgent.SetModels(large, small)
618
619	agentCfg, ok := c.cfg.Agents[config.AgentCoder]
620	if !ok {
621		return errors.New("coder agent not configured")
622	}
623
624	tools, err := c.buildTools(agentCfg)
625	if err != nil {
626		return err
627	}
628	c.currentAgent.SetTools(tools)
629	return nil
630}
631
632func (c *coordinator) QueuedPrompts(sessionID string) int {
633	return c.currentAgent.QueuedPrompts(sessionID)
634}
635
636func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
637	return c.currentAgent.Summarize(ctx, sessionID)
638}