From 5ff8d6876005ba48686a58fa71417c3a3f17bebe Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 11 Mar 2026 20:12:02 -0400 Subject: [PATCH] =?UTF-8?q?refactor(config):=20introduce=20ConfigStore=20a?= =?UTF-8?q?nd=20Scope=20for=20better=20config=20m=E2=80=A6=20(#2395)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(config): introduce ConfigStore and Scope for better config management This makes config.Config immutable and introduces a ConfigStore that manages the config and provides helper methods for accessing config values with proper scoping (global, workspace). This allows us to avoid passing around mutable config objects and ensures that all parts of the code are accessing the most up-to-date config values. It also lays the groundwork for future features like per-workspace config overrides. * fixt: lint --- internal/agent/agent_tool.go | 2 +- internal/agent/agentic_fetch_tool.go | 10 +- internal/agent/common_test.go | 18 +- internal/agent/coordinator.go | 70 ++--- internal/agent/coordinator_test.go | 2 +- internal/agent/prompt/prompt.go | 29 +- internal/agent/prompts.go | 2 +- internal/agent/tools/list_mcp_resources.go | 2 +- internal/agent/tools/mcp-tools.go | 4 +- internal/agent/tools/mcp/init.go | 8 +- internal/agent/tools/mcp/prompts.go | 2 +- internal/agent/tools/mcp/resources.go | 4 +- internal/agent/tools/mcp/tools.go | 10 +- internal/agent/tools/read_mcp_resource.go | 2 +- internal/app/app.go | 34 ++- internal/cmd/login.go | 18 +- internal/cmd/logs.go | 2 +- internal/cmd/models.go | 4 +- internal/cmd/root.go | 5 +- internal/cmd/stats.go | 2 +- internal/commands/commands.go | 2 +- internal/config/config.go | 248 --------------- internal/config/copilot.go | 47 --- internal/config/init.go | 29 +- internal/config/load.go | 62 ++-- internal/config/load_test.go | 84 +++--- internal/config/recent_models_test.go | 64 ++-- internal/config/scope.go | 11 + internal/config/store.go | 336 +++++++++++++++++++++ internal/lsp/manager.go | 12 +- internal/ui/common/common.go | 7 +- internal/ui/dialog/api_key_input.go | 6 +- internal/ui/dialog/filepicker.go | 2 +- internal/ui/dialog/models.go | 2 +- internal/ui/dialog/oauth.go | 4 +- internal/ui/model/header.go | 3 +- internal/ui/model/landing.go | 2 +- internal/ui/model/onboarding.go | 11 +- internal/ui/model/sidebar.go | 4 +- internal/ui/model/ui.go | 26 +- 40 files changed, 648 insertions(+), 544 deletions(-) create mode 100644 internal/config/scope.go create mode 100644 internal/config/store.go diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 1a7286e342d245c7e7ac1161111d8c205300018b..0d7677dee702b813e0a0d6f02e67837f084d5c29 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -24,7 +24,7 @@ const ( ) func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) { - agentCfg, ok := c.cfg.Agents[config.AgentTask] + agentCfg, ok := c.cfg.Config().Agents[config.AgentTask] if !ok { return nil, errors.New("task agent not configured") } diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index 0bd942e013b706389fb90352c891a4f2ea014f30..ffbe0f49e45c259db3f0bba9f07fda771ad3ecd4 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -98,7 +98,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } - tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*") + tmpDir, err := os.MkdirTemp(c.cfg.Config().Options.DataDirectory, "crush-fetch-*") if err != nil { return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil } @@ -151,12 +151,12 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err) } - systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg) + systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), c.cfg) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err) } - smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider) + smallProviderCfg, ok := c.cfg.Config().Providers.Get(small.ModelCfg.Provider) if !ok { return fantasy.ToolResponse{}, errors.New("small model provider not configured") } @@ -167,7 +167,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( webFetchTool, webSearchTool, tools.NewGlobTool(tmpDir), - tools.NewGrepTool(tmpDir, c.cfg.Tools.Grep), + tools.NewGrepTool(tmpDir, c.cfg.Config().Tools.Grep), tools.NewSourcegraphTool(client), tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, tmpDir), } @@ -177,7 +177,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( SmallModel: small, SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix, SystemPrompt: systemPrompt, - DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize, + DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize, IsYolo: c.permissions.SkipRequests(), Sessions: c.sessions, Messages: c.messages, diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index 89fc6ff3d29d27c60a8091f17ebe0fad057dc44a..132c27d21aee81bd3930c469963f1d73885d58a7 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -185,36 +185,36 @@ func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel // NOTE(@andreynering): Set a fixed config to ensure cassettes match // independently of user config on `$HOME/.config/crush/crush.json`. - cfg.Options.Attribution = &config.Attribution{ + cfg.Config().Options.Attribution = &config.Attribution{ TrailerStyle: "co-authored-by", GeneratedWith: true, } // Clear some fields to avoid issues with VCR cassette matching. - cfg.Options.SkillsPaths = nil - cfg.Options.ContextPaths = nil - cfg.LSP = nil + cfg.Config().Options.SkillsPaths = nil + cfg.Config().Options.ContextPaths = nil + cfg.Config().LSP = nil - systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg) + systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), cfg) if err != nil { return nil, err } // Get the model name for the bash tool modelName := large.Model() // fallback to ID if Name not available - if model := cfg.GetModel(large.Provider(), large.Model()); model != nil { + if model := cfg.Config().GetModel(large.Provider(), large.Model()); model != nil { modelName = model.Name } allTools := []fantasy.AgentTool{ - tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName), + tools.NewBashTool(env.permissions, env.workingDir, cfg.Config().Options.Attribution, modelName), tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()), tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir), tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir), tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()), tools.NewGlobTool(env.workingDir), - tools.NewGrepTool(env.workingDir, cfg.Tools.Grep), - tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls), + tools.NewGrepTool(env.workingDir, cfg.Config().Tools.Grep), + tools.NewLsTool(env.permissions, env.workingDir, cfg.Config().Tools.Ls), tools.NewSourcegraphTool(r.GetDefaultClient()), tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir), tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir), diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 3968952ae4e10bd59e596d02797a845d943bd378..4bca96d5946630423fc532ac6ccb5be833638dd2 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -74,7 +74,7 @@ type Coordinator interface { } type coordinator struct { - cfg *config.Config + cfg *config.ConfigStore sessions session.Service messages message.Service permissions permission.Service @@ -91,7 +91,7 @@ type coordinator struct { func NewCoordinator( ctx context.Context, - cfg *config.Config, + cfg *config.ConfigStore, sessions session.Service, messages message.Service, permissions permission.Service, @@ -112,7 +112,7 @@ func NewCoordinator( agents: make(map[string]SessionAgent), } - agentCfg, ok := cfg.Agents[config.AgentCoder] + agentCfg, ok := cfg.Config().Agents[config.AgentCoder] if !ok { return nil, errCoderAgentNotConfigured } @@ -160,7 +160,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments = filteredAttachments } - providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider) + providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider) if !ok { return nil, errModelProviderNotConfigured } @@ -383,14 +383,14 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age return nil, err } - largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider) + largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider) result := NewSessionAgent(SessionAgentOptions{ LargeModel: large, SmallModel: small, SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix, SystemPrompt: "", IsSubAgent: isSubAgent, - DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize, + DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize, IsYolo: c.permissions.SkipRequests(), Sessions: c.sessions, Messages: c.messages, @@ -399,7 +399,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age }) c.readyWg.Go(func() error { - systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg) + systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg) if err != nil { return err } @@ -439,14 +439,14 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan // Get the model name for the agent modelName := "" - if modelCfg, ok := c.cfg.Models[agent.Model]; ok { - if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil { + if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok { + if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil { modelName = model.Name } } allTools = append(allTools, - tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName), + tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName), tools.NewJobOutputTool(), tools.NewJobKillTool(), tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil), @@ -454,20 +454,20 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil), tools.NewGlobTool(c.cfg.WorkingDir()), - tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep), - tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls), + tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep), + tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls), tools.NewSourcegraphTool(nil), tools.NewTodosTool(c.sessions), - tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...), + tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...), tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), ) // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true). - if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP { + if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP { allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager)) } - if len(c.cfg.MCP) > 0 { + if len(c.cfg.Config().MCP) > 0 { allTools = append( allTools, tools.NewListMCPResourcesTool(c.cfg, c.permissions), @@ -513,16 +513,16 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) { - largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge] + largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge] if !ok { return Model{}, Model{}, errLargeModelNotSelected } - smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall] + smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall] if !ok { return Model{}, Model{}, errSmallModelNotSelected } - largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider) + largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider) if !ok { return Model{}, Model{}, errLargeModelProviderNotConfigured } @@ -532,7 +532,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Mo return Model{}, Model{}, err } - smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider) + smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider) if !ok { return Model{}, Model{}, errSmallModelProviderNotConfigured } @@ -620,7 +620,7 @@ func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map opts = append(opts, anthropic.WithBaseURL(baseURL)) } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, anthropic.WithHTTPClient(httpClient)) } @@ -632,7 +632,7 @@ func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[st openai.WithAPIKey(apiKey), openai.WithUseResponsesAPI(), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, openai.WithHTTPClient(httpClient)) } @@ -649,7 +649,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri opts := []openrouter.Option{ openrouter.WithAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, openrouter.WithHTTPClient(httpClient)) } @@ -663,7 +663,7 @@ func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]s opts := []vercel.Option{ vercel.WithAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, vercel.WithHTTPClient(httpClient)) } @@ -683,8 +683,8 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers var httpClient *http.Client if providerID == string(catwalk.InferenceProviderCopilot) { opts = append(opts, openaicompat.WithUseResponsesAPI()) - httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug) - } else if c.cfg.Options.Debug { + httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug) + } else if c.cfg.Config().Options.Debug { httpClient = log.NewHTTPClient() } if httpClient != nil { @@ -708,7 +708,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str azure.WithAPIKey(apiKey), azure.WithUseResponsesAPI(), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, azure.WithHTTPClient(httpClient)) } @@ -727,7 +727,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) { var opts []bedrock.Option - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, bedrock.WithHTTPClient(httpClient)) } @@ -746,7 +746,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st google.WithBaseURL(baseURL), google.WithGeminiAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, google.WithHTTPClient(httpClient)) } @@ -758,7 +758,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) { opts := []google.Option{} - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, google.WithHTTPClient(httpClient)) } @@ -779,7 +779,7 @@ func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provid hyper.WithBaseURL(baseURL), hyper.WithAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, hyper.WithHTTPClient(httpClient)) } @@ -887,7 +887,7 @@ func (c *coordinator) UpdateModels(ctx context.Context) error { } c.currentAgent.SetModels(large, small) - agentCfg, ok := c.cfg.Agents[config.AgentCoder] + agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder] if !ok { return errCoderAgentNotConfigured } @@ -909,7 +909,7 @@ func (c *coordinator) QueuedPromptsList(sessionID string) []string { } func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { - providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider) + providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider) if !ok { return errModelProviderNotConfigured } @@ -922,7 +922,7 @@ func (c *coordinator) isUnauthorized(err error) bool { } func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error { - if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { + if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil { slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err) return err } @@ -940,7 +940,7 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con } providerCfg.APIKey = newAPIKey - c.cfg.Providers.Set(providerCfg.ID, providerCfg) + c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg) if err := c.UpdateModels(ctx); err != nil { return err @@ -984,7 +984,7 @@ func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (f maxTokens = model.ModelCfg.MaxTokens } - providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider) + providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider) if !ok { return fantasy.ToolResponse{}, errModelProviderNotConfigured } diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go index 3c270394cba9c1758e4a9029a149027af6bf36c2..657575b6458d7fb815c7a9646a9d605c8b89ec42 100644 --- a/internal/agent/coordinator_test.go +++ b/internal/agent/coordinator_test.go @@ -44,7 +44,7 @@ func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOp func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator { cfg, err := config.Init(env.workingDir, "", false) require.NoError(t, err) - cfg.Providers.Set(providerID, providerCfg) + cfg.Config().Providers.Set(providerID, providerCfg) return &coordinator{ cfg: cfg, sessions: env.sessions, diff --git a/internal/agent/prompt/prompt.go b/internal/agent/prompt/prompt.go index d68c7c132116c49cd004bee52169be7487133efa..c8f488319f04238c476aae4719728fb94521695e 100644 --- a/internal/agent/prompt/prompt.go +++ b/internal/agent/prompt/prompt.go @@ -76,13 +76,13 @@ func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) { return p, nil } -func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) { +func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) { t, err := template.New(p.name).Parse(p.template) if err != nil { return "", fmt.Errorf("parsing template: %w", err) } var sb strings.Builder - d, err := p.promptData(ctx, provider, model, cfg) + d, err := p.promptData(ctx, provider, model, store) if err != nil { return "", err } @@ -104,11 +104,11 @@ func processFile(filePath string) *ContextFile { } } -func processContextPath(p string, cfg config.Config) []ContextFile { +func processContextPath(p string, store *config.ConfigStore) []ContextFile { var contexts []ContextFile fullPath := p if !filepath.IsAbs(p) { - fullPath = filepath.Join(cfg.WorkingDir(), p) + fullPath = filepath.Join(store.WorkingDir(), p) } info, err := os.Stat(fullPath) if err != nil { @@ -136,11 +136,11 @@ func processContextPath(p string, cfg config.Config) []ContextFile { } // expandPath expands ~ and environment variables in file paths -func expandPath(path string, cfg config.Config) string { +func expandPath(path string, store *config.ConfigStore) string { path = home.Long(path) // Handle environment variable expansion using the same pattern as config if strings.HasPrefix(path, "$") { - if expanded, err := cfg.Resolver().ResolveValue(path); err == nil { + if expanded, err := store.Resolver().ResolveValue(path); err == nil { path = expanded } } @@ -148,19 +148,20 @@ func expandPath(path string, cfg config.Config) string { return path } -func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) { - workingDir := cmp.Or(p.workingDir, cfg.WorkingDir()) +func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) { + workingDir := cmp.Or(p.workingDir, store.WorkingDir()) platform := cmp.Or(p.platform, runtime.GOOS) files := map[string][]ContextFile{} + cfg := store.Config() for _, pth := range cfg.Options.ContextPaths { - expanded := expandPath(pth, cfg) + expanded := expandPath(pth, store) pathKey := strings.ToLower(expanded) if _, ok := files[pathKey]; ok { continue } - content := processContextPath(expanded, cfg) + content := processContextPath(expanded, store) files[pathKey] = content } @@ -169,18 +170,18 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg con if len(cfg.Options.SkillsPaths) > 0 { expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths)) for _, pth := range cfg.Options.SkillsPaths { - expandedPaths = append(expandedPaths, expandPath(pth, cfg)) + expandedPaths = append(expandedPaths, expandPath(pth, store)) } if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 { availSkillXML = skills.ToPromptXML(discoveredSkills) } } - isGit := isGitRepo(cfg.WorkingDir()) + isGit := isGitRepo(store.WorkingDir()) data := PromptDat{ Provider: provider, Model: model, - Config: cfg, + Config: *cfg, WorkingDir: filepath.ToSlash(workingDir), IsGitRepo: isGit, Platform: platform, @@ -189,7 +190,7 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg con } if isGit { var err error - data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir()) + data.GitStatus, err = getGitStatus(ctx, store.WorkingDir()) if err != nil { return PromptDat{}, err } diff --git a/internal/agent/prompts.go b/internal/agent/prompts.go index 577d32e4e274d9cb8274bd862af583208a613f08..448fe0425c3b700b1d6edafc842c4815ad3d5760 100644 --- a/internal/agent/prompts.go +++ b/internal/agent/prompts.go @@ -33,7 +33,7 @@ func taskPrompt(opts ...prompt.Option) (*prompt.Prompt, error) { return systemPrompt, nil } -func InitializePrompt(cfg config.Config) (string, error) { +func InitializePrompt(cfg *config.ConfigStore) (string, error) { systemPrompt, err := prompt.NewPrompt("initialize", string(initializePromptTmpl)) if err != nil { return "", err diff --git a/internal/agent/tools/list_mcp_resources.go b/internal/agent/tools/list_mcp_resources.go index 032d1eb1888a65e9a14daecc3b503698a6fa60d4..7ea8998a1dc80955b2a5b0a79d4aef7d19fb9011 100644 --- a/internal/agent/tools/list_mcp_resources.go +++ b/internal/agent/tools/list_mcp_resources.go @@ -28,7 +28,7 @@ const ListMCPResourcesToolName = "list_mcp_resources" //go:embed list_mcp_resources.md var listMCPResourcesDescription []byte -func NewListMCPResourcesTool(cfg *config.Config, permissions permission.Service) fantasy.AgentTool { +func NewListMCPResourcesTool(cfg *config.ConfigStore, permissions permission.Service) fantasy.AgentTool { return fantasy.NewParallelAgentTool( ListMCPResourcesToolName, string(listMCPResourcesDescription), diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go index 429cadaf6b686b83e170ef35976881d839b07e17..e1184118552ee62e75f60c6943f59ecca2868563 100644 --- a/internal/agent/tools/mcp-tools.go +++ b/internal/agent/tools/mcp-tools.go @@ -11,7 +11,7 @@ import ( ) // GetMCPTools gets all the currently available MCP tools. -func GetMCPTools(permissions permission.Service, cfg *config.Config, wd string) []*Tool { +func GetMCPTools(permissions permission.Service, cfg *config.ConfigStore, wd string) []*Tool { var result []*Tool for mcpName, tools := range mcp.Tools() { for _, tool := range tools { @@ -31,7 +31,7 @@ func GetMCPTools(permissions permission.Service, cfg *config.Config, wd string) type Tool struct { mcpName string tool *mcp.Tool - cfg *config.Config + cfg *config.ConfigStore permissions permission.Service workingDir string providerOptions fantasy.ProviderOptions diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index f8cfe0ce84bf7b1987496607d42753b8ca72263f..cba9a51c717b1866b823762f85bfadf90e1a7a10 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -163,11 +163,11 @@ func Close(ctx context.Context) error { } // Initialize initializes MCP clients based on the provided configuration. -func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) { +func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) { slog.Info("Initializing MCP clients") var wg sync.WaitGroup // Initialize states for all configured MCPs - for name, m := range cfg.MCP { + for name, m := range cfg.Config().MCP { if m.Disabled { updateState(name, StateDisabled, nil, nil, Counts{}) slog.Debug("Skipping disabled MCP", "name", name) @@ -253,13 +253,13 @@ func WaitForInit(ctx context.Context) error { } } -func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) { +func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) { sess, ok := sessions.Get(name) if !ok { return nil, fmt.Errorf("mcp '%s' not available", name) } - m := cfg.MCP[name] + m := cfg.Config().MCP[name] state, _ := states.Get(name) timeout := mcpTimeout(m) diff --git a/internal/agent/tools/mcp/prompts.go b/internal/agent/tools/mcp/prompts.go index 2b39d5dc2db43aff418c3dd7561edbcebd6af865..d84be303ecb103d4fdd37423b7b6d088374d2c70 100644 --- a/internal/agent/tools/mcp/prompts.go +++ b/internal/agent/tools/mcp/prompts.go @@ -20,7 +20,7 @@ func Prompts() iter.Seq2[string, []*Prompt] { } // GetPromptMessages retrieves the content of an MCP prompt with the given arguments. -func GetPromptMessages(ctx context.Context, cfg *config.Config, clientName, promptName string, args map[string]string) ([]string, error) { +func GetPromptMessages(ctx context.Context, cfg *config.ConfigStore, clientName, promptName string, args map[string]string) ([]string, error) { c, err := getOrRenewClient(ctx, cfg, clientName) if err != nil { return nil, err diff --git a/internal/agent/tools/mcp/resources.go b/internal/agent/tools/mcp/resources.go index 8e2bcc796b28c698481dd90b0c70511273f7c98d..21616761e81212960f4d6ad59da1505049abbffb 100644 --- a/internal/agent/tools/mcp/resources.go +++ b/internal/agent/tools/mcp/resources.go @@ -24,7 +24,7 @@ func Resources() iter.Seq2[string, []*Resource] { } // ListResources returns the current resources for an MCP server. -func ListResources(ctx context.Context, cfg *config.Config, name string) ([]*Resource, error) { +func ListResources(ctx context.Context, cfg *config.ConfigStore, name string) ([]*Resource, error) { session, err := getOrRenewClient(ctx, cfg, name) if err != nil { return nil, err @@ -43,7 +43,7 @@ func ListResources(ctx context.Context, cfg *config.Config, name string) ([]*Res } // ReadResource reads the contents of a resource from an MCP server. -func ReadResource(ctx context.Context, cfg *config.Config, name, uri string) ([]*ResourceContents, error) { +func ReadResource(ctx context.Context, cfg *config.ConfigStore, name, uri string) ([]*ResourceContents, error) { session, err := getOrRenewClient(ctx, cfg, name) if err != nil { return nil, err diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index b6e208f7ccb3363bee0a0b60ef56c103ad9cd41b..8d1d2649ba4381e14fa8d99933f1dfb3b42d27ae 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -32,7 +32,7 @@ func Tools() iter.Seq2[string, []*Tool] { } // RunTool runs an MCP tool with the given input parameters. -func RunTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (ToolResult, error) { +func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string, input string) (ToolResult, error) { var args map[string]any if err := json.Unmarshal([]byte(input), &args); err != nil { return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err) @@ -108,7 +108,7 @@ func RunTool(ctx context.Context, cfg *config.Config, name, toolName string, inp // RefreshTools gets the updated list of tools from the MCP and updates the // global state. -func RefreshTools(ctx context.Context, cfg *config.Config, name string) { +func RefreshTools(ctx context.Context, cfg *config.ConfigStore, name string) { session, ok := sessions.Get(name) if !ok { slog.Warn("Refresh tools: no session", "name", name) @@ -139,7 +139,7 @@ func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) { return result.Tools, nil } -func updateTools(cfg *config.Config, name string, tools []*Tool) int { +func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int { tools = filterDisabledTools(cfg, name, tools) if len(tools) == 0 { allTools.Del(name) @@ -150,8 +150,8 @@ func updateTools(cfg *config.Config, name string, tools []*Tool) int { } // filterDisabledTools removes tools that are disabled via config. -func filterDisabledTools(cfg *config.Config, mcpName string, tools []*Tool) []*Tool { - mcpCfg, ok := cfg.MCP[mcpName] +func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool { + mcpCfg, ok := cfg.Config().MCP[mcpName] if !ok || len(mcpCfg.DisabledTools) == 0 { return tools } diff --git a/internal/agent/tools/read_mcp_resource.go b/internal/agent/tools/read_mcp_resource.go index cc0450d63aa94574e45e4264906c77fc2b7a1127..c96b00194b92b05e40953a49a90c3453fe6b16b2 100644 --- a/internal/agent/tools/read_mcp_resource.go +++ b/internal/agent/tools/read_mcp_resource.go @@ -30,7 +30,7 @@ const ReadMCPResourceToolName = "read_mcp_resource" //go:embed read_mcp_resource.md var readMCPResourceDescription []byte -func NewReadMCPResourceTool(cfg *config.Config, permissions permission.Service) fantasy.AgentTool { +func NewReadMCPResourceTool(cfg *config.ConfigStore, permissions permission.Service) fantasy.AgentTool { return fantasy.NewParallelAgentTool( ReadMCPResourceToolName, string(readMCPResourceDescription), diff --git a/internal/app/app.go b/internal/app/app.go index 7d87bd1231000cb2a1c88c1fa7a0ceae5b4316a9..8ed3e2e41cb2b235771eba24c3b59945f73cdfda 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -61,7 +61,7 @@ type App struct { LSPManager *lsp.Manager - config *config.Config + config *config.ConfigStore serviceEventsWG *sync.WaitGroup eventsCtx context.Context @@ -75,11 +75,12 @@ type App struct { } // New initializes a new application instance. -func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { +func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore) (*App, error) { q := db.New(conn) sessions := session.NewService(q, conn) messages := message.NewService(q) files := history.NewService(q, conn) + cfg := store.Config() skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests var allowedTools []string if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil { @@ -90,13 +91,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { Sessions: sessions, Messages: messages, History: files, - Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), + Permissions: permission.NewPermissionService(store.WorkingDir(), skipPermissionsRequests, allowedTools), FileTracker: filetracker.NewService(q), - LSPManager: lsp.NewManager(cfg), + LSPManager: lsp.NewManager(store), globalCtx: ctx, - config: cfg, + config: store, events: make(chan tea.Msg, 100), serviceEventsWG: &sync.WaitGroup{}, @@ -109,7 +110,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { // Check for updates in the background. go app.checkForUpdates(ctx) - go mcp.Initialize(ctx, app.Permissions, cfg) + go mcp.Initialize(ctx, app.Permissions, store) // cleanup database upon app shutdown app.cleanupFuncs = append( @@ -141,8 +142,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { return app, nil } -// Config returns the application configuration. +// Config returns the pure-data configuration. func (app *App) Config() *config.Config { + return app.config.Config() +} + +// Store returns the config store. +func (app *App) Store() *config.ConfigStore { return app.config } @@ -178,7 +184,7 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, } stderrTTY = term.IsTerminal(os.Stderr.Fd()) stdinTTY = term.IsTerminal(os.Stdin.Fd()) - progress = app.config.Options.Progress == nil || *app.config.Options.Progress + progress = app.config.Config().Options.Progress == nil || *app.config.Config().Options.Progress if !hideSpinner && stderrTTY { t := styles.DefaultStyles() @@ -331,7 +337,7 @@ func (app *App) UpdateAgentModel(ctx context.Context) error { // If largeModel is provided but smallModel is not, the small model defaults to // the provider's default small model. func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, smallModel string) error { - providers := app.config.Providers.Copy() + providers := app.config.Config().Providers.Copy() largeMatches, smallMatches, err := findModels(providers, largeModel, smallModel) if err != nil { @@ -348,7 +354,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, } largeProviderID = found.provider slog.Info("Overriding large model for non-interactive run", "provider", found.provider, "model", found.modelID) - app.config.Models[config.SelectedModelTypeLarge] = config.SelectedModel{ + app.config.Config().Models[config.SelectedModelTypeLarge] = config.SelectedModel{ Provider: found.provider, Model: found.modelID, } @@ -362,7 +368,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, return err } slog.Info("Overriding small model for non-interactive run", "provider", found.provider, "model", found.modelID) - app.config.Models[config.SelectedModelTypeSmall] = config.SelectedModel{ + app.config.Config().Models[config.SelectedModelTypeSmall] = config.SelectedModel{ Provider: found.provider, Model: found.modelID, } @@ -370,7 +376,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, case largeModel != "": // No small model specified, but large model was - use provider's default. smallCfg := app.GetDefaultSmallModel(largeProviderID) - app.config.Models[config.SelectedModelTypeSmall] = smallCfg + app.config.Config().Models[config.SelectedModelTypeSmall] = smallCfg } return app.AgentCoordinator.UpdateModels(ctx) @@ -379,7 +385,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, // GetDefaultSmallModel returns the default small model for the given // provider. Falls back to the large model if no default is found. func (app *App) GetDefaultSmallModel(providerID string) config.SelectedModel { - cfg := app.config + cfg := app.config.Config() largeModelCfg := cfg.Models[config.SelectedModelTypeLarge] // Find the provider in the known providers list to get its default small model. @@ -481,7 +487,7 @@ func setupSubscriber[T any]( } func (app *App) InitCoderAgent(ctx context.Context) error { - coderAgentCfg := app.config.Agents[config.AgentCoder] + coderAgentCfg := app.config.Config().Agents[config.AgentCoder] if coderAgentCfg.ID == "" { return fmt.Errorf("coder agent configuration is missing") } diff --git a/internal/cmd/login.go b/internal/cmd/login.go index bdad4547d6f583b5ae7e5a97bbbbd88a1421e6ee..c9acb12df19875f48b242bee96e377bf5548aacb 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -52,16 +52,16 @@ crush login copilot } switch provider { case "hyper": - return loginHyper(app.Config()) + return loginHyper(app.Store()) case "copilot", "github", "github-copilot": - return loginCopilot(app.Config()) + return loginCopilot(app.Store()) default: return fmt.Errorf("unknown platform: %s", args[0]) } }, } -func loginHyper(cfg *config.Config) error { +func loginHyper(cfg *config.ConfigStore) error { if !hyperp.Enabled() { return fmt.Errorf("hyper not enabled") } @@ -112,8 +112,8 @@ func loginHyper(cfg *config.Config) error { } if err := cmp.Or( - cfg.SetConfigField("providers.hyper.api_key", token.AccessToken), - cfg.SetConfigField("providers.hyper.oauth", token), + cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken), + cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.oauth", token), ); err != nil { return err } @@ -123,10 +123,10 @@ func loginHyper(cfg *config.Config) error { return nil } -func loginCopilot(cfg *config.Config) error { +func loginCopilot(cfg *config.ConfigStore) error { ctx := getLoginContext() - if cfg.HasConfigField("providers.copilot.oauth") { + if cfg.HasConfigField(config.ScopeGlobal, "providers.copilot.oauth") { fmt.Println("You are already logged in to GitHub Copilot.") return nil } @@ -177,8 +177,8 @@ func loginCopilot(cfg *config.Config) error { } if err := cmp.Or( - cfg.SetConfigField("providers.copilot.api_key", token.AccessToken), - cfg.SetConfigField("providers.copilot.oauth", token), + cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken), + cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.oauth", token), ); err != nil { return err } diff --git a/internal/cmd/logs.go b/internal/cmd/logs.go index 804b23310fa1e3fb86e4b32983bfcdd571df47aa..87e106feb7cc934567b183d454ef1537970dca88 100644 --- a/internal/cmd/logs.go +++ b/internal/cmd/logs.go @@ -55,7 +55,7 @@ var logsCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to load configuration: %v", err) } - logsFile := filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log") + logsFile := filepath.Join(cfg.Config().Options.DataDirectory, "logs", "crush.log") _, err = os.Stat(logsFile) if os.IsNotExist(err) { log.Warn("Looks like you are not in a crush project. No logs found.") diff --git a/internal/cmd/models.go b/internal/cmd/models.go index e2aa5c991d5cf49ba78dbff9d3f79c4f6493523d..f4fa559ebe41d93bee54ed5e2272f8fb0b8dc9ad 100644 --- a/internal/cmd/models.go +++ b/internal/cmd/models.go @@ -38,7 +38,7 @@ crush models gpt5`, return err } - if !cfg.IsConfigured() { + if !cfg.Config().IsConfigured() { return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") } @@ -55,7 +55,7 @@ crush models gpt5`, var providerIDs []string providerModels := make(map[string][]string) - for providerID, provider := range cfg.Providers.Seq2() { + for providerID, provider := range cfg.Config().Providers.Seq2() { if provider.Disable { continue } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 52ffda3fb09a0e6fdfb88084b80f7bdd261fb3c2..6e1bc08f2f14e8af3d65b5dca7826b95d890b116 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -189,11 +189,12 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - cfg, err := config.Init(cwd, dataDir, debug) + store, err := config.Init(cwd, dataDir, debug) if err != nil { return nil, err } + cfg := store.Config() if cfg.Permissions == nil { cfg.Permissions = &config.Permissions{} } @@ -215,7 +216,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - appInstance, err := app.New(ctx, conn, cfg) + appInstance, err := app.New(ctx, conn, store) if err != nil { slog.Error("Failed to create app instance", "error", err) return nil, err diff --git a/internal/cmd/stats.go b/internal/cmd/stats.go index 8831c2a647a283bfe6d6edff15c5eff4dafb3377..3900acadec059869b1896c8adeb49f93155f17fa 100644 --- a/internal/cmd/stats.go +++ b/internal/cmd/stats.go @@ -131,7 +131,7 @@ func runStats(cmd *cobra.Command, _ []string) error { if err != nil { return fmt.Errorf("failed to initialize config: %w", err) } - dataDir = cfg.Options.DataDirectory + dataDir = cfg.Config().Options.DataDirectory } conn, err := db.Connect(ctx, dataDir) diff --git a/internal/commands/commands.go b/internal/commands/commands.go index aeb2ca305dc984c2c450d249d51028858e4e9802..96302bde1281adebfe74a009e4f76443f5368afe 100644 --- a/internal/commands/commands.go +++ b/internal/commands/commands.go @@ -227,7 +227,7 @@ func isMarkdownFile(name string) bool { return strings.HasSuffix(strings.ToLower(name), ".md") } -func GetMCPPrompt(cfg *config.Config, clientID, promptID string, args map[string]string) (string, error) { +func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) { // TODO: we should pass the context down result, err := mcp.GetPromptMessages(context.Background(), cfg, clientID, promptID, args) if err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 118afef344f8a022add7a13db406ccce27a1391e..8e9b3f0fb7349f4b911c9a6c41fc3e3890f3f19e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,22 +8,16 @@ import ( "maps" "net/http" "net/url" - "os" - "path/filepath" "slices" "strings" "time" "charm.land/catwalk/pkg/catwalk" - hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" - "github.com/charmbracelet/crush/internal/oauth/hyper" "github.com/invopop/jsonschema" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) const ( @@ -398,17 +392,6 @@ type Config struct { Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"` Agents map[string]Agent `json:"-"` - - // Internal - workingDir string `json:"-"` - // TODO: find a better way to do this this should probably not be part of the config - resolver VariableResolver - dataConfigDir string `json:"-"` - knownProviders []catwalk.Provider `json:"-"` -} - -func (c *Config) WorkingDir() string { - return c.workingDir } func (c *Config) EnabledProviders() []ProviderConfig { @@ -472,235 +455,8 @@ func (c *Config) SmallModel() *catwalk.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) SetCompactMode(enabled bool) error { - if c.Options == nil { - c.Options = &Options{} - } - c.Options.TUI.CompactMode = enabled - return c.SetConfigField("options.tui.compact_mode", enabled) -} - -func (c *Config) Resolve(key string) (string, error) { - if c.resolver == nil { - return "", fmt.Errorf("no variable resolver configured") - } - return c.resolver.ResolveValue(key) -} - -func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { - c.Models[modelType] = model - if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { - return fmt.Errorf("failed to update preferred model: %w", err) - } - if err := c.recordRecentModel(modelType, model); err != nil { - return err - } - return nil -} - -func (c *Config) HasConfigField(key string) bool { - data, err := os.ReadFile(c.dataConfigDir) - if err != nil { - return false - } - return gjson.Get(string(data), key).Exists() -} - -func (c *Config) SetConfigField(key string, value any) error { - data, err := os.ReadFile(c.dataConfigDir) - if err != nil { - if os.IsNotExist(err) { - data = []byte("{}") - } else { - return fmt.Errorf("failed to read config file: %w", err) - } - } - - newValue, err := sjson.Set(string(data), key, value) - if err != nil { - return fmt.Errorf("failed to set config field %s: %w", key, err) - } - if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil { - return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err) - } - if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - return nil -} - -func (c *Config) RemoveConfigField(key string) error { - data, err := os.ReadFile(c.dataConfigDir) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - - newValue, err := sjson.Delete(string(data), key) - if err != nil { - return fmt.Errorf("failed to delete config field %s: %w", key, err) - } - if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil { - return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err) - } - if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - return nil -} - -// RefreshOAuthToken refreshes the OAuth token for the given provider. -func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error { - providerConfig, exists := c.Providers.Get(providerID) - if !exists { - return fmt.Errorf("provider %s not found", providerID) - } - - if providerConfig.OAuthToken == nil { - return fmt.Errorf("provider %s does not have an OAuth token", providerID) - } - - var newToken *oauth.Token - var refreshErr error - switch providerID { - case string(catwalk.InferenceProviderCopilot): - newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) - case hyperp.Name: - newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) - default: - return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) - } - if refreshErr != nil { - return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) - } - - slog.Info("Successfully refreshed OAuth token", "provider", providerID) - providerConfig.OAuthToken = newToken - providerConfig.APIKey = newToken.AccessToken - - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - - c.Providers.Set(providerID, providerConfig) - - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken), - ); err != nil { - return fmt.Errorf("failed to persist refreshed token: %w", err) - } - - return nil -} - -func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { - var providerConfig ProviderConfig - var exists bool - var setKeyOrToken func() - - switch v := apiKey.(type) { - case string: - if err := c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { - return fmt.Errorf("failed to save api key to config file: %w", err) - } - setKeyOrToken = func() { providerConfig.APIKey = v } - case *oauth.Token: - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v), - ); err != nil { - return err - } - setKeyOrToken = func() { - providerConfig.APIKey = v.AccessToken - providerConfig.OAuthToken = v - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - } - } - - providerConfig, exists = c.Providers.Get(providerID) - if exists { - setKeyOrToken() - c.Providers.Set(providerID, providerConfig) - return nil - } - - var foundProvider *catwalk.Provider - for _, p := range c.knownProviders { - if string(p.ID) == providerID { - foundProvider = &p - break - } - } - - if foundProvider != nil { - // Create new provider config based on known provider - providerConfig = ProviderConfig{ - ID: providerID, - Name: foundProvider.Name, - BaseURL: foundProvider.APIEndpoint, - Type: foundProvider.Type, - Disable: false, - ExtraHeaders: make(map[string]string), - ExtraParams: make(map[string]string), - Models: foundProvider.Models, - } - setKeyOrToken() - } else { - return fmt.Errorf("provider with ID %s not found in known providers", providerID) - } - // Store the updated provider config - c.Providers.Set(providerID, providerConfig) - return nil -} - const maxRecentModelsPerType = 5 -func (c *Config) recordRecentModel(modelType SelectedModelType, model SelectedModel) error { - if model.Provider == "" || model.Model == "" { - return nil - } - - if c.RecentModels == nil { - c.RecentModels = make(map[SelectedModelType][]SelectedModel) - } - - eq := func(a, b SelectedModel) bool { - return a.Provider == b.Provider && a.Model == b.Model - } - - entry := SelectedModel{ - Provider: model.Provider, - Model: model.Model, - } - - current := c.RecentModels[modelType] - withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { - return eq(existing, entry) - }) - - updated := append([]SelectedModel{entry}, withoutCurrent...) - if len(updated) > maxRecentModelsPerType { - updated = updated[:maxRecentModelsPerType] - } - - if slices.EqualFunc(current, updated, eq) { - return nil - } - - c.RecentModels[modelType] = updated - - if err := c.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { - return fmt.Errorf("failed to persist recent models: %w", err) - } - - return nil -} - func allToolNames() []string { return []string{ "agent", @@ -780,10 +536,6 @@ func (c *Config) SetupAgents() { c.Agents = agents } -func (c *Config) Resolver() VariableResolver { - return c.resolver -} - func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { var ( providerID = catwalk.InferenceProvider(c.ID) diff --git a/internal/config/copilot.go b/internal/config/copilot.go index d72e7d5048ba4d31c88d7f7152a6b3a9510960a2..d912156bec00a9f00850ab2ec3a3baf1016c2141 100644 --- a/internal/config/copilot.go +++ b/internal/config/copilot.go @@ -1,48 +1 @@ package config - -import ( - "cmp" - "context" - "log/slog" - "testing" - - "charm.land/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/oauth" - "github.com/charmbracelet/crush/internal/oauth/copilot" -) - -func (c *Config) ImportCopilot() (*oauth.Token, bool) { - if testing.Testing() { - return nil, false - } - - if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") { - return nil, false - } - - diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() - if !hasDiskToken { - return nil, false - } - - slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") - token, err := copilot.RefreshToken(context.TODO(), diskToken) - if err != nil { - slog.Error("Unable to import GitHub Copilot token", "error", err) - return nil, false - } - - if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil { - return token, false - } - - if err := cmp.Or( - c.SetConfigField("providers.copilot.api_key", token.AccessToken), - c.SetConfigField("providers.copilot.oauth", token), - ); err != nil { - slog.Error("Unable to save GitHub Copilot token to disk", "error", err) - } - - slog.Info("GitHub Copilot successfully imported") - return token, true -} diff --git a/internal/config/init.go b/internal/config/init.go index 5a4683f77485f54409d4372a33d1933b47abd33f..6138c49d496b16d054ea9ff3f6c49906b3f433ca 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -18,19 +18,20 @@ type ProjectInitFlag struct { Initialized bool `json:"initialized"` } -func Init(workingDir, dataDir string, debug bool) (*Config, error) { - cfg, err := Load(workingDir, dataDir, debug) +func Init(workingDir, dataDir string, debug bool) (*ConfigStore, error) { + store, err := Load(workingDir, dataDir, debug) if err != nil { return nil, err } - return cfg, nil + return store, nil } -func ProjectNeedsInitialization(cfg *Config) (bool, error) { - if cfg == nil { +func ProjectNeedsInitialization(store *ConfigStore) (bool, error) { + if store == nil { return false, fmt.Errorf("config not loaded") } + cfg := store.Config() flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename) _, err := os.Stat(flagFilePath) @@ -42,7 +43,7 @@ func ProjectNeedsInitialization(cfg *Config) (bool, error) { return false, fmt.Errorf("failed to check init flag file: %w", err) } - someContextFileExists, err := contextPathsExist(cfg.WorkingDir()) + someContextFileExists, err := contextPathsExist(store.WorkingDir()) if err != nil { return false, fmt.Errorf("failed to check for context files: %w", err) } @@ -51,7 +52,7 @@ func ProjectNeedsInitialization(cfg *Config) (bool, error) { } // If the working directory has no non-ignored files, skip initialization step - empty, err := dirHasNoVisibleFiles(cfg.WorkingDir()) + empty, err := dirHasNoVisibleFiles(store.WorkingDir()) if err != nil { return false, fmt.Errorf("failed to check if directory is empty: %w", err) } @@ -90,7 +91,7 @@ func contextPathsExist(dir string) (bool, error) { return false, nil } -// dirHasNoVisibleFiles returns true if the directory has no files/dirs after applying ignore rules +// dirHasNoVisibleFiles returns true if the directory has no files/dirs after applying ignore rules. func dirHasNoVisibleFiles(dir string) (bool, error) { files, _, err := fsext.ListDirectory(dir, nil, 1, 1) if err != nil { @@ -99,11 +100,11 @@ func dirHasNoVisibleFiles(dir string) (bool, error) { return len(files) == 0, nil } -func MarkProjectInitialized(cfg *Config) error { - if cfg == nil { +func MarkProjectInitialized(store *ConfigStore) error { + if store == nil { return fmt.Errorf("config not loaded") } - flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename) + flagFilePath := filepath.Join(store.Config().Options.DataDirectory, InitFlagFilename) file, err := os.Create(flagFilePath) if err != nil { @@ -114,13 +115,13 @@ func MarkProjectInitialized(cfg *Config) error { return nil } -func HasInitialDataConfig(cfg *Config) bool { - if cfg == nil { +func HasInitialDataConfig(store *ConfigStore) bool { + if store == nil { return false } cfgPath := GlobalConfigData() if _, err := os.Stat(cfgPath); err != nil { return false } - return cfg.IsConfigured() + return store.Config().IsConfigured() } diff --git a/internal/config/load.go b/internal/config/load.go index 3fba44aa9142c52b8966b1dbe994cef0ae654c48..0c63950e84434d7bebd20e39c25734541027a4d9 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -29,8 +29,9 @@ import ( const defaultCatwalkURL = "https://catwalk.charm.sh" -// Load loads the configuration from the default paths. -func Load(workingDir, dataDir string, debug bool) (*Config, error) { +// Load loads the configuration from the default paths and returns a +// ConfigStore that owns both the pure-data Config and all runtime state. +func Load(workingDir, dataDir string, debug bool) (*ConfigStore, error) { configPaths := lookupConfigs(workingDir) cfg, err := loadFromConfigPaths(configPaths) @@ -38,10 +39,15 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err) } - cfg.dataConfigDir = GlobalConfigData() - cfg.setDefaults(workingDir, dataDir) + store := &ConfigStore{ + config: cfg, + workingDir: workingDir, + globalDataPath: GlobalConfigData(), + workspacePath: filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName)), + } + if debug { cfg.Options.Debug = true } @@ -52,6 +58,18 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { cfg.Options.Debug, ) + // Load workspace config last so it has highest priority. + if wsData, err := os.ReadFile(store.workspacePath); err == nil && len(wsData) > 0 { + merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData)) + if mergeErr == nil { + // Preserve defaults that setDefaults already applied. + dataDir := cfg.Options.DataDirectory + *cfg = *merged + cfg.setDefaults(workingDir, dataDir) + store.config = cfg + } + } + if !isInsideWorktree() { const depth = 2 const items = 100 @@ -72,26 +90,36 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { if err != nil { return nil, err } - cfg.knownProviders = providers + store.knownProviders = providers env := env.New() // Configure providers valueResolver := NewShellVariableResolver(env) - cfg.resolver = valueResolver - if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil { + store.resolver = valueResolver + if err := cfg.configureProviders(store, env, valueResolver, store.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } if !cfg.IsConfigured() { slog.Warn("No providers configured") - return cfg, nil + return store, nil } - if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil { + if err := configureSelectedModels(store, store.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } - cfg.SetupAgents() - return cfg, nil + store.SetupAgents() + return store, nil +} + +// mustMarshalConfig marshals the config to JSON bytes, returning empty JSON on +// error. +func mustMarshalConfig(cfg *Config) []byte { + data, err := json.Marshal(cfg) + if err != nil { + return []byte("{}") + } + return data } func PushPopCrushEnv() func() { @@ -122,7 +150,7 @@ func PushPopCrushEnv() func() { return restore } -func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { +func (c *Config) configureProviders(store *ConfigStore, env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() @@ -209,7 +237,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch { case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: // Claude Code subscription is not supported anymore. Remove to show onboarding. - c.RemoveConfigField("providers.anthropic") + store.RemoveConfigField(ScopeGlobal, "providers.anthropic") c.Providers.Del(string(p.ID)) continue case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil: @@ -340,7 +368,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } func (c *Config) setDefaults(workingDir, dataDir string) { - c.workingDir = workingDir if c.Options == nil { c.Options = &Options{} } @@ -524,7 +551,8 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large return largeModel, smallModel, err } -func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error { +func configureSelectedModels(store *ConfigStore, knownProviders []catwalk.Provider) error { + c := store.config defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders) if err != nil { return fmt.Errorf("failed to select default models: %w", err) @@ -543,7 +571,7 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro if model == nil { large = defaultLarge // override the model type to large - err := c.UpdatePreferredModel(SelectedModelTypeLarge, large) + err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeLarge, large) if err != nil { return fmt.Errorf("failed to update preferred large model: %w", err) } @@ -587,7 +615,7 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro if model == nil { small = defaultSmall // override the model type to small - err := c.UpdatePreferredModel(SelectedModelTypeSmall, small) + err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeSmall, small) if err != nil { return fmt.Errorf("failed to update preferred small model: %w", err) } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 93d2245193463e2a6539e23aeb0e16ac14c0ccef..62b1eaa2437116b0a051fe10183689650d388472 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -36,6 +36,11 @@ func TestConfig_LoadFromBytes(t *testing.T) { require.Equal(t, "https://api.openai.com/v2", pc.BaseURL) } +// testStore wraps a Config in a minimal ConfigStore for testing. +func testStore(cfg *Config) *ConfigStore { + return &ConfigStore{config: cfg} +} + func TestConfig_setDefaults(t *testing.T) { cfg := &Config{} @@ -53,7 +58,6 @@ func TestConfig_setDefaults(t *testing.T) { for _, path := range defaultContextPaths { require.Contains(t, cfg.Options.ContextPaths, path) } - require.Equal(t, "/tmp", cfg.workingDir) } func TestConfig_configureProviders(t *testing.T) { @@ -74,7 +78,7 @@ func TestConfig_configureProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) @@ -117,7 +121,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) @@ -159,7 +163,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Should be to because of the env variable require.Equal(t, cfg.Providers.Len(), 2) @@ -195,7 +199,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -221,7 +225,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without credentials require.Equal(t, cfg.Providers.Len(), 0) @@ -246,7 +250,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.Error(t, err) } @@ -269,7 +273,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { "VERTEXAI_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -301,7 +305,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without proper credentials require.Equal(t, cfg.Providers.Len(), 0) @@ -326,7 +330,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without project require.Equal(t, cfg.Providers.Len(), 0) @@ -350,7 +354,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -541,7 +545,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -569,7 +573,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -592,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -614,7 +618,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -639,7 +643,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -664,7 +668,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -692,7 +696,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -722,7 +726,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -757,7 +761,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { "GOOGLE_GENAI_USE_VERTEXAI": "false", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -788,7 +792,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -819,7 +823,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -852,7 +856,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -886,7 +890,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) @@ -922,7 +926,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) @@ -952,7 +956,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) require.Error(t, err) @@ -995,7 +999,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) require.NoError(t, err) @@ -1039,7 +1043,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) require.Error(t, err) @@ -1081,7 +1085,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) require.NoError(t, err) @@ -1126,7 +1130,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.ErrorContains(t, err, "no custom providers") // openai should NOT be present because it lacks base_url and models. @@ -1169,7 +1173,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Only fully specified provider should be present. @@ -1223,7 +1227,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "ANTHROPIC_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Both providers should be present. @@ -1251,7 +1255,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.ErrorContains(t, err, "no custom providers") // Provider should be rejected for missing models. @@ -1275,7 +1279,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.ErrorContains(t, err, "no custom providers") // Provider should be rejected for missing base_url. @@ -1340,10 +1344,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = configureSelectedModels(testStore(cfg), knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] @@ -1402,10 +1406,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = configureSelectedModels(testStore(cfg), knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] @@ -1447,10 +1451,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = configureSelectedModels(testStore(cfg), knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] require.Equal(t, "large-model", large.Model) diff --git a/internal/config/recent_models_test.go b/internal/config/recent_models_test.go index 739ddc0031a65cab261723772c3f38658dcd1561..7c46d5d5202927932ed154a4da8b0719ce9e114e 100644 --- a/internal/config/recent_models_test.go +++ b/internal/config/recent_models_test.go @@ -31,15 +31,23 @@ func readRecentModels(t *testing.T, path string) map[string]any { return rm } +// testStoreWithPath creates a ConfigStore backed by a Config for recent model tests. +func testStoreWithPath(cfg *Config, dir string) *ConfigStore { + return &ConfigStore{ + config: cfg, + globalDataPath: filepath.Join(dir, "config.json"), + } +} + func TestRecordRecentModel_AddsAndPersists(t *testing.T) { t.Parallel() dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) - err := cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) + err := store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) require.NoError(t, err) // in-memory state @@ -48,7 +56,7 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) { require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model) // persisted state - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 1) @@ -64,13 +72,13 @@ func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Add two entries - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) // Re-add first; should move to front and not duplicate - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) got := cfg.RecentModels[SelectedModelTypeLarge] require.Len(t, got, 2) @@ -84,7 +92,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Insert 6 unique models; max is 5 entries := []SelectedModel{ @@ -96,7 +104,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { {Provider: "p6", Model: "m6"}, } for _, e := range entries { - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, e)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, e)) } // in-memory state @@ -110,7 +118,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4]) // persisted state: verify trimmed to 5 and newest-first order - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 5) @@ -129,12 +137,12 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Missing provider - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) // Missing model - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) _, ok := cfg.RecentModels[SelectedModelTypeLarge] // Map may be initialized, but should have no entries @@ -142,8 +150,8 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0) } // No file should be written (stat via fs.FS) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(store.globalDataPath) + fileName := filepath.Base(store.globalDataPath) _, err := fs.Stat(os.DirFS(baseDir), fileName) require.True(t, os.IsNotExist(err)) } @@ -154,13 +162,13 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) entry := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, entry)) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(store.globalDataPath) + fileName := filepath.Base(store.globalDataPath) before, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -170,7 +178,7 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { beforeMod := stBefore.ModTime() // Re-record same entry should be a no-op (no write) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, entry)) after, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -188,17 +196,17 @@ func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) sel := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.UpdatePreferredModel(SelectedModelTypeSmall, sel)) + require.NoError(t, store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeSmall, sel)) // in-memory require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall]) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) // persisted (read via fs.FS) - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) small, ok := rm[string(SelectedModelTypeSmall)].([]any) require.True(t, ok) require.Len(t, small, 1) @@ -210,14 +218,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Add models to both large and small types largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"} smallModel := SelectedModel{Provider: "anthropic", Model: "claude"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, largeModel)) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeSmall, smallModel)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, largeModel)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeSmall, smallModel)) // in-memory: verify types maintain separate histories require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1) @@ -227,14 +235,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) { // Add another to large, verify small unchanged anotherLarge := SelectedModel{Provider: "google", Model: "gemini"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, anotherLarge)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, anotherLarge)) require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0]) // persisted state: verify both types exist with correct lengths and contents - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) diff --git a/internal/config/scope.go b/internal/config/scope.go new file mode 100644 index 0000000000000000000000000000000000000000..971ce32c3ed662dd0d0627c4f1c858372f3b4514 --- /dev/null +++ b/internal/config/scope.go @@ -0,0 +1,11 @@ +package config + +// Scope determines which config file is targeted for read/write operations. +type Scope int + +const ( + // ScopeGlobal targets the global data config (~/.local/share/crush/crush.json). + ScopeGlobal Scope = iota + // ScopeWorkspace targets the workspace config (.crush/crush.json). + ScopeWorkspace +) diff --git a/internal/config/store.go b/internal/config/store.go new file mode 100644 index 0000000000000000000000000000000000000000..4dfe6130bc23007ec3df12e5a88cb53bc3ad5a2d --- /dev/null +++ b/internal/config/store.go @@ -0,0 +1,336 @@ +package config + +import ( + "cmp" + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "slices" + + "charm.land/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConfigStore is the single entry point for all config access. It owns the +// pure-data Config, runtime state (working directory, resolver, known +// providers), and persistence to both global and workspace config files. +type ConfigStore struct { + config *Config + workingDir string + resolver VariableResolver + globalDataPath string // ~/.local/share/crush/crush.json + workspacePath string // .crush/crush.json + knownProviders []catwalk.Provider +} + +// Config returns the pure-data config struct (read-only after load). +func (s *ConfigStore) Config() *Config { + return s.config +} + +// WorkingDir returns the current working directory. +func (s *ConfigStore) WorkingDir() string { + return s.workingDir +} + +// Resolver returns the variable resolver. +func (s *ConfigStore) Resolver() VariableResolver { + return s.resolver +} + +// Resolve resolves a variable reference using the configured resolver. +func (s *ConfigStore) Resolve(key string) (string, error) { + if s.resolver == nil { + return "", fmt.Errorf("no variable resolver configured") + } + return s.resolver.ResolveValue(key) +} + +// KnownProviders returns the list of known providers. +func (s *ConfigStore) KnownProviders() []catwalk.Provider { + return s.knownProviders +} + +// SetupAgents configures the coder and task agents on the config. +func (s *ConfigStore) SetupAgents() { + s.config.SetupAgents() +} + +// configPath returns the file path for the given scope. +func (s *ConfigStore) configPath(scope Scope) string { + switch scope { + case ScopeWorkspace: + return s.workspacePath + default: + return s.globalDataPath + } +} + +// HasConfigField checks whether a key exists in the config file for the given +// scope. +func (s *ConfigStore) HasConfigField(scope Scope, key string) bool { + data, err := os.ReadFile(s.configPath(scope)) + if err != nil { + return false + } + return gjson.Get(string(data), key).Exists() +} + +// SetConfigField sets a key/value pair in the config file for the given scope. +func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error { + path := s.configPath(scope) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + data = []byte("{}") + } else { + return fmt.Errorf("failed to read config file: %w", err) + } + } + + newValue, err := sjson.Set(string(data), key, value) + if err != nil { + return fmt.Errorf("failed to set config field %s: %w", key, err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("failed to create config directory %q: %w", path, err) + } + if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + return nil +} + +// RemoveConfigField removes a key from the config file for the given scope. +func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error { + path := s.configPath(scope) + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to read config file: %w", err) + } + + newValue, err := sjson.Delete(string(data), key) + if err != nil { + return fmt.Errorf("failed to delete config field %s: %w", key, err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("failed to create config directory %q: %w", path, err) + } + if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + return nil +} + +// UpdatePreferredModel updates the preferred model for the given type and +// persists it to the config file at the given scope. +func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error { + s.config.Models[modelType] = model + if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil { + return fmt.Errorf("failed to update preferred model: %w", err) + } + if err := s.recordRecentModel(scope, modelType, model); err != nil { + return err + } + return nil +} + +// SetCompactMode sets the compact mode setting and persists it. +func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error { + if s.config.Options == nil { + s.config.Options = &Options{} + } + s.config.Options.TUI.CompactMode = enabled + return s.SetConfigField(scope, "options.tui.compact_mode", enabled) +} + +// SetProviderAPIKey sets the API key for a provider and persists it. +func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error { + var providerConfig ProviderConfig + var exists bool + var setKeyOrToken func() + + switch v := apiKey.(type) { + case string: + if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { + return fmt.Errorf("failed to save api key to config file: %w", err) + } + setKeyOrToken = func() { providerConfig.APIKey = v } + case *oauth.Token: + if err := cmp.Or( + s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v), + ); err != nil { + return err + } + setKeyOrToken = func() { + providerConfig.APIKey = v.AccessToken + providerConfig.OAuthToken = v + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + } + } + + providerConfig, exists = s.config.Providers.Get(providerID) + if exists { + setKeyOrToken() + s.config.Providers.Set(providerID, providerConfig) + return nil + } + + var foundProvider *catwalk.Provider + for _, p := range s.knownProviders { + if string(p.ID) == providerID { + foundProvider = &p + break + } + } + + if foundProvider != nil { + providerConfig = ProviderConfig{ + ID: providerID, + Name: foundProvider.Name, + BaseURL: foundProvider.APIEndpoint, + Type: foundProvider.Type, + Disable: false, + ExtraHeaders: make(map[string]string), + ExtraParams: make(map[string]string), + Models: foundProvider.Models, + } + setKeyOrToken() + } else { + return fmt.Errorf("provider with ID %s not found in known providers", providerID) + } + s.config.Providers.Set(providerID, providerConfig) + return nil +} + +// RefreshOAuthToken refreshes the OAuth token for the given provider. +func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error { + providerConfig, exists := s.config.Providers.Get(providerID) + if !exists { + return fmt.Errorf("provider %s not found", providerID) + } + + if providerConfig.OAuthToken == nil { + return fmt.Errorf("provider %s does not have an OAuth token", providerID) + } + + var newToken *oauth.Token + var refreshErr error + switch providerID { + case string(catwalk.InferenceProviderCopilot): + newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + case hyperp.Name: + newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + default: + return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) + } + if refreshErr != nil { + return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) + } + + slog.Info("Successfully refreshed OAuth token", "provider", providerID) + providerConfig.OAuthToken = newToken + providerConfig.APIKey = newToken.AccessToken + + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + + s.config.Providers.Set(providerID, providerConfig) + + if err := cmp.Or( + s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken), + ); err != nil { + return fmt.Errorf("failed to persist refreshed token: %w", err) + } + + return nil +} + +// recordRecentModel records a model in the recent models list. +func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error { + if model.Provider == "" || model.Model == "" { + return nil + } + + if s.config.RecentModels == nil { + s.config.RecentModels = make(map[SelectedModelType][]SelectedModel) + } + + eq := func(a, b SelectedModel) bool { + return a.Provider == b.Provider && a.Model == b.Model + } + + entry := SelectedModel{ + Provider: model.Provider, + Model: model.Model, + } + + current := s.config.RecentModels[modelType] + withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { + return eq(existing, entry) + }) + + updated := append([]SelectedModel{entry}, withoutCurrent...) + if len(updated) > maxRecentModelsPerType { + updated = updated[:maxRecentModelsPerType] + } + + if slices.EqualFunc(current, updated, eq) { + return nil + } + + s.config.RecentModels[modelType] = updated + + if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { + return fmt.Errorf("failed to persist recent models: %w", err) + } + + return nil +} + +// ImportCopilot attempts to import a GitHub Copilot token from disk. +func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) { + if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") { + return nil, false + } + + diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() + if !hasDiskToken { + return nil, false + } + + slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") + token, err := copilot.RefreshToken(context.TODO(), diskToken) + if err != nil { + slog.Error("Unable to import GitHub Copilot token", "error", err) + return nil, false + } + + if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil { + return token, false + } + + if err := cmp.Or( + s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken), + s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token), + ); err != nil { + slog.Error("Unable to save GitHub Copilot token to disk", "error", err) + } + + slog.Info("GitHub Copilot successfully imported") + return token, true +} diff --git a/internal/lsp/manager.go b/internal/lsp/manager.go index 13a78cef2a471a71c1e741e32e08e8d7edcb7484..b564c0e602c0234462a32cfaae67c8f8179551c4 100644 --- a/internal/lsp/manager.go +++ b/internal/lsp/manager.go @@ -26,18 +26,18 @@ var unavailable = csync.NewMap[string, struct{}]() // Manager handles lazy initialization of LSP clients based on file types. type Manager struct { clients *csync.Map[string, *Client] - cfg *config.Config + cfg *config.ConfigStore manager *powernapconfig.Manager callback func(name string, client *Client) } // NewManager creates a new LSP manager service. -func NewManager(cfg *config.Config) *Manager { +func NewManager(cfg *config.ConfigStore) *Manager { manager := powernapconfig.NewManager() manager.LoadDefaults() // Merge user-configured LSPs into the manager. - for name, clientConfig := range cfg.LSP { + for name, clientConfig := range cfg.Config().LSP { if clientConfig.Disabled { slog.Debug("LSP disabled by user config", "name", name) manager.RemoveServer(name) @@ -194,7 +194,7 @@ func (s *Manager) startServer(ctx context.Context, name, filepath string, server cfg, s.cfg.Resolver(), s.cfg.WorkingDir(), - s.cfg.Options.DebugLSP, + s.cfg.Config().Options.DebugLSP, ) if err != nil { slog.Error("Failed to create LSP client", "name", name, "error", err) @@ -244,7 +244,7 @@ func (s *Manager) startServer(ctx context.Context, name, filepath string, server } func (s *Manager) isUserConfigured(name string) bool { - cfg, ok := s.cfg.LSP[name] + cfg, ok := s.cfg.Config().LSP[name] return ok && !cfg.Disabled } @@ -258,7 +258,7 @@ func (s *Manager) buildConfig(name string, server *powernapconfig.ServerConfig) InitOptions: server.InitOptions, Options: server.Settings, } - if userCfg, ok := s.cfg.LSP[name]; ok { + if userCfg, ok := s.cfg.Config().LSP[name]; ok { cfg.Timeout = userCfg.Timeout } return cfg diff --git a/internal/ui/common/common.go b/internal/ui/common/common.go index 6e7c632474389aa5455295e4132818941bc18244..143b20305464da33d2f350a36176bab0e45b85aa 100644 --- a/internal/ui/common/common.go +++ b/internal/ui/common/common.go @@ -26,11 +26,16 @@ type Common struct { Styles *styles.Styles } -// Config returns the configuration associated with this [Common] instance. +// Config returns the pure-data configuration associated with this [Common] instance. func (c *Common) Config() *config.Config { return c.App.Config() } +// Store returns the config store associated with this [Common] instance. +func (c *Common) Store() *config.ConfigStore { + return c.App.Store() +} + // DefaultCommon returns the default common UI configurations. func DefaultCommon(app *app.App) *Common { s := styles.DefaultStyles() diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go index 9677763b2f4f2436376f5bf16ab58aed79140c68..cc37d742903d5a80bbcffcf1ff24fb24596dfccd 100644 --- a/internal/ui/dialog/api_key_input.go +++ b/internal/ui/dialog/api_key_input.go @@ -296,7 +296,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { Type: m.provider.Type, BaseURL: m.provider.APIEndpoint, } - err := providerConfig.TestConnection(m.com.Config().Resolver()) + err := providerConfig.TestConnection(m.com.Store().Resolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner elapsed := time.Since(start) @@ -312,9 +312,9 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { } func (m *APIKeyInput) saveKeyAndContinue() Action { - cfg := m.com.Config() + store := m.com.Store() - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value()) + err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value()) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/dialog/filepicker.go b/internal/ui/dialog/filepicker.go index 4b0b844e4ed869a4347af10e9d0b1b3c70a7d2f0..78f82a05f7e2e0db7a9bb561fb1b6248d8045513 100644 --- a/internal/ui/dialog/filepicker.go +++ b/internal/ui/dialog/filepicker.go @@ -123,7 +123,7 @@ func (f *FilePicker) SetImageCapabilities(caps *common.Capabilities) { // WorkingDir returns the current working directory of the [FilePicker]. func (f *FilePicker) WorkingDir() string { - wd := f.com.Config().WorkingDir() + wd := f.com.Store().WorkingDir() if len(wd) > 0 { return wd } diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 977f04a61e98f79adb9bb35777fac905508f47d5..434f699e91b4c227c4e54f6ff553affff76a1c43 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -490,7 +490,7 @@ func (m *Models) setProviderItems() error { if len(validRecentItems) != len(recentItems) { // FIXME: Does this need to be here? Is it mutating the config during a read? - if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { + if err := m.com.Store().SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { return fmt.Errorf("failed to update recent models: %w", err) } } diff --git a/internal/ui/dialog/oauth.go b/internal/ui/dialog/oauth.go index 93d5fe052db11d036d29d7790810807d5630bb57..2803070381e65bd0380a8ddab5f256481c117c15 100644 --- a/internal/ui/dialog/oauth.go +++ b/internal/ui/dialog/oauth.go @@ -373,9 +373,9 @@ func (d *OAuth) copyCodeAndOpenURL() tea.Cmd { } func (m *OAuth) saveKeyAndContinue() Action { - cfg := m.com.Config() + store := m.com.Store() - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token) + err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.token) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/model/header.go b/internal/ui/model/header.go index 24254a0f69e5803e4bcbe89274f21db5b04ef541..06bb4ff92981b28625efb11683081e29fc55a21e 100644 --- a/internal/ui/model/header.go +++ b/internal/ui/model/header.go @@ -143,8 +143,7 @@ func renderHeaderDetails( metadata = dot + metadata const dirTrimLimit = 4 - cfg := com.Config() - cwd := fsext.DirTrim(fsext.PrettyPath(cfg.WorkingDir()), dirTrimLimit) + cwd := fsext.DirTrim(fsext.PrettyPath(com.Store().WorkingDir()), dirTrimLimit) cwd = t.Header.WorkingDir.Render(cwd) result := cwd + metadata diff --git a/internal/ui/model/landing.go b/internal/ui/model/landing.go index 45d376ff5ddc691b978e438ddef04a702af100f9..72c2671ccd297f4bade087f6b2cb960f6c6a92a9 100644 --- a/internal/ui/model/landing.go +++ b/internal/ui/model/landing.go @@ -22,7 +22,7 @@ func (m *UI) selectedLargeModel() *agent.Model { func (m *UI) landingView() string { t := m.com.Styles width := m.layout.main.Dx() - cwd := common.PrettyPath(t, m.com.Config().WorkingDir(), width) + cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width) parts := []string{ cwd, diff --git a/internal/ui/model/onboarding.go b/internal/ui/model/onboarding.go index 075067d75333fc539152f0041b4e5a3c2eed1c5e..5bba37ea8599944df77602014aa8c8d61dd73e80 100644 --- a/internal/ui/model/onboarding.go +++ b/internal/ui/model/onboarding.go @@ -19,7 +19,7 @@ import ( // markProjectInitialized marks the current project as initialized in the config. func (m *UI) markProjectInitialized() tea.Msg { // TODO: handle error so we show it in the tui footer - err := config.MarkProjectInitialized(m.com.Config()) + err := config.MarkProjectInitialized(m.com.Store()) if err != nil { slog.Error(err.Error()) } @@ -52,10 +52,10 @@ func (m *UI) initializeProject() tea.Cmd { if cmd := m.newSession(); cmd != nil { cmds = append(cmds, cmd) } - cfg := m.com.Config() + cfg := m.com.Store() initialize := func() tea.Msg { - initPrompt, err := agent.InitializePrompt(*cfg) + initPrompt, err := agent.InitializePrompt(cfg) if err != nil { return util.InfoMsg{Type: util.InfoTypeError, Msg: err.Error()} } @@ -77,10 +77,9 @@ func (m *UI) skipInitializeProject() tea.Cmd { // initializeView renders the project initialization prompt with Yes/No buttons. func (m *UI) initializeView() string { - cfg := m.com.Config() s := m.com.Styles.Initialize - cwd := home.Short(cfg.WorkingDir()) - initFile := cfg.Options.InitializeAs + cwd := home.Short(m.com.Store().WorkingDir()) + initFile := m.com.Config().Options.InitializeAs header := s.Header.Render("Would you like to initialize this project?") path := s.Accent.PaddingLeft(2).Render(cwd) diff --git a/internal/ui/model/sidebar.go b/internal/ui/model/sidebar.go index 88113a593034b09ed8d2859bc7628a103f5728b1..8849d86a8e1c8bda02092e3f165e85b8e32a8b1d 100644 --- a/internal/ui/model/sidebar.go +++ b/internal/ui/model/sidebar.go @@ -112,7 +112,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) { height := area.Dy() title := t.Muted.Width(width).MaxHeight(2).Render(m.session.Title) - cwd := common.PrettyPath(t, m.com.Config().WorkingDir(), width) + cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width) sidebarLogo := m.sidebarLogo if height < logoHeightBreakpoint { sidebarLogo = logo.SmallRender(m.com.Styles, width) @@ -138,7 +138,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) { lspSection := m.lspInfo(width, maxLSPs, true) mcpSection := m.mcpInfo(width, maxMCPs, true) - filesSection := m.filesInfo(m.com.Config().WorkingDir(), width, maxFiles, true) + filesSection := m.filesInfo(m.com.Store().WorkingDir(), width, maxFiles, true) uv.NewStyledString( lipgloss.NewStyle(). diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 89b3b37608500f1a02eea98d4ebfabeba262bcd1..66d57321824833e91b78539357d55013e4322e87 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -317,7 +317,7 @@ func New(com *common.Common) *UI { desiredFocus := uiFocusEditor if !com.Config().IsConfigured() { desiredState = uiOnboarding - } else if n, _ := config.ProjectNeedsInitialization(com.Config()); n { + } else if n, _ := config.ProjectNeedsInitialization(com.Store()); n { desiredState = uiInitialize } @@ -579,7 +579,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case mcp.EventPromptsListChanged: return m, handleMCPPromptsEvent(msg.Payload.Name) case mcp.EventToolsListChanged: - return m, handleMCPToolsEvent(m.com.Config(), msg.Payload.Name) + return m, handleMCPToolsEvent(m.com.Store(), msg.Payload.Name) case mcp.EventResourcesListChanged: return m, handleMCPResourcesEvent(msg.Payload.Name) } @@ -1301,7 +1301,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.Think = !currentModel.Think - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { return util.ReportError(err)() } m.com.App.UpdateAgentModel(context.TODO()) @@ -1342,7 +1342,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { // Attempt to import GitHub Copilot tokens from VSCode if available. if isCopilot && !isConfigured() && !msg.ReAuthenticate { - m.com.Config().ImportCopilot() + m.com.Store().ImportCopilot() } if !isConfigured() || msg.ReAuthenticate { @@ -1353,12 +1353,12 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } - if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, msg.ModelType, msg.Model); err != nil { cmds = append(cmds, util.ReportError(err)) } else if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok { // Ensure small model is set is unset. smallModel := m.com.App.GetDefaultSmallModel(providerID) - if err := cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, config.SelectedModelTypeSmall, smallModel); err != nil { cmds = append(cmds, util.ReportError(err)) } } @@ -1404,7 +1404,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.ReasoningEffort = msg.Effort - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { cmds = append(cmds, util.ReportError(err)) break } @@ -2016,7 +2016,7 @@ func (m *UI) View() tea.View { } v.MouseMode = tea.MouseModeCellMotion v.ReportFocus = m.caps.ReportFocusEvents - v.WindowTitle = "crush " + home.Short(m.com.Config().WorkingDir()) + v.WindowTitle = "crush " + home.Short(m.com.Store().WorkingDir()) canvas := uv.NewScreenBuffer(m.width, m.height) v.Cursor = m.Draw(canvas, canvas.Bounds()) @@ -2255,7 +2255,7 @@ func (m *UI) FullHelp() [][]key.Binding { func (m *UI) toggleCompactMode() tea.Cmd { m.forceCompactMode = !m.forceCompactMode - err := m.com.Config().SetCompactMode(m.forceCompactMode) + err := m.com.Store().SetCompactMode(config.ScopeGlobal, m.forceCompactMode) if err != nil { return util.ReportError(err) } @@ -2637,7 +2637,7 @@ func (m *UI) insertMCPResourceCompletion(item completions.ResourceCompletionValu return func() tea.Msg { contents, err := mcp.ReadResource( context.Background(), - m.com.Config(), + m.com.Store(), item.MCPName, item.URI, ) @@ -3299,7 +3299,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { lspSection := m.lspInfo(sectionWidth, maxItemsPerSection, false) mcpSection := m.mcpInfo(sectionWidth, maxItemsPerSection, false) - filesSection := m.filesInfo(m.com.Config().WorkingDir(), sectionWidth, maxItemsPerSection, false) + filesSection := m.filesInfo(m.com.Store().WorkingDir(), sectionWidth, maxItemsPerSection, false) sections := lipgloss.JoinHorizontal(lipgloss.Top, filesSection, " ", lspSection, " ", mcpSection) uv.NewStyledString( s.CompactDetails.View. @@ -3317,7 +3317,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string) tea.Cmd { load := func() tea.Msg { - prompt, err := commands.GetMCPPrompt(m.com.Config(), clientID, promptID, arguments) + prompt, err := commands.GetMCPPrompt(m.com.Store(), clientID, promptID, arguments) if err != nil { // TODO: make this better return util.ReportError(err)() @@ -3358,7 +3358,7 @@ func handleMCPPromptsEvent(name string) tea.Cmd { } } -func handleMCPToolsEvent(cfg *config.Config, name string) tea.Cmd { +func handleMCPToolsEvent(cfg *config.ConfigStore, name string) tea.Cmd { return func() tea.Msg { mcp.RefreshTools( context.Background(),