diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 1a7286e342d245c7e7ac1161111d8c205300018b..0d7677dee702b813e0a0d6f02e67837f084d5c29 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -24,7 +24,7 @@ const ( ) func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) { - agentCfg, ok := c.cfg.Agents[config.AgentTask] + agentCfg, ok := c.cfg.Config().Agents[config.AgentTask] if !ok { return nil, errors.New("task agent not configured") } diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index 0bd942e013b706389fb90352c891a4f2ea014f30..ffbe0f49e45c259db3f0bba9f07fda771ad3ecd4 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -98,7 +98,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( return fantasy.ToolResponse{}, permission.ErrorPermissionDenied } - tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*") + tmpDir, err := os.MkdirTemp(c.cfg.Config().Options.DataDirectory, "crush-fetch-*") if err != nil { return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil } @@ -151,12 +151,12 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err) } - systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg) + systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), c.cfg) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err) } - smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider) + smallProviderCfg, ok := c.cfg.Config().Providers.Get(small.ModelCfg.Provider) if !ok { return fantasy.ToolResponse{}, errors.New("small model provider not configured") } @@ -167,7 +167,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( webFetchTool, webSearchTool, tools.NewGlobTool(tmpDir), - tools.NewGrepTool(tmpDir, c.cfg.Tools.Grep), + tools.NewGrepTool(tmpDir, c.cfg.Config().Tools.Grep), tools.NewSourcegraphTool(client), tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, tmpDir), } @@ -177,7 +177,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( SmallModel: small, SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix, SystemPrompt: systemPrompt, - DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize, + DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize, IsYolo: c.permissions.SkipRequests(), Sessions: c.sessions, Messages: c.messages, diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index 89fc6ff3d29d27c60a8091f17ebe0fad057dc44a..132c27d21aee81bd3930c469963f1d73885d58a7 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -185,36 +185,36 @@ func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel // NOTE(@andreynering): Set a fixed config to ensure cassettes match // independently of user config on `$HOME/.config/crush/crush.json`. - cfg.Options.Attribution = &config.Attribution{ + cfg.Config().Options.Attribution = &config.Attribution{ TrailerStyle: "co-authored-by", GeneratedWith: true, } // Clear some fields to avoid issues with VCR cassette matching. - cfg.Options.SkillsPaths = nil - cfg.Options.ContextPaths = nil - cfg.LSP = nil + cfg.Config().Options.SkillsPaths = nil + cfg.Config().Options.ContextPaths = nil + cfg.Config().LSP = nil - systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg) + systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), cfg) if err != nil { return nil, err } // Get the model name for the bash tool modelName := large.Model() // fallback to ID if Name not available - if model := cfg.GetModel(large.Provider(), large.Model()); model != nil { + if model := cfg.Config().GetModel(large.Provider(), large.Model()); model != nil { modelName = model.Name } allTools := []fantasy.AgentTool{ - tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName), + tools.NewBashTool(env.permissions, env.workingDir, cfg.Config().Options.Attribution, modelName), tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()), tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir), tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir), tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()), tools.NewGlobTool(env.workingDir), - tools.NewGrepTool(env.workingDir, cfg.Tools.Grep), - tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls), + tools.NewGrepTool(env.workingDir, cfg.Config().Tools.Grep), + tools.NewLsTool(env.permissions, env.workingDir, cfg.Config().Tools.Ls), tools.NewSourcegraphTool(r.GetDefaultClient()), tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir), tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir), diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 3968952ae4e10bd59e596d02797a845d943bd378..4bca96d5946630423fc532ac6ccb5be833638dd2 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -74,7 +74,7 @@ type Coordinator interface { } type coordinator struct { - cfg *config.Config + cfg *config.ConfigStore sessions session.Service messages message.Service permissions permission.Service @@ -91,7 +91,7 @@ type coordinator struct { func NewCoordinator( ctx context.Context, - cfg *config.Config, + cfg *config.ConfigStore, sessions session.Service, messages message.Service, permissions permission.Service, @@ -112,7 +112,7 @@ func NewCoordinator( agents: make(map[string]SessionAgent), } - agentCfg, ok := cfg.Agents[config.AgentCoder] + agentCfg, ok := cfg.Config().Agents[config.AgentCoder] if !ok { return nil, errCoderAgentNotConfigured } @@ -160,7 +160,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments = filteredAttachments } - providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider) + providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider) if !ok { return nil, errModelProviderNotConfigured } @@ -383,14 +383,14 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age return nil, err } - largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider) + largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider) result := NewSessionAgent(SessionAgentOptions{ LargeModel: large, SmallModel: small, SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix, SystemPrompt: "", IsSubAgent: isSubAgent, - DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize, + DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize, IsYolo: c.permissions.SkipRequests(), Sessions: c.sessions, Messages: c.messages, @@ -399,7 +399,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age }) c.readyWg.Go(func() error { - systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg) + systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg) if err != nil { return err } @@ -439,14 +439,14 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan // Get the model name for the agent modelName := "" - if modelCfg, ok := c.cfg.Models[agent.Model]; ok { - if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil { + if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok { + if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil { modelName = model.Name } } allTools = append(allTools, - tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName), + tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName), tools.NewJobOutputTool(), tools.NewJobKillTool(), tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil), @@ -454,20 +454,20 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil), tools.NewGlobTool(c.cfg.WorkingDir()), - tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep), - tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls), + tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep), + tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls), tools.NewSourcegraphTool(nil), tools.NewTodosTool(c.sessions), - tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...), + tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...), tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), ) // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true). - if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP { + if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP { allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager)) } - if len(c.cfg.MCP) > 0 { + if len(c.cfg.Config().MCP) > 0 { allTools = append( allTools, tools.NewListMCPResourcesTool(c.cfg, c.permissions), @@ -513,16 +513,16 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) { - largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge] + largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge] if !ok { return Model{}, Model{}, errLargeModelNotSelected } - smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall] + smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall] if !ok { return Model{}, Model{}, errSmallModelNotSelected } - largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider) + largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider) if !ok { return Model{}, Model{}, errLargeModelProviderNotConfigured } @@ -532,7 +532,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Mo return Model{}, Model{}, err } - smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider) + smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider) if !ok { return Model{}, Model{}, errSmallModelProviderNotConfigured } @@ -620,7 +620,7 @@ func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map opts = append(opts, anthropic.WithBaseURL(baseURL)) } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, anthropic.WithHTTPClient(httpClient)) } @@ -632,7 +632,7 @@ func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[st openai.WithAPIKey(apiKey), openai.WithUseResponsesAPI(), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, openai.WithHTTPClient(httpClient)) } @@ -649,7 +649,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri opts := []openrouter.Option{ openrouter.WithAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, openrouter.WithHTTPClient(httpClient)) } @@ -663,7 +663,7 @@ func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]s opts := []vercel.Option{ vercel.WithAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, vercel.WithHTTPClient(httpClient)) } @@ -683,8 +683,8 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers var httpClient *http.Client if providerID == string(catwalk.InferenceProviderCopilot) { opts = append(opts, openaicompat.WithUseResponsesAPI()) - httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug) - } else if c.cfg.Options.Debug { + httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug) + } else if c.cfg.Config().Options.Debug { httpClient = log.NewHTTPClient() } if httpClient != nil { @@ -708,7 +708,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str azure.WithAPIKey(apiKey), azure.WithUseResponsesAPI(), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, azure.WithHTTPClient(httpClient)) } @@ -727,7 +727,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) { var opts []bedrock.Option - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, bedrock.WithHTTPClient(httpClient)) } @@ -746,7 +746,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st google.WithBaseURL(baseURL), google.WithGeminiAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, google.WithHTTPClient(httpClient)) } @@ -758,7 +758,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) { opts := []google.Option{} - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, google.WithHTTPClient(httpClient)) } @@ -779,7 +779,7 @@ func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provid hyper.WithBaseURL(baseURL), hyper.WithAPIKey(apiKey), } - if c.cfg.Options.Debug { + if c.cfg.Config().Options.Debug { httpClient := log.NewHTTPClient() opts = append(opts, hyper.WithHTTPClient(httpClient)) } @@ -887,7 +887,7 @@ func (c *coordinator) UpdateModels(ctx context.Context) error { } c.currentAgent.SetModels(large, small) - agentCfg, ok := c.cfg.Agents[config.AgentCoder] + agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder] if !ok { return errCoderAgentNotConfigured } @@ -909,7 +909,7 @@ func (c *coordinator) QueuedPromptsList(sessionID string) []string { } func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { - providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider) + providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider) if !ok { return errModelProviderNotConfigured } @@ -922,7 +922,7 @@ func (c *coordinator) isUnauthorized(err error) bool { } func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error { - if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { + if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil { slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err) return err } @@ -940,7 +940,7 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con } providerCfg.APIKey = newAPIKey - c.cfg.Providers.Set(providerCfg.ID, providerCfg) + c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg) if err := c.UpdateModels(ctx); err != nil { return err @@ -984,7 +984,7 @@ func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (f maxTokens = model.ModelCfg.MaxTokens } - providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider) + providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider) if !ok { return fantasy.ToolResponse{}, errModelProviderNotConfigured } diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go index 3c270394cba9c1758e4a9029a149027af6bf36c2..657575b6458d7fb815c7a9646a9d605c8b89ec42 100644 --- a/internal/agent/coordinator_test.go +++ b/internal/agent/coordinator_test.go @@ -44,7 +44,7 @@ func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOp func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator { cfg, err := config.Init(env.workingDir, "", false) require.NoError(t, err) - cfg.Providers.Set(providerID, providerCfg) + cfg.Config().Providers.Set(providerID, providerCfg) return &coordinator{ cfg: cfg, sessions: env.sessions, diff --git a/internal/agent/prompt/prompt.go b/internal/agent/prompt/prompt.go index d68c7c132116c49cd004bee52169be7487133efa..c8f488319f04238c476aae4719728fb94521695e 100644 --- a/internal/agent/prompt/prompt.go +++ b/internal/agent/prompt/prompt.go @@ -76,13 +76,13 @@ func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) { return p, nil } -func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) { +func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) { t, err := template.New(p.name).Parse(p.template) if err != nil { return "", fmt.Errorf("parsing template: %w", err) } var sb strings.Builder - d, err := p.promptData(ctx, provider, model, cfg) + d, err := p.promptData(ctx, provider, model, store) if err != nil { return "", err } @@ -104,11 +104,11 @@ func processFile(filePath string) *ContextFile { } } -func processContextPath(p string, cfg config.Config) []ContextFile { +func processContextPath(p string, store *config.ConfigStore) []ContextFile { var contexts []ContextFile fullPath := p if !filepath.IsAbs(p) { - fullPath = filepath.Join(cfg.WorkingDir(), p) + fullPath = filepath.Join(store.WorkingDir(), p) } info, err := os.Stat(fullPath) if err != nil { @@ -136,11 +136,11 @@ func processContextPath(p string, cfg config.Config) []ContextFile { } // expandPath expands ~ and environment variables in file paths -func expandPath(path string, cfg config.Config) string { +func expandPath(path string, store *config.ConfigStore) string { path = home.Long(path) // Handle environment variable expansion using the same pattern as config if strings.HasPrefix(path, "$") { - if expanded, err := cfg.Resolver().ResolveValue(path); err == nil { + if expanded, err := store.Resolver().ResolveValue(path); err == nil { path = expanded } } @@ -148,19 +148,20 @@ func expandPath(path string, cfg config.Config) string { return path } -func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) { - workingDir := cmp.Or(p.workingDir, cfg.WorkingDir()) +func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) { + workingDir := cmp.Or(p.workingDir, store.WorkingDir()) platform := cmp.Or(p.platform, runtime.GOOS) files := map[string][]ContextFile{} + cfg := store.Config() for _, pth := range cfg.Options.ContextPaths { - expanded := expandPath(pth, cfg) + expanded := expandPath(pth, store) pathKey := strings.ToLower(expanded) if _, ok := files[pathKey]; ok { continue } - content := processContextPath(expanded, cfg) + content := processContextPath(expanded, store) files[pathKey] = content } @@ -169,18 +170,18 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg con if len(cfg.Options.SkillsPaths) > 0 { expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths)) for _, pth := range cfg.Options.SkillsPaths { - expandedPaths = append(expandedPaths, expandPath(pth, cfg)) + expandedPaths = append(expandedPaths, expandPath(pth, store)) } if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 { availSkillXML = skills.ToPromptXML(discoveredSkills) } } - isGit := isGitRepo(cfg.WorkingDir()) + isGit := isGitRepo(store.WorkingDir()) data := PromptDat{ Provider: provider, Model: model, - Config: cfg, + Config: *cfg, WorkingDir: filepath.ToSlash(workingDir), IsGitRepo: isGit, Platform: platform, @@ -189,7 +190,7 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg con } if isGit { var err error - data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir()) + data.GitStatus, err = getGitStatus(ctx, store.WorkingDir()) if err != nil { return PromptDat{}, err } diff --git a/internal/agent/prompts.go b/internal/agent/prompts.go index 577d32e4e274d9cb8274bd862af583208a613f08..448fe0425c3b700b1d6edafc842c4815ad3d5760 100644 --- a/internal/agent/prompts.go +++ b/internal/agent/prompts.go @@ -33,7 +33,7 @@ func taskPrompt(opts ...prompt.Option) (*prompt.Prompt, error) { return systemPrompt, nil } -func InitializePrompt(cfg config.Config) (string, error) { +func InitializePrompt(cfg *config.ConfigStore) (string, error) { systemPrompt, err := prompt.NewPrompt("initialize", string(initializePromptTmpl)) if err != nil { return "", err diff --git a/internal/agent/tools/list_mcp_resources.go b/internal/agent/tools/list_mcp_resources.go index 032d1eb1888a65e9a14daecc3b503698a6fa60d4..7ea8998a1dc80955b2a5b0a79d4aef7d19fb9011 100644 --- a/internal/agent/tools/list_mcp_resources.go +++ b/internal/agent/tools/list_mcp_resources.go @@ -28,7 +28,7 @@ const ListMCPResourcesToolName = "list_mcp_resources" //go:embed list_mcp_resources.md var listMCPResourcesDescription []byte -func NewListMCPResourcesTool(cfg *config.Config, permissions permission.Service) fantasy.AgentTool { +func NewListMCPResourcesTool(cfg *config.ConfigStore, permissions permission.Service) fantasy.AgentTool { return fantasy.NewParallelAgentTool( ListMCPResourcesToolName, string(listMCPResourcesDescription), diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go index 429cadaf6b686b83e170ef35976881d839b07e17..e1184118552ee62e75f60c6943f59ecca2868563 100644 --- a/internal/agent/tools/mcp-tools.go +++ b/internal/agent/tools/mcp-tools.go @@ -11,7 +11,7 @@ import ( ) // GetMCPTools gets all the currently available MCP tools. -func GetMCPTools(permissions permission.Service, cfg *config.Config, wd string) []*Tool { +func GetMCPTools(permissions permission.Service, cfg *config.ConfigStore, wd string) []*Tool { var result []*Tool for mcpName, tools := range mcp.Tools() { for _, tool := range tools { @@ -31,7 +31,7 @@ func GetMCPTools(permissions permission.Service, cfg *config.Config, wd string) type Tool struct { mcpName string tool *mcp.Tool - cfg *config.Config + cfg *config.ConfigStore permissions permission.Service workingDir string providerOptions fantasy.ProviderOptions diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index f8cfe0ce84bf7b1987496607d42753b8ca72263f..cba9a51c717b1866b823762f85bfadf90e1a7a10 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -163,11 +163,11 @@ func Close(ctx context.Context) error { } // Initialize initializes MCP clients based on the provided configuration. -func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) { +func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) { slog.Info("Initializing MCP clients") var wg sync.WaitGroup // Initialize states for all configured MCPs - for name, m := range cfg.MCP { + for name, m := range cfg.Config().MCP { if m.Disabled { updateState(name, StateDisabled, nil, nil, Counts{}) slog.Debug("Skipping disabled MCP", "name", name) @@ -253,13 +253,13 @@ func WaitForInit(ctx context.Context) error { } } -func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) { +func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) { sess, ok := sessions.Get(name) if !ok { return nil, fmt.Errorf("mcp '%s' not available", name) } - m := cfg.MCP[name] + m := cfg.Config().MCP[name] state, _ := states.Get(name) timeout := mcpTimeout(m) diff --git a/internal/agent/tools/mcp/prompts.go b/internal/agent/tools/mcp/prompts.go index 2b39d5dc2db43aff418c3dd7561edbcebd6af865..d84be303ecb103d4fdd37423b7b6d088374d2c70 100644 --- a/internal/agent/tools/mcp/prompts.go +++ b/internal/agent/tools/mcp/prompts.go @@ -20,7 +20,7 @@ func Prompts() iter.Seq2[string, []*Prompt] { } // GetPromptMessages retrieves the content of an MCP prompt with the given arguments. -func GetPromptMessages(ctx context.Context, cfg *config.Config, clientName, promptName string, args map[string]string) ([]string, error) { +func GetPromptMessages(ctx context.Context, cfg *config.ConfigStore, clientName, promptName string, args map[string]string) ([]string, error) { c, err := getOrRenewClient(ctx, cfg, clientName) if err != nil { return nil, err diff --git a/internal/agent/tools/mcp/resources.go b/internal/agent/tools/mcp/resources.go index 8e2bcc796b28c698481dd90b0c70511273f7c98d..21616761e81212960f4d6ad59da1505049abbffb 100644 --- a/internal/agent/tools/mcp/resources.go +++ b/internal/agent/tools/mcp/resources.go @@ -24,7 +24,7 @@ func Resources() iter.Seq2[string, []*Resource] { } // ListResources returns the current resources for an MCP server. -func ListResources(ctx context.Context, cfg *config.Config, name string) ([]*Resource, error) { +func ListResources(ctx context.Context, cfg *config.ConfigStore, name string) ([]*Resource, error) { session, err := getOrRenewClient(ctx, cfg, name) if err != nil { return nil, err @@ -43,7 +43,7 @@ func ListResources(ctx context.Context, cfg *config.Config, name string) ([]*Res } // ReadResource reads the contents of a resource from an MCP server. -func ReadResource(ctx context.Context, cfg *config.Config, name, uri string) ([]*ResourceContents, error) { +func ReadResource(ctx context.Context, cfg *config.ConfigStore, name, uri string) ([]*ResourceContents, error) { session, err := getOrRenewClient(ctx, cfg, name) if err != nil { return nil, err diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index b6e208f7ccb3363bee0a0b60ef56c103ad9cd41b..8d1d2649ba4381e14fa8d99933f1dfb3b42d27ae 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -32,7 +32,7 @@ func Tools() iter.Seq2[string, []*Tool] { } // RunTool runs an MCP tool with the given input parameters. -func RunTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (ToolResult, error) { +func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string, input string) (ToolResult, error) { var args map[string]any if err := json.Unmarshal([]byte(input), &args); err != nil { return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err) @@ -108,7 +108,7 @@ func RunTool(ctx context.Context, cfg *config.Config, name, toolName string, inp // RefreshTools gets the updated list of tools from the MCP and updates the // global state. -func RefreshTools(ctx context.Context, cfg *config.Config, name string) { +func RefreshTools(ctx context.Context, cfg *config.ConfigStore, name string) { session, ok := sessions.Get(name) if !ok { slog.Warn("Refresh tools: no session", "name", name) @@ -139,7 +139,7 @@ func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) { return result.Tools, nil } -func updateTools(cfg *config.Config, name string, tools []*Tool) int { +func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int { tools = filterDisabledTools(cfg, name, tools) if len(tools) == 0 { allTools.Del(name) @@ -150,8 +150,8 @@ func updateTools(cfg *config.Config, name string, tools []*Tool) int { } // filterDisabledTools removes tools that are disabled via config. -func filterDisabledTools(cfg *config.Config, mcpName string, tools []*Tool) []*Tool { - mcpCfg, ok := cfg.MCP[mcpName] +func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool { + mcpCfg, ok := cfg.Config().MCP[mcpName] if !ok || len(mcpCfg.DisabledTools) == 0 { return tools } diff --git a/internal/agent/tools/read_mcp_resource.go b/internal/agent/tools/read_mcp_resource.go index cc0450d63aa94574e45e4264906c77fc2b7a1127..c96b00194b92b05e40953a49a90c3453fe6b16b2 100644 --- a/internal/agent/tools/read_mcp_resource.go +++ b/internal/agent/tools/read_mcp_resource.go @@ -30,7 +30,7 @@ const ReadMCPResourceToolName = "read_mcp_resource" //go:embed read_mcp_resource.md var readMCPResourceDescription []byte -func NewReadMCPResourceTool(cfg *config.Config, permissions permission.Service) fantasy.AgentTool { +func NewReadMCPResourceTool(cfg *config.ConfigStore, permissions permission.Service) fantasy.AgentTool { return fantasy.NewParallelAgentTool( ReadMCPResourceToolName, string(readMCPResourceDescription), diff --git a/internal/app/app.go b/internal/app/app.go index 7d87bd1231000cb2a1c88c1fa7a0ceae5b4316a9..8ed3e2e41cb2b235771eba24c3b59945f73cdfda 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -61,7 +61,7 @@ type App struct { LSPManager *lsp.Manager - config *config.Config + config *config.ConfigStore serviceEventsWG *sync.WaitGroup eventsCtx context.Context @@ -75,11 +75,12 @@ type App struct { } // New initializes a new application instance. -func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { +func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore) (*App, error) { q := db.New(conn) sessions := session.NewService(q, conn) messages := message.NewService(q) files := history.NewService(q, conn) + cfg := store.Config() skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests var allowedTools []string if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil { @@ -90,13 +91,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { Sessions: sessions, Messages: messages, History: files, - Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), + Permissions: permission.NewPermissionService(store.WorkingDir(), skipPermissionsRequests, allowedTools), FileTracker: filetracker.NewService(q), - LSPManager: lsp.NewManager(cfg), + LSPManager: lsp.NewManager(store), globalCtx: ctx, - config: cfg, + config: store, events: make(chan tea.Msg, 100), serviceEventsWG: &sync.WaitGroup{}, @@ -109,7 +110,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { // Check for updates in the background. go app.checkForUpdates(ctx) - go mcp.Initialize(ctx, app.Permissions, cfg) + go mcp.Initialize(ctx, app.Permissions, store) // cleanup database upon app shutdown app.cleanupFuncs = append( @@ -141,8 +142,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { return app, nil } -// Config returns the application configuration. +// Config returns the pure-data configuration. func (app *App) Config() *config.Config { + return app.config.Config() +} + +// Store returns the config store. +func (app *App) Store() *config.ConfigStore { return app.config } @@ -178,7 +184,7 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, } stderrTTY = term.IsTerminal(os.Stderr.Fd()) stdinTTY = term.IsTerminal(os.Stdin.Fd()) - progress = app.config.Options.Progress == nil || *app.config.Options.Progress + progress = app.config.Config().Options.Progress == nil || *app.config.Config().Options.Progress if !hideSpinner && stderrTTY { t := styles.DefaultStyles() @@ -331,7 +337,7 @@ func (app *App) UpdateAgentModel(ctx context.Context) error { // If largeModel is provided but smallModel is not, the small model defaults to // the provider's default small model. func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, smallModel string) error { - providers := app.config.Providers.Copy() + providers := app.config.Config().Providers.Copy() largeMatches, smallMatches, err := findModels(providers, largeModel, smallModel) if err != nil { @@ -348,7 +354,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, } largeProviderID = found.provider slog.Info("Overriding large model for non-interactive run", "provider", found.provider, "model", found.modelID) - app.config.Models[config.SelectedModelTypeLarge] = config.SelectedModel{ + app.config.Config().Models[config.SelectedModelTypeLarge] = config.SelectedModel{ Provider: found.provider, Model: found.modelID, } @@ -362,7 +368,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, return err } slog.Info("Overriding small model for non-interactive run", "provider", found.provider, "model", found.modelID) - app.config.Models[config.SelectedModelTypeSmall] = config.SelectedModel{ + app.config.Config().Models[config.SelectedModelTypeSmall] = config.SelectedModel{ Provider: found.provider, Model: found.modelID, } @@ -370,7 +376,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, case largeModel != "": // No small model specified, but large model was - use provider's default. smallCfg := app.GetDefaultSmallModel(largeProviderID) - app.config.Models[config.SelectedModelTypeSmall] = smallCfg + app.config.Config().Models[config.SelectedModelTypeSmall] = smallCfg } return app.AgentCoordinator.UpdateModels(ctx) @@ -379,7 +385,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, // GetDefaultSmallModel returns the default small model for the given // provider. Falls back to the large model if no default is found. func (app *App) GetDefaultSmallModel(providerID string) config.SelectedModel { - cfg := app.config + cfg := app.config.Config() largeModelCfg := cfg.Models[config.SelectedModelTypeLarge] // Find the provider in the known providers list to get its default small model. @@ -481,7 +487,7 @@ func setupSubscriber[T any]( } func (app *App) InitCoderAgent(ctx context.Context) error { - coderAgentCfg := app.config.Agents[config.AgentCoder] + coderAgentCfg := app.config.Config().Agents[config.AgentCoder] if coderAgentCfg.ID == "" { return fmt.Errorf("coder agent configuration is missing") } diff --git a/internal/cmd/login.go b/internal/cmd/login.go index bdad4547d6f583b5ae7e5a97bbbbd88a1421e6ee..c9acb12df19875f48b242bee96e377bf5548aacb 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -52,16 +52,16 @@ crush login copilot } switch provider { case "hyper": - return loginHyper(app.Config()) + return loginHyper(app.Store()) case "copilot", "github", "github-copilot": - return loginCopilot(app.Config()) + return loginCopilot(app.Store()) default: return fmt.Errorf("unknown platform: %s", args[0]) } }, } -func loginHyper(cfg *config.Config) error { +func loginHyper(cfg *config.ConfigStore) error { if !hyperp.Enabled() { return fmt.Errorf("hyper not enabled") } @@ -112,8 +112,8 @@ func loginHyper(cfg *config.Config) error { } if err := cmp.Or( - cfg.SetConfigField("providers.hyper.api_key", token.AccessToken), - cfg.SetConfigField("providers.hyper.oauth", token), + cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken), + cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.oauth", token), ); err != nil { return err } @@ -123,10 +123,10 @@ func loginHyper(cfg *config.Config) error { return nil } -func loginCopilot(cfg *config.Config) error { +func loginCopilot(cfg *config.ConfigStore) error { ctx := getLoginContext() - if cfg.HasConfigField("providers.copilot.oauth") { + if cfg.HasConfigField(config.ScopeGlobal, "providers.copilot.oauth") { fmt.Println("You are already logged in to GitHub Copilot.") return nil } @@ -177,8 +177,8 @@ func loginCopilot(cfg *config.Config) error { } if err := cmp.Or( - cfg.SetConfigField("providers.copilot.api_key", token.AccessToken), - cfg.SetConfigField("providers.copilot.oauth", token), + cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken), + cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.oauth", token), ); err != nil { return err } diff --git a/internal/cmd/logs.go b/internal/cmd/logs.go index 804b23310fa1e3fb86e4b32983bfcdd571df47aa..87e106feb7cc934567b183d454ef1537970dca88 100644 --- a/internal/cmd/logs.go +++ b/internal/cmd/logs.go @@ -55,7 +55,7 @@ var logsCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to load configuration: %v", err) } - logsFile := filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log") + logsFile := filepath.Join(cfg.Config().Options.DataDirectory, "logs", "crush.log") _, err = os.Stat(logsFile) if os.IsNotExist(err) { log.Warn("Looks like you are not in a crush project. No logs found.") diff --git a/internal/cmd/models.go b/internal/cmd/models.go index e2aa5c991d5cf49ba78dbff9d3f79c4f6493523d..f4fa559ebe41d93bee54ed5e2272f8fb0b8dc9ad 100644 --- a/internal/cmd/models.go +++ b/internal/cmd/models.go @@ -38,7 +38,7 @@ crush models gpt5`, return err } - if !cfg.IsConfigured() { + if !cfg.Config().IsConfigured() { return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") } @@ -55,7 +55,7 @@ crush models gpt5`, var providerIDs []string providerModels := make(map[string][]string) - for providerID, provider := range cfg.Providers.Seq2() { + for providerID, provider := range cfg.Config().Providers.Seq2() { if provider.Disable { continue } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 52ffda3fb09a0e6fdfb88084b80f7bdd261fb3c2..6e1bc08f2f14e8af3d65b5dca7826b95d890b116 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -189,11 +189,12 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - cfg, err := config.Init(cwd, dataDir, debug) + store, err := config.Init(cwd, dataDir, debug) if err != nil { return nil, err } + cfg := store.Config() if cfg.Permissions == nil { cfg.Permissions = &config.Permissions{} } @@ -215,7 +216,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - appInstance, err := app.New(ctx, conn, cfg) + appInstance, err := app.New(ctx, conn, store) if err != nil { slog.Error("Failed to create app instance", "error", err) return nil, err diff --git a/internal/cmd/stats.go b/internal/cmd/stats.go index 8831c2a647a283bfe6d6edff15c5eff4dafb3377..3900acadec059869b1896c8adeb49f93155f17fa 100644 --- a/internal/cmd/stats.go +++ b/internal/cmd/stats.go @@ -131,7 +131,7 @@ func runStats(cmd *cobra.Command, _ []string) error { if err != nil { return fmt.Errorf("failed to initialize config: %w", err) } - dataDir = cfg.Options.DataDirectory + dataDir = cfg.Config().Options.DataDirectory } conn, err := db.Connect(ctx, dataDir) diff --git a/internal/commands/commands.go b/internal/commands/commands.go index aeb2ca305dc984c2c450d249d51028858e4e9802..96302bde1281adebfe74a009e4f76443f5368afe 100644 --- a/internal/commands/commands.go +++ b/internal/commands/commands.go @@ -227,7 +227,7 @@ func isMarkdownFile(name string) bool { return strings.HasSuffix(strings.ToLower(name), ".md") } -func GetMCPPrompt(cfg *config.Config, clientID, promptID string, args map[string]string) (string, error) { +func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) { // TODO: we should pass the context down result, err := mcp.GetPromptMessages(context.Background(), cfg, clientID, promptID, args) if err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 118afef344f8a022add7a13db406ccce27a1391e..8e9b3f0fb7349f4b911c9a6c41fc3e3890f3f19e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,22 +8,16 @@ import ( "maps" "net/http" "net/url" - "os" - "path/filepath" "slices" "strings" "time" "charm.land/catwalk/pkg/catwalk" - hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" - "github.com/charmbracelet/crush/internal/oauth/hyper" "github.com/invopop/jsonschema" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) const ( @@ -398,17 +392,6 @@ type Config struct { Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"` Agents map[string]Agent `json:"-"` - - // Internal - workingDir string `json:"-"` - // TODO: find a better way to do this this should probably not be part of the config - resolver VariableResolver - dataConfigDir string `json:"-"` - knownProviders []catwalk.Provider `json:"-"` -} - -func (c *Config) WorkingDir() string { - return c.workingDir } func (c *Config) EnabledProviders() []ProviderConfig { @@ -472,235 +455,8 @@ func (c *Config) SmallModel() *catwalk.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) SetCompactMode(enabled bool) error { - if c.Options == nil { - c.Options = &Options{} - } - c.Options.TUI.CompactMode = enabled - return c.SetConfigField("options.tui.compact_mode", enabled) -} - -func (c *Config) Resolve(key string) (string, error) { - if c.resolver == nil { - return "", fmt.Errorf("no variable resolver configured") - } - return c.resolver.ResolveValue(key) -} - -func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { - c.Models[modelType] = model - if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { - return fmt.Errorf("failed to update preferred model: %w", err) - } - if err := c.recordRecentModel(modelType, model); err != nil { - return err - } - return nil -} - -func (c *Config) HasConfigField(key string) bool { - data, err := os.ReadFile(c.dataConfigDir) - if err != nil { - return false - } - return gjson.Get(string(data), key).Exists() -} - -func (c *Config) SetConfigField(key string, value any) error { - data, err := os.ReadFile(c.dataConfigDir) - if err != nil { - if os.IsNotExist(err) { - data = []byte("{}") - } else { - return fmt.Errorf("failed to read config file: %w", err) - } - } - - newValue, err := sjson.Set(string(data), key, value) - if err != nil { - return fmt.Errorf("failed to set config field %s: %w", key, err) - } - if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil { - return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err) - } - if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - return nil -} - -func (c *Config) RemoveConfigField(key string) error { - data, err := os.ReadFile(c.dataConfigDir) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - - newValue, err := sjson.Delete(string(data), key) - if err != nil { - return fmt.Errorf("failed to delete config field %s: %w", key, err) - } - if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil { - return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err) - } - if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil { - return fmt.Errorf("failed to write config file: %w", err) - } - return nil -} - -// RefreshOAuthToken refreshes the OAuth token for the given provider. -func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error { - providerConfig, exists := c.Providers.Get(providerID) - if !exists { - return fmt.Errorf("provider %s not found", providerID) - } - - if providerConfig.OAuthToken == nil { - return fmt.Errorf("provider %s does not have an OAuth token", providerID) - } - - var newToken *oauth.Token - var refreshErr error - switch providerID { - case string(catwalk.InferenceProviderCopilot): - newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) - case hyperp.Name: - newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) - default: - return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) - } - if refreshErr != nil { - return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) - } - - slog.Info("Successfully refreshed OAuth token", "provider", providerID) - providerConfig.OAuthToken = newToken - providerConfig.APIKey = newToken.AccessToken - - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - - c.Providers.Set(providerID, providerConfig) - - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken), - ); err != nil { - return fmt.Errorf("failed to persist refreshed token: %w", err) - } - - return nil -} - -func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error { - var providerConfig ProviderConfig - var exists bool - var setKeyOrToken func() - - switch v := apiKey.(type) { - case string: - if err := c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { - return fmt.Errorf("failed to save api key to config file: %w", err) - } - setKeyOrToken = func() { providerConfig.APIKey = v } - case *oauth.Token: - if err := cmp.Or( - c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), - c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v), - ); err != nil { - return err - } - setKeyOrToken = func() { - providerConfig.APIKey = v.AccessToken - providerConfig.OAuthToken = v - switch providerID { - case string(catwalk.InferenceProviderCopilot): - providerConfig.SetupGitHubCopilot() - } - } - } - - providerConfig, exists = c.Providers.Get(providerID) - if exists { - setKeyOrToken() - c.Providers.Set(providerID, providerConfig) - return nil - } - - var foundProvider *catwalk.Provider - for _, p := range c.knownProviders { - if string(p.ID) == providerID { - foundProvider = &p - break - } - } - - if foundProvider != nil { - // Create new provider config based on known provider - providerConfig = ProviderConfig{ - ID: providerID, - Name: foundProvider.Name, - BaseURL: foundProvider.APIEndpoint, - Type: foundProvider.Type, - Disable: false, - ExtraHeaders: make(map[string]string), - ExtraParams: make(map[string]string), - Models: foundProvider.Models, - } - setKeyOrToken() - } else { - return fmt.Errorf("provider with ID %s not found in known providers", providerID) - } - // Store the updated provider config - c.Providers.Set(providerID, providerConfig) - return nil -} - const maxRecentModelsPerType = 5 -func (c *Config) recordRecentModel(modelType SelectedModelType, model SelectedModel) error { - if model.Provider == "" || model.Model == "" { - return nil - } - - if c.RecentModels == nil { - c.RecentModels = make(map[SelectedModelType][]SelectedModel) - } - - eq := func(a, b SelectedModel) bool { - return a.Provider == b.Provider && a.Model == b.Model - } - - entry := SelectedModel{ - Provider: model.Provider, - Model: model.Model, - } - - current := c.RecentModels[modelType] - withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { - return eq(existing, entry) - }) - - updated := append([]SelectedModel{entry}, withoutCurrent...) - if len(updated) > maxRecentModelsPerType { - updated = updated[:maxRecentModelsPerType] - } - - if slices.EqualFunc(current, updated, eq) { - return nil - } - - c.RecentModels[modelType] = updated - - if err := c.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { - return fmt.Errorf("failed to persist recent models: %w", err) - } - - return nil -} - func allToolNames() []string { return []string{ "agent", @@ -780,10 +536,6 @@ func (c *Config) SetupAgents() { c.Agents = agents } -func (c *Config) Resolver() VariableResolver { - return c.resolver -} - func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { var ( providerID = catwalk.InferenceProvider(c.ID) diff --git a/internal/config/copilot.go b/internal/config/copilot.go index d72e7d5048ba4d31c88d7f7152a6b3a9510960a2..d912156bec00a9f00850ab2ec3a3baf1016c2141 100644 --- a/internal/config/copilot.go +++ b/internal/config/copilot.go @@ -1,48 +1 @@ package config - -import ( - "cmp" - "context" - "log/slog" - "testing" - - "charm.land/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/oauth" - "github.com/charmbracelet/crush/internal/oauth/copilot" -) - -func (c *Config) ImportCopilot() (*oauth.Token, bool) { - if testing.Testing() { - return nil, false - } - - if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") { - return nil, false - } - - diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() - if !hasDiskToken { - return nil, false - } - - slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") - token, err := copilot.RefreshToken(context.TODO(), diskToken) - if err != nil { - slog.Error("Unable to import GitHub Copilot token", "error", err) - return nil, false - } - - if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil { - return token, false - } - - if err := cmp.Or( - c.SetConfigField("providers.copilot.api_key", token.AccessToken), - c.SetConfigField("providers.copilot.oauth", token), - ); err != nil { - slog.Error("Unable to save GitHub Copilot token to disk", "error", err) - } - - slog.Info("GitHub Copilot successfully imported") - return token, true -} diff --git a/internal/config/init.go b/internal/config/init.go index 5a4683f77485f54409d4372a33d1933b47abd33f..6138c49d496b16d054ea9ff3f6c49906b3f433ca 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -18,19 +18,20 @@ type ProjectInitFlag struct { Initialized bool `json:"initialized"` } -func Init(workingDir, dataDir string, debug bool) (*Config, error) { - cfg, err := Load(workingDir, dataDir, debug) +func Init(workingDir, dataDir string, debug bool) (*ConfigStore, error) { + store, err := Load(workingDir, dataDir, debug) if err != nil { return nil, err } - return cfg, nil + return store, nil } -func ProjectNeedsInitialization(cfg *Config) (bool, error) { - if cfg == nil { +func ProjectNeedsInitialization(store *ConfigStore) (bool, error) { + if store == nil { return false, fmt.Errorf("config not loaded") } + cfg := store.Config() flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename) _, err := os.Stat(flagFilePath) @@ -42,7 +43,7 @@ func ProjectNeedsInitialization(cfg *Config) (bool, error) { return false, fmt.Errorf("failed to check init flag file: %w", err) } - someContextFileExists, err := contextPathsExist(cfg.WorkingDir()) + someContextFileExists, err := contextPathsExist(store.WorkingDir()) if err != nil { return false, fmt.Errorf("failed to check for context files: %w", err) } @@ -51,7 +52,7 @@ func ProjectNeedsInitialization(cfg *Config) (bool, error) { } // If the working directory has no non-ignored files, skip initialization step - empty, err := dirHasNoVisibleFiles(cfg.WorkingDir()) + empty, err := dirHasNoVisibleFiles(store.WorkingDir()) if err != nil { return false, fmt.Errorf("failed to check if directory is empty: %w", err) } @@ -90,7 +91,7 @@ func contextPathsExist(dir string) (bool, error) { return false, nil } -// dirHasNoVisibleFiles returns true if the directory has no files/dirs after applying ignore rules +// dirHasNoVisibleFiles returns true if the directory has no files/dirs after applying ignore rules. func dirHasNoVisibleFiles(dir string) (bool, error) { files, _, err := fsext.ListDirectory(dir, nil, 1, 1) if err != nil { @@ -99,11 +100,11 @@ func dirHasNoVisibleFiles(dir string) (bool, error) { return len(files) == 0, nil } -func MarkProjectInitialized(cfg *Config) error { - if cfg == nil { +func MarkProjectInitialized(store *ConfigStore) error { + if store == nil { return fmt.Errorf("config not loaded") } - flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename) + flagFilePath := filepath.Join(store.Config().Options.DataDirectory, InitFlagFilename) file, err := os.Create(flagFilePath) if err != nil { @@ -114,13 +115,13 @@ func MarkProjectInitialized(cfg *Config) error { return nil } -func HasInitialDataConfig(cfg *Config) bool { - if cfg == nil { +func HasInitialDataConfig(store *ConfigStore) bool { + if store == nil { return false } cfgPath := GlobalConfigData() if _, err := os.Stat(cfgPath); err != nil { return false } - return cfg.IsConfigured() + return store.Config().IsConfigured() } diff --git a/internal/config/load.go b/internal/config/load.go index 3fba44aa9142c52b8966b1dbe994cef0ae654c48..0c63950e84434d7bebd20e39c25734541027a4d9 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -29,8 +29,9 @@ import ( const defaultCatwalkURL = "https://catwalk.charm.sh" -// Load loads the configuration from the default paths. -func Load(workingDir, dataDir string, debug bool) (*Config, error) { +// Load loads the configuration from the default paths and returns a +// ConfigStore that owns both the pure-data Config and all runtime state. +func Load(workingDir, dataDir string, debug bool) (*ConfigStore, error) { configPaths := lookupConfigs(workingDir) cfg, err := loadFromConfigPaths(configPaths) @@ -38,10 +39,15 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err) } - cfg.dataConfigDir = GlobalConfigData() - cfg.setDefaults(workingDir, dataDir) + store := &ConfigStore{ + config: cfg, + workingDir: workingDir, + globalDataPath: GlobalConfigData(), + workspacePath: filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName)), + } + if debug { cfg.Options.Debug = true } @@ -52,6 +58,18 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { cfg.Options.Debug, ) + // Load workspace config last so it has highest priority. + if wsData, err := os.ReadFile(store.workspacePath); err == nil && len(wsData) > 0 { + merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData)) + if mergeErr == nil { + // Preserve defaults that setDefaults already applied. + dataDir := cfg.Options.DataDirectory + *cfg = *merged + cfg.setDefaults(workingDir, dataDir) + store.config = cfg + } + } + if !isInsideWorktree() { const depth = 2 const items = 100 @@ -72,26 +90,36 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { if err != nil { return nil, err } - cfg.knownProviders = providers + store.knownProviders = providers env := env.New() // Configure providers valueResolver := NewShellVariableResolver(env) - cfg.resolver = valueResolver - if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil { + store.resolver = valueResolver + if err := cfg.configureProviders(store, env, valueResolver, store.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } if !cfg.IsConfigured() { slog.Warn("No providers configured") - return cfg, nil + return store, nil } - if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil { + if err := configureSelectedModels(store, store.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } - cfg.SetupAgents() - return cfg, nil + store.SetupAgents() + return store, nil +} + +// mustMarshalConfig marshals the config to JSON bytes, returning empty JSON on +// error. +func mustMarshalConfig(cfg *Config) []byte { + data, err := json.Marshal(cfg) + if err != nil { + return []byte("{}") + } + return data } func PushPopCrushEnv() func() { @@ -122,7 +150,7 @@ func PushPopCrushEnv() func() { return restore } -func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { +func (c *Config) configureProviders(store *ConfigStore, env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() @@ -209,7 +237,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch { case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: // Claude Code subscription is not supported anymore. Remove to show onboarding. - c.RemoveConfigField("providers.anthropic") + store.RemoveConfigField(ScopeGlobal, "providers.anthropic") c.Providers.Del(string(p.ID)) continue case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil: @@ -340,7 +368,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } func (c *Config) setDefaults(workingDir, dataDir string) { - c.workingDir = workingDir if c.Options == nil { c.Options = &Options{} } @@ -524,7 +551,8 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large return largeModel, smallModel, err } -func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error { +func configureSelectedModels(store *ConfigStore, knownProviders []catwalk.Provider) error { + c := store.config defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders) if err != nil { return fmt.Errorf("failed to select default models: %w", err) @@ -543,7 +571,7 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro if model == nil { large = defaultLarge // override the model type to large - err := c.UpdatePreferredModel(SelectedModelTypeLarge, large) + err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeLarge, large) if err != nil { return fmt.Errorf("failed to update preferred large model: %w", err) } @@ -587,7 +615,7 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro if model == nil { small = defaultSmall // override the model type to small - err := c.UpdatePreferredModel(SelectedModelTypeSmall, small) + err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeSmall, small) if err != nil { return fmt.Errorf("failed to update preferred small model: %w", err) } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 93d2245193463e2a6539e23aeb0e16ac14c0ccef..62b1eaa2437116b0a051fe10183689650d388472 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -36,6 +36,11 @@ func TestConfig_LoadFromBytes(t *testing.T) { require.Equal(t, "https://api.openai.com/v2", pc.BaseURL) } +// testStore wraps a Config in a minimal ConfigStore for testing. +func testStore(cfg *Config) *ConfigStore { + return &ConfigStore{config: cfg} +} + func TestConfig_setDefaults(t *testing.T) { cfg := &Config{} @@ -53,7 +58,6 @@ func TestConfig_setDefaults(t *testing.T) { for _, path := range defaultContextPaths { require.Contains(t, cfg.Options.ContextPaths, path) } - require.Equal(t, "/tmp", cfg.workingDir) } func TestConfig_configureProviders(t *testing.T) { @@ -74,7 +78,7 @@ func TestConfig_configureProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) @@ -117,7 +121,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) @@ -159,7 +163,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Should be to because of the env variable require.Equal(t, cfg.Providers.Len(), 2) @@ -195,7 +199,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -221,7 +225,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without credentials require.Equal(t, cfg.Providers.Len(), 0) @@ -246,7 +250,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.Error(t, err) } @@ -269,7 +273,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { "VERTEXAI_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -301,7 +305,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without proper credentials require.Equal(t, cfg.Providers.Len(), 0) @@ -326,7 +330,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without project require.Equal(t, cfg.Providers.Len(), 0) @@ -350,7 +354,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -541,7 +545,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -569,7 +573,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -592,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -614,7 +618,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -639,7 +643,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -664,7 +668,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -692,7 +696,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -722,7 +726,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -757,7 +761,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { "GOOGLE_GENAI_USE_VERTEXAI": "false", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -788,7 +792,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -819,7 +823,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -852,7 +856,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -886,7 +890,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) @@ -922,7 +926,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) @@ -952,7 +956,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) require.Error(t, err) @@ -995,7 +999,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) require.NoError(t, err) @@ -1039,7 +1043,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) require.Error(t, err) @@ -1081,7 +1085,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) require.NoError(t, err) @@ -1126,7 +1130,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.ErrorContains(t, err, "no custom providers") // openai should NOT be present because it lacks base_url and models. @@ -1169,7 +1173,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Only fully specified provider should be present. @@ -1223,7 +1227,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "ANTHROPIC_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) // Both providers should be present. @@ -1251,7 +1255,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.ErrorContains(t, err, "no custom providers") // Provider should be rejected for missing models. @@ -1275,7 +1279,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{}) require.ErrorContains(t, err, "no custom providers") // Provider should be rejected for missing base_url. @@ -1340,10 +1344,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = configureSelectedModels(testStore(cfg), knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] @@ -1402,10 +1406,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = configureSelectedModels(testStore(cfg), knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] @@ -1447,10 +1451,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = configureSelectedModels(testStore(cfg), knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] require.Equal(t, "large-model", large.Model) diff --git a/internal/config/recent_models_test.go b/internal/config/recent_models_test.go index 739ddc0031a65cab261723772c3f38658dcd1561..7c46d5d5202927932ed154a4da8b0719ce9e114e 100644 --- a/internal/config/recent_models_test.go +++ b/internal/config/recent_models_test.go @@ -31,15 +31,23 @@ func readRecentModels(t *testing.T, path string) map[string]any { return rm } +// testStoreWithPath creates a ConfigStore backed by a Config for recent model tests. +func testStoreWithPath(cfg *Config, dir string) *ConfigStore { + return &ConfigStore{ + config: cfg, + globalDataPath: filepath.Join(dir, "config.json"), + } +} + func TestRecordRecentModel_AddsAndPersists(t *testing.T) { t.Parallel() dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) - err := cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) + err := store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}) require.NoError(t, err) // in-memory state @@ -48,7 +56,7 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) { require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model) // persisted state - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 1) @@ -64,13 +72,13 @@ func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Add two entries - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"})) // Re-add first; should move to front and not duplicate - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})) got := cfg.RecentModels[SelectedModelTypeLarge] require.Len(t, got, 2) @@ -84,7 +92,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Insert 6 unique models; max is 5 entries := []SelectedModel{ @@ -96,7 +104,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { {Provider: "p6", Model: "m6"}, } for _, e := range entries { - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, e)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, e)) } // in-memory state @@ -110,7 +118,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) { require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4]) // persisted state: verify trimmed to 5 and newest-first order - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) require.Len(t, large, 5) @@ -129,12 +137,12 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Missing provider - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"})) // Missing model - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""})) _, ok := cfg.RecentModels[SelectedModelTypeLarge] // Map may be initialized, but should have no entries @@ -142,8 +150,8 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) { require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0) } // No file should be written (stat via fs.FS) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(store.globalDataPath) + fileName := filepath.Base(store.globalDataPath) _, err := fs.Stat(os.DirFS(baseDir), fileName) require.True(t, os.IsNotExist(err)) } @@ -154,13 +162,13 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) entry := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, entry)) - baseDir := filepath.Dir(cfg.dataConfigDir) - fileName := filepath.Base(cfg.dataConfigDir) + baseDir := filepath.Dir(store.globalDataPath) + fileName := filepath.Base(store.globalDataPath) before, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -170,7 +178,7 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) { beforeMod := stBefore.ModTime() // Re-record same entry should be a no-op (no write) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, entry)) after, err := fs.ReadFile(os.DirFS(baseDir), fileName) require.NoError(t, err) @@ -188,17 +196,17 @@ func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) sel := SelectedModel{Provider: "openai", Model: "gpt-4o"} - require.NoError(t, cfg.UpdatePreferredModel(SelectedModelTypeSmall, sel)) + require.NoError(t, store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeSmall, sel)) // in-memory require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall]) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) // persisted (read via fs.FS) - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) small, ok := rm[string(SelectedModelTypeSmall)].([]any) require.True(t, ok) require.Len(t, small, 1) @@ -210,14 +218,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) { dir := t.TempDir() cfg := &Config{} cfg.setDefaults(dir, "") - cfg.dataConfigDir = filepath.Join(dir, "config.json") + store := testStoreWithPath(cfg, dir) // Add models to both large and small types largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"} smallModel := SelectedModel{Provider: "anthropic", Model: "claude"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, largeModel)) - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeSmall, smallModel)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, largeModel)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeSmall, smallModel)) // in-memory: verify types maintain separate histories require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1) @@ -227,14 +235,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) { // Add another to large, verify small unchanged anotherLarge := SelectedModel{Provider: "google", Model: "gemini"} - require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, anotherLarge)) + require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, anotherLarge)) require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2) require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1) require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0]) // persisted state: verify both types exist with correct lengths and contents - rm := readRecentModels(t, cfg.dataConfigDir) + rm := readRecentModels(t, store.globalDataPath) large, ok := rm[string(SelectedModelTypeLarge)].([]any) require.True(t, ok) diff --git a/internal/config/scope.go b/internal/config/scope.go new file mode 100644 index 0000000000000000000000000000000000000000..971ce32c3ed662dd0d0627c4f1c858372f3b4514 --- /dev/null +++ b/internal/config/scope.go @@ -0,0 +1,11 @@ +package config + +// Scope determines which config file is targeted for read/write operations. +type Scope int + +const ( + // ScopeGlobal targets the global data config (~/.local/share/crush/crush.json). + ScopeGlobal Scope = iota + // ScopeWorkspace targets the workspace config (.crush/crush.json). + ScopeWorkspace +) diff --git a/internal/config/store.go b/internal/config/store.go new file mode 100644 index 0000000000000000000000000000000000000000..4dfe6130bc23007ec3df12e5a88cb53bc3ad5a2d --- /dev/null +++ b/internal/config/store.go @@ -0,0 +1,336 @@ +package config + +import ( + "cmp" + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "slices" + + "charm.land/catwalk/pkg/catwalk" + hyperp "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConfigStore is the single entry point for all config access. It owns the +// pure-data Config, runtime state (working directory, resolver, known +// providers), and persistence to both global and workspace config files. +type ConfigStore struct { + config *Config + workingDir string + resolver VariableResolver + globalDataPath string // ~/.local/share/crush/crush.json + workspacePath string // .crush/crush.json + knownProviders []catwalk.Provider +} + +// Config returns the pure-data config struct (read-only after load). +func (s *ConfigStore) Config() *Config { + return s.config +} + +// WorkingDir returns the current working directory. +func (s *ConfigStore) WorkingDir() string { + return s.workingDir +} + +// Resolver returns the variable resolver. +func (s *ConfigStore) Resolver() VariableResolver { + return s.resolver +} + +// Resolve resolves a variable reference using the configured resolver. +func (s *ConfigStore) Resolve(key string) (string, error) { + if s.resolver == nil { + return "", fmt.Errorf("no variable resolver configured") + } + return s.resolver.ResolveValue(key) +} + +// KnownProviders returns the list of known providers. +func (s *ConfigStore) KnownProviders() []catwalk.Provider { + return s.knownProviders +} + +// SetupAgents configures the coder and task agents on the config. +func (s *ConfigStore) SetupAgents() { + s.config.SetupAgents() +} + +// configPath returns the file path for the given scope. +func (s *ConfigStore) configPath(scope Scope) string { + switch scope { + case ScopeWorkspace: + return s.workspacePath + default: + return s.globalDataPath + } +} + +// HasConfigField checks whether a key exists in the config file for the given +// scope. +func (s *ConfigStore) HasConfigField(scope Scope, key string) bool { + data, err := os.ReadFile(s.configPath(scope)) + if err != nil { + return false + } + return gjson.Get(string(data), key).Exists() +} + +// SetConfigField sets a key/value pair in the config file for the given scope. +func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error { + path := s.configPath(scope) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + data = []byte("{}") + } else { + return fmt.Errorf("failed to read config file: %w", err) + } + } + + newValue, err := sjson.Set(string(data), key, value) + if err != nil { + return fmt.Errorf("failed to set config field %s: %w", key, err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("failed to create config directory %q: %w", path, err) + } + if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + return nil +} + +// RemoveConfigField removes a key from the config file for the given scope. +func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error { + path := s.configPath(scope) + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to read config file: %w", err) + } + + newValue, err := sjson.Delete(string(data), key) + if err != nil { + return fmt.Errorf("failed to delete config field %s: %w", key, err) + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("failed to create config directory %q: %w", path, err) + } + if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + return nil +} + +// UpdatePreferredModel updates the preferred model for the given type and +// persists it to the config file at the given scope. +func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error { + s.config.Models[modelType] = model + if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil { + return fmt.Errorf("failed to update preferred model: %w", err) + } + if err := s.recordRecentModel(scope, modelType, model); err != nil { + return err + } + return nil +} + +// SetCompactMode sets the compact mode setting and persists it. +func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error { + if s.config.Options == nil { + s.config.Options = &Options{} + } + s.config.Options.TUI.CompactMode = enabled + return s.SetConfigField(scope, "options.tui.compact_mode", enabled) +} + +// SetProviderAPIKey sets the API key for a provider and persists it. +func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error { + var providerConfig ProviderConfig + var exists bool + var setKeyOrToken func() + + switch v := apiKey.(type) { + case string: + if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil { + return fmt.Errorf("failed to save api key to config file: %w", err) + } + setKeyOrToken = func() { providerConfig.APIKey = v } + case *oauth.Token: + if err := cmp.Or( + s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v), + ); err != nil { + return err + } + setKeyOrToken = func() { + providerConfig.APIKey = v.AccessToken + providerConfig.OAuthToken = v + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + } + } + + providerConfig, exists = s.config.Providers.Get(providerID) + if exists { + setKeyOrToken() + s.config.Providers.Set(providerID, providerConfig) + return nil + } + + var foundProvider *catwalk.Provider + for _, p := range s.knownProviders { + if string(p.ID) == providerID { + foundProvider = &p + break + } + } + + if foundProvider != nil { + providerConfig = ProviderConfig{ + ID: providerID, + Name: foundProvider.Name, + BaseURL: foundProvider.APIEndpoint, + Type: foundProvider.Type, + Disable: false, + ExtraHeaders: make(map[string]string), + ExtraParams: make(map[string]string), + Models: foundProvider.Models, + } + setKeyOrToken() + } else { + return fmt.Errorf("provider with ID %s not found in known providers", providerID) + } + s.config.Providers.Set(providerID, providerConfig) + return nil +} + +// RefreshOAuthToken refreshes the OAuth token for the given provider. +func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error { + providerConfig, exists := s.config.Providers.Get(providerID) + if !exists { + return fmt.Errorf("provider %s not found", providerID) + } + + if providerConfig.OAuthToken == nil { + return fmt.Errorf("provider %s does not have an OAuth token", providerID) + } + + var newToken *oauth.Token + var refreshErr error + switch providerID { + case string(catwalk.InferenceProviderCopilot): + newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken) + case hyperp.Name: + newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken) + default: + return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) + } + if refreshErr != nil { + return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) + } + + slog.Info("Successfully refreshed OAuth token", "provider", providerID) + providerConfig.OAuthToken = newToken + providerConfig.APIKey = newToken.AccessToken + + switch providerID { + case string(catwalk.InferenceProviderCopilot): + providerConfig.SetupGitHubCopilot() + } + + s.config.Providers.Set(providerID, providerConfig) + + if err := cmp.Or( + s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), + s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken), + ); err != nil { + return fmt.Errorf("failed to persist refreshed token: %w", err) + } + + return nil +} + +// recordRecentModel records a model in the recent models list. +func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error { + if model.Provider == "" || model.Model == "" { + return nil + } + + if s.config.RecentModels == nil { + s.config.RecentModels = make(map[SelectedModelType][]SelectedModel) + } + + eq := func(a, b SelectedModel) bool { + return a.Provider == b.Provider && a.Model == b.Model + } + + entry := SelectedModel{ + Provider: model.Provider, + Model: model.Model, + } + + current := s.config.RecentModels[modelType] + withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool { + return eq(existing, entry) + }) + + updated := append([]SelectedModel{entry}, withoutCurrent...) + if len(updated) > maxRecentModelsPerType { + updated = updated[:maxRecentModelsPerType] + } + + if slices.EqualFunc(current, updated, eq) { + return nil + } + + s.config.RecentModels[modelType] = updated + + if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil { + return fmt.Errorf("failed to persist recent models: %w", err) + } + + return nil +} + +// ImportCopilot attempts to import a GitHub Copilot token from disk. +func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) { + if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") { + return nil, false + } + + diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() + if !hasDiskToken { + return nil, false + } + + slog.Info("Found existing GitHub Copilot token on disk. Authenticating...") + token, err := copilot.RefreshToken(context.TODO(), diskToken) + if err != nil { + slog.Error("Unable to import GitHub Copilot token", "error", err) + return nil, false + } + + if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil { + return token, false + } + + if err := cmp.Or( + s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken), + s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token), + ); err != nil { + slog.Error("Unable to save GitHub Copilot token to disk", "error", err) + } + + slog.Info("GitHub Copilot successfully imported") + return token, true +} diff --git a/internal/lsp/manager.go b/internal/lsp/manager.go index 13a78cef2a471a71c1e741e32e08e8d7edcb7484..b564c0e602c0234462a32cfaae67c8f8179551c4 100644 --- a/internal/lsp/manager.go +++ b/internal/lsp/manager.go @@ -26,18 +26,18 @@ var unavailable = csync.NewMap[string, struct{}]() // Manager handles lazy initialization of LSP clients based on file types. type Manager struct { clients *csync.Map[string, *Client] - cfg *config.Config + cfg *config.ConfigStore manager *powernapconfig.Manager callback func(name string, client *Client) } // NewManager creates a new LSP manager service. -func NewManager(cfg *config.Config) *Manager { +func NewManager(cfg *config.ConfigStore) *Manager { manager := powernapconfig.NewManager() manager.LoadDefaults() // Merge user-configured LSPs into the manager. - for name, clientConfig := range cfg.LSP { + for name, clientConfig := range cfg.Config().LSP { if clientConfig.Disabled { slog.Debug("LSP disabled by user config", "name", name) manager.RemoveServer(name) @@ -194,7 +194,7 @@ func (s *Manager) startServer(ctx context.Context, name, filepath string, server cfg, s.cfg.Resolver(), s.cfg.WorkingDir(), - s.cfg.Options.DebugLSP, + s.cfg.Config().Options.DebugLSP, ) if err != nil { slog.Error("Failed to create LSP client", "name", name, "error", err) @@ -244,7 +244,7 @@ func (s *Manager) startServer(ctx context.Context, name, filepath string, server } func (s *Manager) isUserConfigured(name string) bool { - cfg, ok := s.cfg.LSP[name] + cfg, ok := s.cfg.Config().LSP[name] return ok && !cfg.Disabled } @@ -258,7 +258,7 @@ func (s *Manager) buildConfig(name string, server *powernapconfig.ServerConfig) InitOptions: server.InitOptions, Options: server.Settings, } - if userCfg, ok := s.cfg.LSP[name]; ok { + if userCfg, ok := s.cfg.Config().LSP[name]; ok { cfg.Timeout = userCfg.Timeout } return cfg diff --git a/internal/ui/common/common.go b/internal/ui/common/common.go index 6e7c632474389aa5455295e4132818941bc18244..143b20305464da33d2f350a36176bab0e45b85aa 100644 --- a/internal/ui/common/common.go +++ b/internal/ui/common/common.go @@ -26,11 +26,16 @@ type Common struct { Styles *styles.Styles } -// Config returns the configuration associated with this [Common] instance. +// Config returns the pure-data configuration associated with this [Common] instance. func (c *Common) Config() *config.Config { return c.App.Config() } +// Store returns the config store associated with this [Common] instance. +func (c *Common) Store() *config.ConfigStore { + return c.App.Store() +} + // DefaultCommon returns the default common UI configurations. func DefaultCommon(app *app.App) *Common { s := styles.DefaultStyles() diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go index 9677763b2f4f2436376f5bf16ab58aed79140c68..cc37d742903d5a80bbcffcf1ff24fb24596dfccd 100644 --- a/internal/ui/dialog/api_key_input.go +++ b/internal/ui/dialog/api_key_input.go @@ -296,7 +296,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { Type: m.provider.Type, BaseURL: m.provider.APIEndpoint, } - err := providerConfig.TestConnection(m.com.Config().Resolver()) + err := providerConfig.TestConnection(m.com.Store().Resolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner elapsed := time.Since(start) @@ -312,9 +312,9 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { } func (m *APIKeyInput) saveKeyAndContinue() Action { - cfg := m.com.Config() + store := m.com.Store() - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value()) + err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value()) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/dialog/filepicker.go b/internal/ui/dialog/filepicker.go index 4b0b844e4ed869a4347af10e9d0b1b3c70a7d2f0..78f82a05f7e2e0db7a9bb561fb1b6248d8045513 100644 --- a/internal/ui/dialog/filepicker.go +++ b/internal/ui/dialog/filepicker.go @@ -123,7 +123,7 @@ func (f *FilePicker) SetImageCapabilities(caps *common.Capabilities) { // WorkingDir returns the current working directory of the [FilePicker]. func (f *FilePicker) WorkingDir() string { - wd := f.com.Config().WorkingDir() + wd := f.com.Store().WorkingDir() if len(wd) > 0 { return wd } diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 977f04a61e98f79adb9bb35777fac905508f47d5..434f699e91b4c227c4e54f6ff553affff76a1c43 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -490,7 +490,7 @@ func (m *Models) setProviderItems() error { if len(validRecentItems) != len(recentItems) { // FIXME: Does this need to be here? Is it mutating the config during a read? - if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { + if err := m.com.Store().SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { return fmt.Errorf("failed to update recent models: %w", err) } } diff --git a/internal/ui/dialog/oauth.go b/internal/ui/dialog/oauth.go index 93d5fe052db11d036d29d7790810807d5630bb57..2803070381e65bd0380a8ddab5f256481c117c15 100644 --- a/internal/ui/dialog/oauth.go +++ b/internal/ui/dialog/oauth.go @@ -373,9 +373,9 @@ func (d *OAuth) copyCodeAndOpenURL() tea.Cmd { } func (m *OAuth) saveKeyAndContinue() Action { - cfg := m.com.Config() + store := m.com.Store() - err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token) + err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.token) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/model/header.go b/internal/ui/model/header.go index 24254a0f69e5803e4bcbe89274f21db5b04ef541..06bb4ff92981b28625efb11683081e29fc55a21e 100644 --- a/internal/ui/model/header.go +++ b/internal/ui/model/header.go @@ -143,8 +143,7 @@ func renderHeaderDetails( metadata = dot + metadata const dirTrimLimit = 4 - cfg := com.Config() - cwd := fsext.DirTrim(fsext.PrettyPath(cfg.WorkingDir()), dirTrimLimit) + cwd := fsext.DirTrim(fsext.PrettyPath(com.Store().WorkingDir()), dirTrimLimit) cwd = t.Header.WorkingDir.Render(cwd) result := cwd + metadata diff --git a/internal/ui/model/landing.go b/internal/ui/model/landing.go index 45d376ff5ddc691b978e438ddef04a702af100f9..72c2671ccd297f4bade087f6b2cb960f6c6a92a9 100644 --- a/internal/ui/model/landing.go +++ b/internal/ui/model/landing.go @@ -22,7 +22,7 @@ func (m *UI) selectedLargeModel() *agent.Model { func (m *UI) landingView() string { t := m.com.Styles width := m.layout.main.Dx() - cwd := common.PrettyPath(t, m.com.Config().WorkingDir(), width) + cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width) parts := []string{ cwd, diff --git a/internal/ui/model/onboarding.go b/internal/ui/model/onboarding.go index 075067d75333fc539152f0041b4e5a3c2eed1c5e..5bba37ea8599944df77602014aa8c8d61dd73e80 100644 --- a/internal/ui/model/onboarding.go +++ b/internal/ui/model/onboarding.go @@ -19,7 +19,7 @@ import ( // markProjectInitialized marks the current project as initialized in the config. func (m *UI) markProjectInitialized() tea.Msg { // TODO: handle error so we show it in the tui footer - err := config.MarkProjectInitialized(m.com.Config()) + err := config.MarkProjectInitialized(m.com.Store()) if err != nil { slog.Error(err.Error()) } @@ -52,10 +52,10 @@ func (m *UI) initializeProject() tea.Cmd { if cmd := m.newSession(); cmd != nil { cmds = append(cmds, cmd) } - cfg := m.com.Config() + cfg := m.com.Store() initialize := func() tea.Msg { - initPrompt, err := agent.InitializePrompt(*cfg) + initPrompt, err := agent.InitializePrompt(cfg) if err != nil { return util.InfoMsg{Type: util.InfoTypeError, Msg: err.Error()} } @@ -77,10 +77,9 @@ func (m *UI) skipInitializeProject() tea.Cmd { // initializeView renders the project initialization prompt with Yes/No buttons. func (m *UI) initializeView() string { - cfg := m.com.Config() s := m.com.Styles.Initialize - cwd := home.Short(cfg.WorkingDir()) - initFile := cfg.Options.InitializeAs + cwd := home.Short(m.com.Store().WorkingDir()) + initFile := m.com.Config().Options.InitializeAs header := s.Header.Render("Would you like to initialize this project?") path := s.Accent.PaddingLeft(2).Render(cwd) diff --git a/internal/ui/model/sidebar.go b/internal/ui/model/sidebar.go index 88113a593034b09ed8d2859bc7628a103f5728b1..8849d86a8e1c8bda02092e3f165e85b8e32a8b1d 100644 --- a/internal/ui/model/sidebar.go +++ b/internal/ui/model/sidebar.go @@ -112,7 +112,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) { height := area.Dy() title := t.Muted.Width(width).MaxHeight(2).Render(m.session.Title) - cwd := common.PrettyPath(t, m.com.Config().WorkingDir(), width) + cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width) sidebarLogo := m.sidebarLogo if height < logoHeightBreakpoint { sidebarLogo = logo.SmallRender(m.com.Styles, width) @@ -138,7 +138,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) { lspSection := m.lspInfo(width, maxLSPs, true) mcpSection := m.mcpInfo(width, maxMCPs, true) - filesSection := m.filesInfo(m.com.Config().WorkingDir(), width, maxFiles, true) + filesSection := m.filesInfo(m.com.Store().WorkingDir(), width, maxFiles, true) uv.NewStyledString( lipgloss.NewStyle(). diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 89b3b37608500f1a02eea98d4ebfabeba262bcd1..66d57321824833e91b78539357d55013e4322e87 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -317,7 +317,7 @@ func New(com *common.Common) *UI { desiredFocus := uiFocusEditor if !com.Config().IsConfigured() { desiredState = uiOnboarding - } else if n, _ := config.ProjectNeedsInitialization(com.Config()); n { + } else if n, _ := config.ProjectNeedsInitialization(com.Store()); n { desiredState = uiInitialize } @@ -579,7 +579,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case mcp.EventPromptsListChanged: return m, handleMCPPromptsEvent(msg.Payload.Name) case mcp.EventToolsListChanged: - return m, handleMCPToolsEvent(m.com.Config(), msg.Payload.Name) + return m, handleMCPToolsEvent(m.com.Store(), msg.Payload.Name) case mcp.EventResourcesListChanged: return m, handleMCPResourcesEvent(msg.Payload.Name) } @@ -1301,7 +1301,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.Think = !currentModel.Think - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { return util.ReportError(err)() } m.com.App.UpdateAgentModel(context.TODO()) @@ -1342,7 +1342,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { // Attempt to import GitHub Copilot tokens from VSCode if available. if isCopilot && !isConfigured() && !msg.ReAuthenticate { - m.com.Config().ImportCopilot() + m.com.Store().ImportCopilot() } if !isConfigured() || msg.ReAuthenticate { @@ -1353,12 +1353,12 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } - if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, msg.ModelType, msg.Model); err != nil { cmds = append(cmds, util.ReportError(err)) } else if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok { // Ensure small model is set is unset. smallModel := m.com.App.GetDefaultSmallModel(providerID) - if err := cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, config.SelectedModelTypeSmall, smallModel); err != nil { cmds = append(cmds, util.ReportError(err)) } } @@ -1404,7 +1404,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.ReasoningEffort = msg.Effort - if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil { + if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { cmds = append(cmds, util.ReportError(err)) break } @@ -2016,7 +2016,7 @@ func (m *UI) View() tea.View { } v.MouseMode = tea.MouseModeCellMotion v.ReportFocus = m.caps.ReportFocusEvents - v.WindowTitle = "crush " + home.Short(m.com.Config().WorkingDir()) + v.WindowTitle = "crush " + home.Short(m.com.Store().WorkingDir()) canvas := uv.NewScreenBuffer(m.width, m.height) v.Cursor = m.Draw(canvas, canvas.Bounds()) @@ -2255,7 +2255,7 @@ func (m *UI) FullHelp() [][]key.Binding { func (m *UI) toggleCompactMode() tea.Cmd { m.forceCompactMode = !m.forceCompactMode - err := m.com.Config().SetCompactMode(m.forceCompactMode) + err := m.com.Store().SetCompactMode(config.ScopeGlobal, m.forceCompactMode) if err != nil { return util.ReportError(err) } @@ -2637,7 +2637,7 @@ func (m *UI) insertMCPResourceCompletion(item completions.ResourceCompletionValu return func() tea.Msg { contents, err := mcp.ReadResource( context.Background(), - m.com.Config(), + m.com.Store(), item.MCPName, item.URI, ) @@ -3299,7 +3299,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { lspSection := m.lspInfo(sectionWidth, maxItemsPerSection, false) mcpSection := m.mcpInfo(sectionWidth, maxItemsPerSection, false) - filesSection := m.filesInfo(m.com.Config().WorkingDir(), sectionWidth, maxItemsPerSection, false) + filesSection := m.filesInfo(m.com.Store().WorkingDir(), sectionWidth, maxItemsPerSection, false) sections := lipgloss.JoinHorizontal(lipgloss.Top, filesSection, " ", lspSection, " ", mcpSection) uv.NewStyledString( s.CompactDetails.View. @@ -3317,7 +3317,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string) tea.Cmd { load := func() tea.Msg { - prompt, err := commands.GetMCPPrompt(m.com.Config(), clientID, promptID, arguments) + prompt, err := commands.GetMCPPrompt(m.com.Store(), clientID, promptID, arguments) if err != nil { // TODO: make this better return util.ReportError(err)() @@ -3358,7 +3358,7 @@ func handleMCPPromptsEvent(name string) tea.Cmd { } } -func handleMCPToolsEvent(cfg *config.Config, name string) tea.Cmd { +func handleMCPToolsEvent(cfg *config.ConfigStore, name string) tea.Cmd { return func() tea.Msg { mcp.RefreshTools( context.Background(),