Detailed changes
@@ -24,7 +24,7 @@ const (
)
func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
- agentCfg, ok := c.cfg.Agents[config.AgentTask]
+ agentCfg, ok := c.cfg.Config().Agents[config.AgentTask]
if !ok {
return nil, errors.New("task agent not configured")
}
@@ -98,7 +98,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
}
- tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*")
+ tmpDir, err := os.MkdirTemp(c.cfg.Config().Options.DataDirectory, "crush-fetch-*")
if err != nil {
return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil
}
@@ -151,12 +151,12 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
}
- systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
+ systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), c.cfg)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
}
- smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
+ smallProviderCfg, ok := c.cfg.Config().Providers.Get(small.ModelCfg.Provider)
if !ok {
return fantasy.ToolResponse{}, errors.New("small model provider not configured")
}
@@ -167,7 +167,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
webFetchTool,
webSearchTool,
tools.NewGlobTool(tmpDir),
- tools.NewGrepTool(tmpDir, c.cfg.Tools.Grep),
+ tools.NewGrepTool(tmpDir, c.cfg.Config().Tools.Grep),
tools.NewSourcegraphTool(client),
tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, tmpDir),
}
@@ -177,7 +177,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
SmallModel: small,
SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix,
SystemPrompt: systemPrompt,
- DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
+ DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
IsYolo: c.permissions.SkipRequests(),
Sessions: c.sessions,
Messages: c.messages,
@@ -185,36 +185,36 @@ func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel
// NOTE(@andreynering): Set a fixed config to ensure cassettes match
// independently of user config on `$HOME/.config/crush/crush.json`.
- cfg.Options.Attribution = &config.Attribution{
+ cfg.Config().Options.Attribution = &config.Attribution{
TrailerStyle: "co-authored-by",
GeneratedWith: true,
}
// Clear some fields to avoid issues with VCR cassette matching.
- cfg.Options.SkillsPaths = nil
- cfg.Options.ContextPaths = nil
- cfg.LSP = nil
+ cfg.Config().Options.SkillsPaths = nil
+ cfg.Config().Options.ContextPaths = nil
+ cfg.Config().LSP = nil
- systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
+ systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), cfg)
if err != nil {
return nil, err
}
// Get the model name for the bash tool
modelName := large.Model() // fallback to ID if Name not available
- if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
+ if model := cfg.Config().GetModel(large.Provider(), large.Model()); model != nil {
modelName = model.Name
}
allTools := []fantasy.AgentTool{
- tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
+ tools.NewBashTool(env.permissions, env.workingDir, cfg.Config().Options.Attribution, modelName),
tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
tools.NewGlobTool(env.workingDir),
- tools.NewGrepTool(env.workingDir, cfg.Tools.Grep),
- tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
+ tools.NewGrepTool(env.workingDir, cfg.Config().Tools.Grep),
+ tools.NewLsTool(env.permissions, env.workingDir, cfg.Config().Tools.Ls),
tools.NewSourcegraphTool(r.GetDefaultClient()),
tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir),
tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
@@ -74,7 +74,7 @@ type Coordinator interface {
}
type coordinator struct {
- cfg *config.Config
+ cfg *config.ConfigStore
sessions session.Service
messages message.Service
permissions permission.Service
@@ -91,7 +91,7 @@ type coordinator struct {
func NewCoordinator(
ctx context.Context,
- cfg *config.Config,
+ cfg *config.ConfigStore,
sessions session.Service,
messages message.Service,
permissions permission.Service,
@@ -112,7 +112,7 @@ func NewCoordinator(
agents: make(map[string]SessionAgent),
}
- agentCfg, ok := cfg.Agents[config.AgentCoder]
+ agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
if !ok {
return nil, errCoderAgentNotConfigured
}
@@ -160,7 +160,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
attachments = filteredAttachments
}
- providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
+ providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
if !ok {
return nil, errModelProviderNotConfigured
}
@@ -383,14 +383,14 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
return nil, err
}
- largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
+ largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
result := NewSessionAgent(SessionAgentOptions{
LargeModel: large,
SmallModel: small,
SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
SystemPrompt: "",
IsSubAgent: isSubAgent,
- DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
+ DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
IsYolo: c.permissions.SkipRequests(),
Sessions: c.sessions,
Messages: c.messages,
@@ -399,7 +399,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age
})
c.readyWg.Go(func() error {
- systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
+ systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
if err != nil {
return err
}
@@ -439,14 +439,14 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
// Get the model name for the agent
modelName := ""
- if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
- if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
+ if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
+ if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
modelName = model.Name
}
}
allTools = append(allTools,
- tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
+ tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
tools.NewJobOutputTool(),
tools.NewJobKillTool(),
tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
@@ -454,20 +454,20 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
tools.NewGlobTool(c.cfg.WorkingDir()),
- tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
- tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
+ tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
+ tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
tools.NewSourcegraphTool(nil),
tools.NewTodosTool(c.sessions),
- tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
+ tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
)
// Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
- if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
+ if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
}
- if len(c.cfg.MCP) > 0 {
+ if len(c.cfg.Config().MCP) > 0 {
allTools = append(
allTools,
tools.NewListMCPResourcesTool(c.cfg, c.permissions),
@@ -513,16 +513,16 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
- largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
+ largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
if !ok {
return Model{}, Model{}, errLargeModelNotSelected
}
- smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
+ smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
if !ok {
return Model{}, Model{}, errSmallModelNotSelected
}
- largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
+ largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
if !ok {
return Model{}, Model{}, errLargeModelProviderNotConfigured
}
@@ -532,7 +532,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Mo
return Model{}, Model{}, err
}
- smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
+ smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
if !ok {
return Model{}, Model{}, errSmallModelProviderNotConfigured
}
@@ -620,7 +620,7 @@ func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map
opts = append(opts, anthropic.WithBaseURL(baseURL))
}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, anthropic.WithHTTPClient(httpClient))
}
@@ -632,7 +632,7 @@ func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[st
openai.WithAPIKey(apiKey),
openai.WithUseResponsesAPI(),
}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, openai.WithHTTPClient(httpClient))
}
@@ -649,7 +649,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
opts := []openrouter.Option{
openrouter.WithAPIKey(apiKey),
}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, openrouter.WithHTTPClient(httpClient))
}
@@ -663,7 +663,7 @@ func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]s
opts := []vercel.Option{
vercel.WithAPIKey(apiKey),
}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, vercel.WithHTTPClient(httpClient))
}
@@ -683,8 +683,8 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers
var httpClient *http.Client
if providerID == string(catwalk.InferenceProviderCopilot) {
opts = append(opts, openaicompat.WithUseResponsesAPI())
- httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
- } else if c.cfg.Options.Debug {
+ httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
+ } else if c.cfg.Config().Options.Debug {
httpClient = log.NewHTTPClient()
}
if httpClient != nil {
@@ -708,7 +708,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str
azure.WithAPIKey(apiKey),
azure.WithUseResponsesAPI(),
}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, azure.WithHTTPClient(httpClient))
}
@@ -727,7 +727,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str
func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
var opts []bedrock.Option
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, bedrock.WithHTTPClient(httpClient))
}
@@ -746,7 +746,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st
google.WithBaseURL(baseURL),
google.WithGeminiAPIKey(apiKey),
}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, google.WithHTTPClient(httpClient))
}
@@ -758,7 +758,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st
func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
opts := []google.Option{}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, google.WithHTTPClient(httpClient))
}
@@ -779,7 +779,7 @@ func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provid
hyper.WithBaseURL(baseURL),
hyper.WithAPIKey(apiKey),
}
- if c.cfg.Options.Debug {
+ if c.cfg.Config().Options.Debug {
httpClient := log.NewHTTPClient()
opts = append(opts, hyper.WithHTTPClient(httpClient))
}
@@ -887,7 +887,7 @@ func (c *coordinator) UpdateModels(ctx context.Context) error {
}
c.currentAgent.SetModels(large, small)
- agentCfg, ok := c.cfg.Agents[config.AgentCoder]
+ agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
if !ok {
return errCoderAgentNotConfigured
}
@@ -909,7 +909,7 @@ func (c *coordinator) QueuedPromptsList(sessionID string) []string {
}
func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
- providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
+ providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
if !ok {
return errModelProviderNotConfigured
}
@@ -922,7 +922,7 @@ func (c *coordinator) isUnauthorized(err error) bool {
}
func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
- if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
+ if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
return err
}
@@ -940,7 +940,7 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con
}
providerCfg.APIKey = newAPIKey
- c.cfg.Providers.Set(providerCfg.ID, providerCfg)
+ c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
if err := c.UpdateModels(ctx); err != nil {
return err
@@ -984,7 +984,7 @@ func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (f
maxTokens = model.ModelCfg.MaxTokens
}
- providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
+ providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
if !ok {
return fantasy.ToolResponse{}, errModelProviderNotConfigured
}
@@ -44,7 +44,7 @@ func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOp
func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
cfg, err := config.Init(env.workingDir, "", false)
require.NoError(t, err)
- cfg.Providers.Set(providerID, providerCfg)
+ cfg.Config().Providers.Set(providerID, providerCfg)
return &coordinator{
cfg: cfg,
sessions: env.sessions,
@@ -76,13 +76,13 @@ func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
return p, nil
}
-func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
+func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
t, err := template.New(p.name).Parse(p.template)
if err != nil {
return "", fmt.Errorf("parsing template: %w", err)
}
var sb strings.Builder
- d, err := p.promptData(ctx, provider, model, cfg)
+ d, err := p.promptData(ctx, provider, model, store)
if err != nil {
return "", err
}
@@ -104,11 +104,11 @@ func processFile(filePath string) *ContextFile {
}
}
-func processContextPath(p string, cfg config.Config) []ContextFile {
+func processContextPath(p string, store *config.ConfigStore) []ContextFile {
var contexts []ContextFile
fullPath := p
if !filepath.IsAbs(p) {
- fullPath = filepath.Join(cfg.WorkingDir(), p)
+ fullPath = filepath.Join(store.WorkingDir(), p)
}
info, err := os.Stat(fullPath)
if err != nil {
@@ -136,11 +136,11 @@ func processContextPath(p string, cfg config.Config) []ContextFile {
}
// expandPath expands ~ and environment variables in file paths
-func expandPath(path string, cfg config.Config) string {
+func expandPath(path string, store *config.ConfigStore) string {
path = home.Long(path)
// Handle environment variable expansion using the same pattern as config
if strings.HasPrefix(path, "$") {
- if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
+ if expanded, err := store.Resolver().ResolveValue(path); err == nil {
path = expanded
}
}
@@ -148,19 +148,20 @@ func expandPath(path string, cfg config.Config) string {
return path
}
-func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
- workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
+func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
+ workingDir := cmp.Or(p.workingDir, store.WorkingDir())
platform := cmp.Or(p.platform, runtime.GOOS)
files := map[string][]ContextFile{}
+ cfg := store.Config()
for _, pth := range cfg.Options.ContextPaths {
- expanded := expandPath(pth, cfg)
+ expanded := expandPath(pth, store)
pathKey := strings.ToLower(expanded)
if _, ok := files[pathKey]; ok {
continue
}
- content := processContextPath(expanded, cfg)
+ content := processContextPath(expanded, store)
files[pathKey] = content
}
@@ -169,18 +170,18 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg con
if len(cfg.Options.SkillsPaths) > 0 {
expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
for _, pth := range cfg.Options.SkillsPaths {
- expandedPaths = append(expandedPaths, expandPath(pth, cfg))
+ expandedPaths = append(expandedPaths, expandPath(pth, store))
}
if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
availSkillXML = skills.ToPromptXML(discoveredSkills)
}
}
- isGit := isGitRepo(cfg.WorkingDir())
+ isGit := isGitRepo(store.WorkingDir())
data := PromptDat{
Provider: provider,
Model: model,
- Config: cfg,
+ Config: *cfg,
WorkingDir: filepath.ToSlash(workingDir),
IsGitRepo: isGit,
Platform: platform,
@@ -189,7 +190,7 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg con
}
if isGit {
var err error
- data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
+ data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
if err != nil {
return PromptDat{}, err
}
@@ -33,7 +33,7 @@ func taskPrompt(opts ...prompt.Option) (*prompt.Prompt, error) {
return systemPrompt, nil
}
-func InitializePrompt(cfg config.Config) (string, error) {
+func InitializePrompt(cfg *config.ConfigStore) (string, error) {
systemPrompt, err := prompt.NewPrompt("initialize", string(initializePromptTmpl))
if err != nil {
return "", err
@@ -28,7 +28,7 @@ const ListMCPResourcesToolName = "list_mcp_resources"
//go:embed list_mcp_resources.md
var listMCPResourcesDescription []byte
-func NewListMCPResourcesTool(cfg *config.Config, permissions permission.Service) fantasy.AgentTool {
+func NewListMCPResourcesTool(cfg *config.ConfigStore, permissions permission.Service) fantasy.AgentTool {
return fantasy.NewParallelAgentTool(
ListMCPResourcesToolName,
string(listMCPResourcesDescription),
@@ -11,7 +11,7 @@ import (
)
// GetMCPTools gets all the currently available MCP tools.
-func GetMCPTools(permissions permission.Service, cfg *config.Config, wd string) []*Tool {
+func GetMCPTools(permissions permission.Service, cfg *config.ConfigStore, wd string) []*Tool {
var result []*Tool
for mcpName, tools := range mcp.Tools() {
for _, tool := range tools {
@@ -31,7 +31,7 @@ func GetMCPTools(permissions permission.Service, cfg *config.Config, wd string)
type Tool struct {
mcpName string
tool *mcp.Tool
- cfg *config.Config
+ cfg *config.ConfigStore
permissions permission.Service
workingDir string
providerOptions fantasy.ProviderOptions
@@ -163,11 +163,11 @@ func Close(ctx context.Context) error {
}
// Initialize initializes MCP clients based on the provided configuration.
-func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
+func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) {
slog.Info("Initializing MCP clients")
var wg sync.WaitGroup
// Initialize states for all configured MCPs
- for name, m := range cfg.MCP {
+ for name, m := range cfg.Config().MCP {
if m.Disabled {
updateState(name, StateDisabled, nil, nil, Counts{})
slog.Debug("Skipping disabled MCP", "name", name)
@@ -253,13 +253,13 @@ func WaitForInit(ctx context.Context) error {
}
}
-func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) {
+func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
sess, ok := sessions.Get(name)
if !ok {
return nil, fmt.Errorf("mcp '%s' not available", name)
}
- m := cfg.MCP[name]
+ m := cfg.Config().MCP[name]
state, _ := states.Get(name)
timeout := mcpTimeout(m)
@@ -20,7 +20,7 @@ func Prompts() iter.Seq2[string, []*Prompt] {
}
// GetPromptMessages retrieves the content of an MCP prompt with the given arguments.
-func GetPromptMessages(ctx context.Context, cfg *config.Config, clientName, promptName string, args map[string]string) ([]string, error) {
+func GetPromptMessages(ctx context.Context, cfg *config.ConfigStore, clientName, promptName string, args map[string]string) ([]string, error) {
c, err := getOrRenewClient(ctx, cfg, clientName)
if err != nil {
return nil, err
@@ -24,7 +24,7 @@ func Resources() iter.Seq2[string, []*Resource] {
}
// ListResources returns the current resources for an MCP server.
-func ListResources(ctx context.Context, cfg *config.Config, name string) ([]*Resource, error) {
+func ListResources(ctx context.Context, cfg *config.ConfigStore, name string) ([]*Resource, error) {
session, err := getOrRenewClient(ctx, cfg, name)
if err != nil {
return nil, err
@@ -43,7 +43,7 @@ func ListResources(ctx context.Context, cfg *config.Config, name string) ([]*Res
}
// ReadResource reads the contents of a resource from an MCP server.
-func ReadResource(ctx context.Context, cfg *config.Config, name, uri string) ([]*ResourceContents, error) {
+func ReadResource(ctx context.Context, cfg *config.ConfigStore, name, uri string) ([]*ResourceContents, error) {
session, err := getOrRenewClient(ctx, cfg, name)
if err != nil {
return nil, err
@@ -32,7 +32,7 @@ func Tools() iter.Seq2[string, []*Tool] {
}
// RunTool runs an MCP tool with the given input parameters.
-func RunTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (ToolResult, error) {
+func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string, input string) (ToolResult, error) {
var args map[string]any
if err := json.Unmarshal([]byte(input), &args); err != nil {
return ToolResult{}, fmt.Errorf("error parsing parameters: %s", err)
@@ -108,7 +108,7 @@ func RunTool(ctx context.Context, cfg *config.Config, name, toolName string, inp
// RefreshTools gets the updated list of tools from the MCP and updates the
// global state.
-func RefreshTools(ctx context.Context, cfg *config.Config, name string) {
+func RefreshTools(ctx context.Context, cfg *config.ConfigStore, name string) {
session, ok := sessions.Get(name)
if !ok {
slog.Warn("Refresh tools: no session", "name", name)
@@ -139,7 +139,7 @@ func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
return result.Tools, nil
}
-func updateTools(cfg *config.Config, name string, tools []*Tool) int {
+func updateTools(cfg *config.ConfigStore, name string, tools []*Tool) int {
tools = filterDisabledTools(cfg, name, tools)
if len(tools) == 0 {
allTools.Del(name)
@@ -150,8 +150,8 @@ func updateTools(cfg *config.Config, name string, tools []*Tool) int {
}
// filterDisabledTools removes tools that are disabled via config.
-func filterDisabledTools(cfg *config.Config, mcpName string, tools []*Tool) []*Tool {
- mcpCfg, ok := cfg.MCP[mcpName]
+func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) []*Tool {
+ mcpCfg, ok := cfg.Config().MCP[mcpName]
if !ok || len(mcpCfg.DisabledTools) == 0 {
return tools
}
@@ -30,7 +30,7 @@ const ReadMCPResourceToolName = "read_mcp_resource"
//go:embed read_mcp_resource.md
var readMCPResourceDescription []byte
-func NewReadMCPResourceTool(cfg *config.Config, permissions permission.Service) fantasy.AgentTool {
+func NewReadMCPResourceTool(cfg *config.ConfigStore, permissions permission.Service) fantasy.AgentTool {
return fantasy.NewParallelAgentTool(
ReadMCPResourceToolName,
string(readMCPResourceDescription),
@@ -61,7 +61,7 @@ type App struct {
LSPManager *lsp.Manager
- config *config.Config
+ config *config.ConfigStore
serviceEventsWG *sync.WaitGroup
eventsCtx context.Context
@@ -75,11 +75,12 @@ type App struct {
}
// New initializes a new application instance.
-func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
+func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore) (*App, error) {
q := db.New(conn)
sessions := session.NewService(q, conn)
messages := message.NewService(q)
files := history.NewService(q, conn)
+ cfg := store.Config()
skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
var allowedTools []string
if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
@@ -90,13 +91,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
Sessions: sessions,
Messages: messages,
History: files,
- Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
+ Permissions: permission.NewPermissionService(store.WorkingDir(), skipPermissionsRequests, allowedTools),
FileTracker: filetracker.NewService(q),
- LSPManager: lsp.NewManager(cfg),
+ LSPManager: lsp.NewManager(store),
globalCtx: ctx,
- config: cfg,
+ config: store,
events: make(chan tea.Msg, 100),
serviceEventsWG: &sync.WaitGroup{},
@@ -109,7 +110,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
// Check for updates in the background.
go app.checkForUpdates(ctx)
- go mcp.Initialize(ctx, app.Permissions, cfg)
+ go mcp.Initialize(ctx, app.Permissions, store)
// cleanup database upon app shutdown
app.cleanupFuncs = append(
@@ -141,8 +142,13 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
return app, nil
}
-// Config returns the application configuration.
+// Config returns the pure-data configuration.
func (app *App) Config() *config.Config {
+ return app.config.Config()
+}
+
+// Store returns the config store.
+func (app *App) Store() *config.ConfigStore {
return app.config
}
@@ -178,7 +184,7 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt,
}
stderrTTY = term.IsTerminal(os.Stderr.Fd())
stdinTTY = term.IsTerminal(os.Stdin.Fd())
- progress = app.config.Options.Progress == nil || *app.config.Options.Progress
+ progress = app.config.Config().Options.Progress == nil || *app.config.Config().Options.Progress
if !hideSpinner && stderrTTY {
t := styles.DefaultStyles()
@@ -331,7 +337,7 @@ func (app *App) UpdateAgentModel(ctx context.Context) error {
// If largeModel is provided but smallModel is not, the small model defaults to
// the provider's default small model.
func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, smallModel string) error {
- providers := app.config.Providers.Copy()
+ providers := app.config.Config().Providers.Copy()
largeMatches, smallMatches, err := findModels(providers, largeModel, smallModel)
if err != nil {
@@ -348,7 +354,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel,
}
largeProviderID = found.provider
slog.Info("Overriding large model for non-interactive run", "provider", found.provider, "model", found.modelID)
- app.config.Models[config.SelectedModelTypeLarge] = config.SelectedModel{
+ app.config.Config().Models[config.SelectedModelTypeLarge] = config.SelectedModel{
Provider: found.provider,
Model: found.modelID,
}
@@ -362,7 +368,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel,
return err
}
slog.Info("Overriding small model for non-interactive run", "provider", found.provider, "model", found.modelID)
- app.config.Models[config.SelectedModelTypeSmall] = config.SelectedModel{
+ app.config.Config().Models[config.SelectedModelTypeSmall] = config.SelectedModel{
Provider: found.provider,
Model: found.modelID,
}
@@ -370,7 +376,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel,
case largeModel != "":
// No small model specified, but large model was - use provider's default.
smallCfg := app.GetDefaultSmallModel(largeProviderID)
- app.config.Models[config.SelectedModelTypeSmall] = smallCfg
+ app.config.Config().Models[config.SelectedModelTypeSmall] = smallCfg
}
return app.AgentCoordinator.UpdateModels(ctx)
@@ -379,7 +385,7 @@ func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel,
// GetDefaultSmallModel returns the default small model for the given
// provider. Falls back to the large model if no default is found.
func (app *App) GetDefaultSmallModel(providerID string) config.SelectedModel {
- cfg := app.config
+ cfg := app.config.Config()
largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
// Find the provider in the known providers list to get its default small model.
@@ -481,7 +487,7 @@ func setupSubscriber[T any](
}
func (app *App) InitCoderAgent(ctx context.Context) error {
- coderAgentCfg := app.config.Agents[config.AgentCoder]
+ coderAgentCfg := app.config.Config().Agents[config.AgentCoder]
if coderAgentCfg.ID == "" {
return fmt.Errorf("coder agent configuration is missing")
}
@@ -52,16 +52,16 @@ crush login copilot
}
switch provider {
case "hyper":
- return loginHyper(app.Config())
+ return loginHyper(app.Store())
case "copilot", "github", "github-copilot":
- return loginCopilot(app.Config())
+ return loginCopilot(app.Store())
default:
return fmt.Errorf("unknown platform: %s", args[0])
}
},
}
-func loginHyper(cfg *config.Config) error {
+func loginHyper(cfg *config.ConfigStore) error {
if !hyperp.Enabled() {
return fmt.Errorf("hyper not enabled")
}
@@ -112,8 +112,8 @@ func loginHyper(cfg *config.Config) error {
}
if err := cmp.Or(
- cfg.SetConfigField("providers.hyper.api_key", token.AccessToken),
- cfg.SetConfigField("providers.hyper.oauth", token),
+ cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
+ cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.oauth", token),
); err != nil {
return err
}
@@ -123,10 +123,10 @@ func loginHyper(cfg *config.Config) error {
return nil
}
-func loginCopilot(cfg *config.Config) error {
+func loginCopilot(cfg *config.ConfigStore) error {
ctx := getLoginContext()
- if cfg.HasConfigField("providers.copilot.oauth") {
+ if cfg.HasConfigField(config.ScopeGlobal, "providers.copilot.oauth") {
fmt.Println("You are already logged in to GitHub Copilot.")
return nil
}
@@ -177,8 +177,8 @@ func loginCopilot(cfg *config.Config) error {
}
if err := cmp.Or(
- cfg.SetConfigField("providers.copilot.api_key", token.AccessToken),
- cfg.SetConfigField("providers.copilot.oauth", token),
+ cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
+ cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.oauth", token),
); err != nil {
return err
}
@@ -55,7 +55,7 @@ var logsCmd = &cobra.Command{
if err != nil {
return fmt.Errorf("failed to load configuration: %v", err)
}
- logsFile := filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log")
+ logsFile := filepath.Join(cfg.Config().Options.DataDirectory, "logs", "crush.log")
_, err = os.Stat(logsFile)
if os.IsNotExist(err) {
log.Warn("Looks like you are not in a crush project. No logs found.")
@@ -38,7 +38,7 @@ crush models gpt5`,
return err
}
- if !cfg.IsConfigured() {
+ if !cfg.Config().IsConfigured() {
return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
}
@@ -55,7 +55,7 @@ crush models gpt5`,
var providerIDs []string
providerModels := make(map[string][]string)
- for providerID, provider := range cfg.Providers.Seq2() {
+ for providerID, provider := range cfg.Config().Providers.Seq2() {
if provider.Disable {
continue
}
@@ -189,11 +189,12 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
return nil, err
}
- cfg, err := config.Init(cwd, dataDir, debug)
+ store, err := config.Init(cwd, dataDir, debug)
if err != nil {
return nil, err
}
+ cfg := store.Config()
if cfg.Permissions == nil {
cfg.Permissions = &config.Permissions{}
}
@@ -215,7 +216,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
return nil, err
}
- appInstance, err := app.New(ctx, conn, cfg)
+ appInstance, err := app.New(ctx, conn, store)
if err != nil {
slog.Error("Failed to create app instance", "error", err)
return nil, err
@@ -131,7 +131,7 @@ func runStats(cmd *cobra.Command, _ []string) error {
if err != nil {
return fmt.Errorf("failed to initialize config: %w", err)
}
- dataDir = cfg.Options.DataDirectory
+ dataDir = cfg.Config().Options.DataDirectory
}
conn, err := db.Connect(ctx, dataDir)
@@ -227,7 +227,7 @@ func isMarkdownFile(name string) bool {
return strings.HasSuffix(strings.ToLower(name), ".md")
}
-func GetMCPPrompt(cfg *config.Config, clientID, promptID string, args map[string]string) (string, error) {
+func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) {
// TODO: we should pass the context down
result, err := mcp.GetPromptMessages(context.Background(), cfg, clientID, promptID, args)
if err != nil {
@@ -8,22 +8,16 @@ import (
"maps"
"net/http"
"net/url"
- "os"
- "path/filepath"
"slices"
"strings"
"time"
"charm.land/catwalk/pkg/catwalk"
- hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/copilot"
- "github.com/charmbracelet/crush/internal/oauth/hyper"
"github.com/invopop/jsonschema"
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
)
const (
@@ -398,17 +392,6 @@ type Config struct {
Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"`
Agents map[string]Agent `json:"-"`
-
- // Internal
- workingDir string `json:"-"`
- // TODO: find a better way to do this this should probably not be part of the config
- resolver VariableResolver
- dataConfigDir string `json:"-"`
- knownProviders []catwalk.Provider `json:"-"`
-}
-
-func (c *Config) WorkingDir() string {
- return c.workingDir
}
func (c *Config) EnabledProviders() []ProviderConfig {
@@ -472,235 +455,8 @@ func (c *Config) SmallModel() *catwalk.Model {
return c.GetModel(model.Provider, model.Model)
}
-func (c *Config) SetCompactMode(enabled bool) error {
- if c.Options == nil {
- c.Options = &Options{}
- }
- c.Options.TUI.CompactMode = enabled
- return c.SetConfigField("options.tui.compact_mode", enabled)
-}
-
-func (c *Config) Resolve(key string) (string, error) {
- if c.resolver == nil {
- return "", fmt.Errorf("no variable resolver configured")
- }
- return c.resolver.ResolveValue(key)
-}
-
-func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
- c.Models[modelType] = model
- if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
- return fmt.Errorf("failed to update preferred model: %w", err)
- }
- if err := c.recordRecentModel(modelType, model); err != nil {
- return err
- }
- return nil
-}
-
-func (c *Config) HasConfigField(key string) bool {
- data, err := os.ReadFile(c.dataConfigDir)
- if err != nil {
- return false
- }
- return gjson.Get(string(data), key).Exists()
-}
-
-func (c *Config) SetConfigField(key string, value any) error {
- data, err := os.ReadFile(c.dataConfigDir)
- if err != nil {
- if os.IsNotExist(err) {
- data = []byte("{}")
- } else {
- return fmt.Errorf("failed to read config file: %w", err)
- }
- }
-
- newValue, err := sjson.Set(string(data), key, value)
- if err != nil {
- return fmt.Errorf("failed to set config field %s: %w", key, err)
- }
- if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil {
- return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err)
- }
- if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil {
- return fmt.Errorf("failed to write config file: %w", err)
- }
- return nil
-}
-
-func (c *Config) RemoveConfigField(key string) error {
- data, err := os.ReadFile(c.dataConfigDir)
- if err != nil {
- return fmt.Errorf("failed to read config file: %w", err)
- }
-
- newValue, err := sjson.Delete(string(data), key)
- if err != nil {
- return fmt.Errorf("failed to delete config field %s: %w", key, err)
- }
- if err := os.MkdirAll(filepath.Dir(c.dataConfigDir), 0o755); err != nil {
- return fmt.Errorf("failed to create config directory %q: %w", c.dataConfigDir, err)
- }
- if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o600); err != nil {
- return fmt.Errorf("failed to write config file: %w", err)
- }
- return nil
-}
-
-// RefreshOAuthToken refreshes the OAuth token for the given provider.
-func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error {
- providerConfig, exists := c.Providers.Get(providerID)
- if !exists {
- return fmt.Errorf("provider %s not found", providerID)
- }
-
- if providerConfig.OAuthToken == nil {
- return fmt.Errorf("provider %s does not have an OAuth token", providerID)
- }
-
- var newToken *oauth.Token
- var refreshErr error
- switch providerID {
- case string(catwalk.InferenceProviderCopilot):
- newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
- case hyperp.Name:
- newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
- default:
- return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
- }
- if refreshErr != nil {
- return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
- }
-
- slog.Info("Successfully refreshed OAuth token", "provider", providerID)
- providerConfig.OAuthToken = newToken
- providerConfig.APIKey = newToken.AccessToken
-
- switch providerID {
- case string(catwalk.InferenceProviderCopilot):
- providerConfig.SetupGitHubCopilot()
- }
-
- c.Providers.Set(providerID, providerConfig)
-
- if err := cmp.Or(
- c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
- c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken),
- ); err != nil {
- return fmt.Errorf("failed to persist refreshed token: %w", err)
- }
-
- return nil
-}
-
-func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error {
- var providerConfig ProviderConfig
- var exists bool
- var setKeyOrToken func()
-
- switch v := apiKey.(type) {
- case string:
- if err := c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
- return fmt.Errorf("failed to save api key to config file: %w", err)
- }
- setKeyOrToken = func() { providerConfig.APIKey = v }
- case *oauth.Token:
- if err := cmp.Or(
- c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
- c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), v),
- ); err != nil {
- return err
- }
- setKeyOrToken = func() {
- providerConfig.APIKey = v.AccessToken
- providerConfig.OAuthToken = v
- switch providerID {
- case string(catwalk.InferenceProviderCopilot):
- providerConfig.SetupGitHubCopilot()
- }
- }
- }
-
- providerConfig, exists = c.Providers.Get(providerID)
- if exists {
- setKeyOrToken()
- c.Providers.Set(providerID, providerConfig)
- return nil
- }
-
- var foundProvider *catwalk.Provider
- for _, p := range c.knownProviders {
- if string(p.ID) == providerID {
- foundProvider = &p
- break
- }
- }
-
- if foundProvider != nil {
- // Create new provider config based on known provider
- providerConfig = ProviderConfig{
- ID: providerID,
- Name: foundProvider.Name,
- BaseURL: foundProvider.APIEndpoint,
- Type: foundProvider.Type,
- Disable: false,
- ExtraHeaders: make(map[string]string),
- ExtraParams: make(map[string]string),
- Models: foundProvider.Models,
- }
- setKeyOrToken()
- } else {
- return fmt.Errorf("provider with ID %s not found in known providers", providerID)
- }
- // Store the updated provider config
- c.Providers.Set(providerID, providerConfig)
- return nil
-}
-
const maxRecentModelsPerType = 5
-func (c *Config) recordRecentModel(modelType SelectedModelType, model SelectedModel) error {
- if model.Provider == "" || model.Model == "" {
- return nil
- }
-
- if c.RecentModels == nil {
- c.RecentModels = make(map[SelectedModelType][]SelectedModel)
- }
-
- eq := func(a, b SelectedModel) bool {
- return a.Provider == b.Provider && a.Model == b.Model
- }
-
- entry := SelectedModel{
- Provider: model.Provider,
- Model: model.Model,
- }
-
- current := c.RecentModels[modelType]
- withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
- return eq(existing, entry)
- })
-
- updated := append([]SelectedModel{entry}, withoutCurrent...)
- if len(updated) > maxRecentModelsPerType {
- updated = updated[:maxRecentModelsPerType]
- }
-
- if slices.EqualFunc(current, updated, eq) {
- return nil
- }
-
- c.RecentModels[modelType] = updated
-
- if err := c.SetConfigField(fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
- return fmt.Errorf("failed to persist recent models: %w", err)
- }
-
- return nil
-}
-
func allToolNames() []string {
return []string{
"agent",
@@ -780,10 +536,6 @@ func (c *Config) SetupAgents() {
c.Agents = agents
}
-func (c *Config) Resolver() VariableResolver {
- return c.resolver
-}
-
func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
var (
providerID = catwalk.InferenceProvider(c.ID)
@@ -1,48 +1 @@
package config
-
-import (
- "cmp"
- "context"
- "log/slog"
- "testing"
-
- "charm.land/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/oauth"
- "github.com/charmbracelet/crush/internal/oauth/copilot"
-)
-
-func (c *Config) ImportCopilot() (*oauth.Token, bool) {
- if testing.Testing() {
- return nil, false
- }
-
- if c.HasConfigField("providers.copilot.api_key") || c.HasConfigField("providers.copilot.oauth") {
- return nil, false
- }
-
- diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
- if !hasDiskToken {
- return nil, false
- }
-
- slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
- token, err := copilot.RefreshToken(context.TODO(), diskToken)
- if err != nil {
- slog.Error("Unable to import GitHub Copilot token", "error", err)
- return nil, false
- }
-
- if err := c.SetProviderAPIKey(string(catwalk.InferenceProviderCopilot), token); err != nil {
- return token, false
- }
-
- if err := cmp.Or(
- c.SetConfigField("providers.copilot.api_key", token.AccessToken),
- c.SetConfigField("providers.copilot.oauth", token),
- ); err != nil {
- slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
- }
-
- slog.Info("GitHub Copilot successfully imported")
- return token, true
-}
@@ -18,19 +18,20 @@ type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
-func Init(workingDir, dataDir string, debug bool) (*Config, error) {
- cfg, err := Load(workingDir, dataDir, debug)
+func Init(workingDir, dataDir string, debug bool) (*ConfigStore, error) {
+ store, err := Load(workingDir, dataDir, debug)
if err != nil {
return nil, err
}
- return cfg, nil
+ return store, nil
}
-func ProjectNeedsInitialization(cfg *Config) (bool, error) {
- if cfg == nil {
+func ProjectNeedsInitialization(store *ConfigStore) (bool, error) {
+ if store == nil {
return false, fmt.Errorf("config not loaded")
}
+ cfg := store.Config()
flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
_, err := os.Stat(flagFilePath)
@@ -42,7 +43,7 @@ func ProjectNeedsInitialization(cfg *Config) (bool, error) {
return false, fmt.Errorf("failed to check init flag file: %w", err)
}
- someContextFileExists, err := contextPathsExist(cfg.WorkingDir())
+ someContextFileExists, err := contextPathsExist(store.WorkingDir())
if err != nil {
return false, fmt.Errorf("failed to check for context files: %w", err)
}
@@ -51,7 +52,7 @@ func ProjectNeedsInitialization(cfg *Config) (bool, error) {
}
// If the working directory has no non-ignored files, skip initialization step
- empty, err := dirHasNoVisibleFiles(cfg.WorkingDir())
+ empty, err := dirHasNoVisibleFiles(store.WorkingDir())
if err != nil {
return false, fmt.Errorf("failed to check if directory is empty: %w", err)
}
@@ -90,7 +91,7 @@ func contextPathsExist(dir string) (bool, error) {
return false, nil
}
-// dirHasNoVisibleFiles returns true if the directory has no files/dirs after applying ignore rules
+// dirHasNoVisibleFiles returns true if the directory has no files/dirs after applying ignore rules.
func dirHasNoVisibleFiles(dir string) (bool, error) {
files, _, err := fsext.ListDirectory(dir, nil, 1, 1)
if err != nil {
@@ -99,11 +100,11 @@ func dirHasNoVisibleFiles(dir string) (bool, error) {
return len(files) == 0, nil
}
-func MarkProjectInitialized(cfg *Config) error {
- if cfg == nil {
+func MarkProjectInitialized(store *ConfigStore) error {
+ if store == nil {
return fmt.Errorf("config not loaded")
}
- flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
+ flagFilePath := filepath.Join(store.Config().Options.DataDirectory, InitFlagFilename)
file, err := os.Create(flagFilePath)
if err != nil {
@@ -114,13 +115,13 @@ func MarkProjectInitialized(cfg *Config) error {
return nil
}
-func HasInitialDataConfig(cfg *Config) bool {
- if cfg == nil {
+func HasInitialDataConfig(store *ConfigStore) bool {
+ if store == nil {
return false
}
cfgPath := GlobalConfigData()
if _, err := os.Stat(cfgPath); err != nil {
return false
}
- return cfg.IsConfigured()
+ return store.Config().IsConfigured()
}
@@ -29,8 +29,9 @@ import (
const defaultCatwalkURL = "https://catwalk.charm.sh"
-// Load loads the configuration from the default paths.
-func Load(workingDir, dataDir string, debug bool) (*Config, error) {
+// Load loads the configuration from the default paths and returns a
+// ConfigStore that owns both the pure-data Config and all runtime state.
+func Load(workingDir, dataDir string, debug bool) (*ConfigStore, error) {
configPaths := lookupConfigs(workingDir)
cfg, err := loadFromConfigPaths(configPaths)
@@ -38,10 +39,15 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
}
- cfg.dataConfigDir = GlobalConfigData()
-
cfg.setDefaults(workingDir, dataDir)
+ store := &ConfigStore{
+ config: cfg,
+ workingDir: workingDir,
+ globalDataPath: GlobalConfigData(),
+ workspacePath: filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName)),
+ }
+
if debug {
cfg.Options.Debug = true
}
@@ -52,6 +58,18 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
cfg.Options.Debug,
)
+ // Load workspace config last so it has highest priority.
+ if wsData, err := os.ReadFile(store.workspacePath); err == nil && len(wsData) > 0 {
+ merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData))
+ if mergeErr == nil {
+ // Preserve defaults that setDefaults already applied.
+ dataDir := cfg.Options.DataDirectory
+ *cfg = *merged
+ cfg.setDefaults(workingDir, dataDir)
+ store.config = cfg
+ }
+ }
+
if !isInsideWorktree() {
const depth = 2
const items = 100
@@ -72,26 +90,36 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
if err != nil {
return nil, err
}
- cfg.knownProviders = providers
+ store.knownProviders = providers
env := env.New()
// Configure providers
valueResolver := NewShellVariableResolver(env)
- cfg.resolver = valueResolver
- if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil {
+ store.resolver = valueResolver
+ if err := cfg.configureProviders(store, env, valueResolver, store.knownProviders); err != nil {
return nil, fmt.Errorf("failed to configure providers: %w", err)
}
if !cfg.IsConfigured() {
slog.Warn("No providers configured")
- return cfg, nil
+ return store, nil
}
- if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil {
+ if err := configureSelectedModels(store, store.knownProviders); err != nil {
return nil, fmt.Errorf("failed to configure selected models: %w", err)
}
- cfg.SetupAgents()
- return cfg, nil
+ store.SetupAgents()
+ return store, nil
+}
+
+// mustMarshalConfig marshals the config to JSON bytes, returning empty JSON on
+// error.
+func mustMarshalConfig(cfg *Config) []byte {
+ data, err := json.Marshal(cfg)
+ if err != nil {
+ return []byte("{}")
+ }
+ return data
}
func PushPopCrushEnv() func() {
@@ -122,7 +150,7 @@ func PushPopCrushEnv() func() {
return restore
}
-func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
+func (c *Config) configureProviders(store *ConfigStore, env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
knownProviderNames := make(map[string]bool)
restore := PushPopCrushEnv()
defer restore()
@@ -209,7 +237,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
switch {
case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil:
// Claude Code subscription is not supported anymore. Remove to show onboarding.
- c.RemoveConfigField("providers.anthropic")
+ store.RemoveConfigField(ScopeGlobal, "providers.anthropic")
c.Providers.Del(string(p.ID))
continue
case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil:
@@ -340,7 +368,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
}
func (c *Config) setDefaults(workingDir, dataDir string) {
- c.workingDir = workingDir
if c.Options == nil {
c.Options = &Options{}
}
@@ -524,7 +551,8 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large
return largeModel, smallModel, err
}
-func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
+func configureSelectedModels(store *ConfigStore, knownProviders []catwalk.Provider) error {
+ c := store.config
defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
if err != nil {
return fmt.Errorf("failed to select default models: %w", err)
@@ -543,7 +571,7 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro
if model == nil {
large = defaultLarge
// override the model type to large
- err := c.UpdatePreferredModel(SelectedModelTypeLarge, large)
+ err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeLarge, large)
if err != nil {
return fmt.Errorf("failed to update preferred large model: %w", err)
}
@@ -587,7 +615,7 @@ func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) erro
if model == nil {
small = defaultSmall
// override the model type to small
- err := c.UpdatePreferredModel(SelectedModelTypeSmall, small)
+ err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeSmall, small)
if err != nil {
return fmt.Errorf("failed to update preferred small model: %w", err)
}
@@ -36,6 +36,11 @@ func TestConfig_LoadFromBytes(t *testing.T) {
require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
}
+// testStore wraps a Config in a minimal ConfigStore for testing.
+func testStore(cfg *Config) *ConfigStore {
+ return &ConfigStore{config: cfg}
+}
+
func TestConfig_setDefaults(t *testing.T) {
cfg := &Config{}
@@ -53,7 +58,6 @@ func TestConfig_setDefaults(t *testing.T) {
for _, path := range defaultContextPaths {
require.Contains(t, cfg.Options.ContextPaths, path)
}
- require.Equal(t, "/tmp", cfg.workingDir)
}
func TestConfig_configureProviders(t *testing.T) {
@@ -74,7 +78,7 @@ func TestConfig_configureProviders(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, 1, cfg.Providers.Len())
@@ -117,7 +121,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, 1, cfg.Providers.Len())
@@ -159,7 +163,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
// Should be to because of the env variable
require.Equal(t, cfg.Providers.Len(), 2)
@@ -195,7 +199,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -221,7 +225,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
// Provider should not be configured without credentials
require.Equal(t, cfg.Providers.Len(), 0)
@@ -246,7 +250,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.Error(t, err)
}
@@ -269,7 +273,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
"VERTEXAI_LOCATION": "us-central1",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -301,7 +305,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
"GOOGLE_CLOUD_LOCATION": "us-central1",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
// Provider should not be configured without proper credentials
require.Equal(t, cfg.Providers.Len(), 0)
@@ -326,7 +330,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
"GOOGLE_CLOUD_LOCATION": "us-central1",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
// Provider should not be configured without project
require.Equal(t, cfg.Providers.Len(), 0)
@@ -350,7 +354,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -541,7 +545,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -569,7 +573,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -592,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -614,7 +618,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -639,7 +643,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -664,7 +668,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -692,7 +696,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -722,7 +726,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -757,7 +761,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
"GOOGLE_GENAI_USE_VERTEXAI": "false",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -788,7 +792,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -819,7 +823,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -852,7 +856,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -886,7 +890,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
@@ -922,7 +926,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
@@ -952,7 +956,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
require.Error(t, err)
@@ -995,7 +999,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
require.NoError(t, err)
@@ -1039,7 +1043,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
require.Error(t, err)
@@ -1081,7 +1085,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
require.NoError(t, err)
@@ -1126,7 +1130,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.ErrorContains(t, err, "no custom providers")
// openai should NOT be present because it lacks base_url and models.
@@ -1169,7 +1173,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
// Only fully specified provider should be present.
@@ -1223,7 +1227,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
"ANTHROPIC_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
// Both providers should be present.
@@ -1251,7 +1255,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.ErrorContains(t, err, "no custom providers")
// Provider should be rejected for missing models.
@@ -1275,7 +1279,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
require.ErrorContains(t, err, "no custom providers")
// Provider should be rejected for missing base_url.
@@ -1340,10 +1344,10 @@ func TestConfig_configureSelectedModels(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
- err = cfg.configureSelectedModels(knownProviders)
+ err = configureSelectedModels(testStore(cfg), knownProviders)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall]
@@ -1402,10 +1406,10 @@ func TestConfig_configureSelectedModels(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
- err = cfg.configureSelectedModels(knownProviders)
+ err = configureSelectedModels(testStore(cfg), knownProviders)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall]
@@ -1447,10 +1451,10 @@ func TestConfig_configureSelectedModels(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
require.NoError(t, err)
- err = cfg.configureSelectedModels(knownProviders)
+ err = configureSelectedModels(testStore(cfg), knownProviders)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
require.Equal(t, "large-model", large.Model)
@@ -31,15 +31,23 @@ func readRecentModels(t *testing.T, path string) map[string]any {
return rm
}
+// testStoreWithPath creates a ConfigStore backed by a Config for recent model tests.
+func testStoreWithPath(cfg *Config, dir string) *ConfigStore {
+ return &ConfigStore{
+ config: cfg,
+ globalDataPath: filepath.Join(dir, "config.json"),
+ }
+}
+
func TestRecordRecentModel_AddsAndPersists(t *testing.T) {
t.Parallel()
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ store := testStoreWithPath(cfg, dir)
- err := cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})
+ err := store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"})
require.NoError(t, err)
// in-memory state
@@ -48,7 +56,7 @@ func TestRecordRecentModel_AddsAndPersists(t *testing.T) {
require.Equal(t, "gpt-4o", cfg.RecentModels[SelectedModelTypeLarge][0].Model)
// persisted state
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, store.globalDataPath)
large, ok := rm[string(SelectedModelTypeLarge)].([]any)
require.True(t, ok)
require.Len(t, large, 1)
@@ -64,13 +72,13 @@ func TestRecordRecentModel_DedupeAndMoveToFront(t *testing.T) {
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ store := testStoreWithPath(cfg, dir)
// Add two entries
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"}))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "anthropic", Model: "claude"}))
// Re-add first; should move to front and not duplicate
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "openai", Model: "gpt-4o"}))
got := cfg.RecentModels[SelectedModelTypeLarge]
require.Len(t, got, 2)
@@ -84,7 +92,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) {
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ store := testStoreWithPath(cfg, dir)
// Insert 6 unique models; max is 5
entries := []SelectedModel{
@@ -96,7 +104,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) {
{Provider: "p6", Model: "m6"},
}
for _, e := range entries {
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, e))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, e))
}
// in-memory state
@@ -110,7 +118,7 @@ func TestRecordRecentModel_TrimsToMax(t *testing.T) {
require.Equal(t, SelectedModel{Provider: "p2", Model: "m2"}, got[4])
// persisted state: verify trimmed to 5 and newest-first order
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, store.globalDataPath)
large, ok := rm[string(SelectedModelTypeLarge)].([]any)
require.True(t, ok)
require.Len(t, large, 5)
@@ -129,12 +137,12 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) {
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ store := testStoreWithPath(cfg, dir)
// Missing provider
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"}))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "", Model: "m"}))
// Missing model
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""}))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, SelectedModel{Provider: "p", Model: ""}))
_, ok := cfg.RecentModels[SelectedModelTypeLarge]
// Map may be initialized, but should have no entries
@@ -142,8 +150,8 @@ func TestRecordRecentModel_SkipsEmptyValues(t *testing.T) {
require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 0)
}
// No file should be written (stat via fs.FS)
- baseDir := filepath.Dir(cfg.dataConfigDir)
- fileName := filepath.Base(cfg.dataConfigDir)
+ baseDir := filepath.Dir(store.globalDataPath)
+ fileName := filepath.Base(store.globalDataPath)
_, err := fs.Stat(os.DirFS(baseDir), fileName)
require.True(t, os.IsNotExist(err))
}
@@ -154,13 +162,13 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) {
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ store := testStoreWithPath(cfg, dir)
entry := SelectedModel{Provider: "openai", Model: "gpt-4o"}
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, entry))
- baseDir := filepath.Dir(cfg.dataConfigDir)
- fileName := filepath.Base(cfg.dataConfigDir)
+ baseDir := filepath.Dir(store.globalDataPath)
+ fileName := filepath.Base(store.globalDataPath)
before, err := fs.ReadFile(os.DirFS(baseDir), fileName)
require.NoError(t, err)
@@ -170,7 +178,7 @@ func TestRecordRecentModel_NoPersistOnNoop(t *testing.T) {
beforeMod := stBefore.ModTime()
// Re-record same entry should be a no-op (no write)
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, entry))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, entry))
after, err := fs.ReadFile(os.DirFS(baseDir), fileName)
require.NoError(t, err)
@@ -188,17 +196,17 @@ func TestUpdatePreferredModel_UpdatesRecents(t *testing.T) {
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ store := testStoreWithPath(cfg, dir)
sel := SelectedModel{Provider: "openai", Model: "gpt-4o"}
- require.NoError(t, cfg.UpdatePreferredModel(SelectedModelTypeSmall, sel))
+ require.NoError(t, store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeSmall, sel))
// in-memory
require.Equal(t, sel, cfg.Models[SelectedModelTypeSmall])
require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1)
// persisted (read via fs.FS)
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, store.globalDataPath)
small, ok := rm[string(SelectedModelTypeSmall)].([]any)
require.True(t, ok)
require.Len(t, small, 1)
@@ -210,14 +218,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) {
dir := t.TempDir()
cfg := &Config{}
cfg.setDefaults(dir, "")
- cfg.dataConfigDir = filepath.Join(dir, "config.json")
+ store := testStoreWithPath(cfg, dir)
// Add models to both large and small types
largeModel := SelectedModel{Provider: "openai", Model: "gpt-4o"}
smallModel := SelectedModel{Provider: "anthropic", Model: "claude"}
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, largeModel))
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeSmall, smallModel))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, largeModel))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeSmall, smallModel))
// in-memory: verify types maintain separate histories
require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 1)
@@ -227,14 +235,14 @@ func TestRecordRecentModel_TypeIsolation(t *testing.T) {
// Add another to large, verify small unchanged
anotherLarge := SelectedModel{Provider: "google", Model: "gemini"}
- require.NoError(t, cfg.recordRecentModel(SelectedModelTypeLarge, anotherLarge))
+ require.NoError(t, store.recordRecentModel(ScopeGlobal, SelectedModelTypeLarge, anotherLarge))
require.Len(t, cfg.RecentModels[SelectedModelTypeLarge], 2)
require.Len(t, cfg.RecentModels[SelectedModelTypeSmall], 1)
require.Equal(t, smallModel, cfg.RecentModels[SelectedModelTypeSmall][0])
// persisted state: verify both types exist with correct lengths and contents
- rm := readRecentModels(t, cfg.dataConfigDir)
+ rm := readRecentModels(t, store.globalDataPath)
large, ok := rm[string(SelectedModelTypeLarge)].([]any)
require.True(t, ok)
@@ -0,0 +1,11 @@
+package config
+
+// Scope determines which config file is targeted for read/write operations.
+type Scope int
+
+const (
+ // ScopeGlobal targets the global data config (~/.local/share/crush/crush.json).
+ ScopeGlobal Scope = iota
+ // ScopeWorkspace targets the workspace config (.crush/crush.json).
+ ScopeWorkspace
+)
@@ -0,0 +1,336 @@
+package config
+
+import (
+ "cmp"
+ "context"
+ "fmt"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "slices"
+
+ "charm.land/catwalk/pkg/catwalk"
+ hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/oauth"
+ "github.com/charmbracelet/crush/internal/oauth/copilot"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+// ConfigStore is the single entry point for all config access. It owns the
+// pure-data Config, runtime state (working directory, resolver, known
+// providers), and persistence to both global and workspace config files.
+type ConfigStore struct {
+ config *Config
+ workingDir string
+ resolver VariableResolver
+ globalDataPath string // ~/.local/share/crush/crush.json
+ workspacePath string // .crush/crush.json
+ knownProviders []catwalk.Provider
+}
+
+// Config returns the pure-data config struct (read-only after load).
+func (s *ConfigStore) Config() *Config {
+ return s.config
+}
+
+// WorkingDir returns the current working directory.
+func (s *ConfigStore) WorkingDir() string {
+ return s.workingDir
+}
+
+// Resolver returns the variable resolver.
+func (s *ConfigStore) Resolver() VariableResolver {
+ return s.resolver
+}
+
+// Resolve resolves a variable reference using the configured resolver.
+func (s *ConfigStore) Resolve(key string) (string, error) {
+ if s.resolver == nil {
+ return "", fmt.Errorf("no variable resolver configured")
+ }
+ return s.resolver.ResolveValue(key)
+}
+
+// KnownProviders returns the list of known providers.
+func (s *ConfigStore) KnownProviders() []catwalk.Provider {
+ return s.knownProviders
+}
+
+// SetupAgents configures the coder and task agents on the config.
+func (s *ConfigStore) SetupAgents() {
+ s.config.SetupAgents()
+}
+
+// configPath returns the file path for the given scope.
+func (s *ConfigStore) configPath(scope Scope) string {
+ switch scope {
+ case ScopeWorkspace:
+ return s.workspacePath
+ default:
+ return s.globalDataPath
+ }
+}
+
+// HasConfigField checks whether a key exists in the config file for the given
+// scope.
+func (s *ConfigStore) HasConfigField(scope Scope, key string) bool {
+ data, err := os.ReadFile(s.configPath(scope))
+ if err != nil {
+ return false
+ }
+ return gjson.Get(string(data), key).Exists()
+}
+
+// SetConfigField sets a key/value pair in the config file for the given scope.
+func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
+ path := s.configPath(scope)
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ data = []byte("{}")
+ } else {
+ return fmt.Errorf("failed to read config file: %w", err)
+ }
+ }
+
+ newValue, err := sjson.Set(string(data), key, value)
+ if err != nil {
+ return fmt.Errorf("failed to set config field %s: %w", key, err)
+ }
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ return fmt.Errorf("failed to create config directory %q: %w", path, err)
+ }
+ if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
+ return fmt.Errorf("failed to write config file: %w", err)
+ }
+ return nil
+}
+
+// RemoveConfigField removes a key from the config file for the given scope.
+func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
+ path := s.configPath(scope)
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return fmt.Errorf("failed to read config file: %w", err)
+ }
+
+ newValue, err := sjson.Delete(string(data), key)
+ if err != nil {
+ return fmt.Errorf("failed to delete config field %s: %w", key, err)
+ }
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ return fmt.Errorf("failed to create config directory %q: %w", path, err)
+ }
+ if err := os.WriteFile(path, []byte(newValue), 0o600); err != nil {
+ return fmt.Errorf("failed to write config file: %w", err)
+ }
+ return nil
+}
+
+// UpdatePreferredModel updates the preferred model for the given type and
+// persists it to the config file at the given scope.
+func (s *ConfigStore) UpdatePreferredModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
+ s.config.Models[modelType] = model
+ if err := s.SetConfigField(scope, fmt.Sprintf("models.%s", modelType), model); err != nil {
+ return fmt.Errorf("failed to update preferred model: %w", err)
+ }
+ if err := s.recordRecentModel(scope, modelType, model); err != nil {
+ return err
+ }
+ return nil
+}
+
+// SetCompactMode sets the compact mode setting and persists it.
+func (s *ConfigStore) SetCompactMode(scope Scope, enabled bool) error {
+ if s.config.Options == nil {
+ s.config.Options = &Options{}
+ }
+ s.config.Options.TUI.CompactMode = enabled
+ return s.SetConfigField(scope, "options.tui.compact_mode", enabled)
+}
+
+// SetProviderAPIKey sets the API key for a provider and persists it.
+func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey any) error {
+ var providerConfig ProviderConfig
+ var exists bool
+ var setKeyOrToken func()
+
+ switch v := apiKey.(type) {
+ case string:
+ if err := s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v); err != nil {
+ return fmt.Errorf("failed to save api key to config file: %w", err)
+ }
+ setKeyOrToken = func() { providerConfig.APIKey = v }
+ case *oauth.Token:
+ if err := cmp.Or(
+ s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), v.AccessToken),
+ s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), v),
+ ); err != nil {
+ return err
+ }
+ setKeyOrToken = func() {
+ providerConfig.APIKey = v.AccessToken
+ providerConfig.OAuthToken = v
+ switch providerID {
+ case string(catwalk.InferenceProviderCopilot):
+ providerConfig.SetupGitHubCopilot()
+ }
+ }
+ }
+
+ providerConfig, exists = s.config.Providers.Get(providerID)
+ if exists {
+ setKeyOrToken()
+ s.config.Providers.Set(providerID, providerConfig)
+ return nil
+ }
+
+ var foundProvider *catwalk.Provider
+ for _, p := range s.knownProviders {
+ if string(p.ID) == providerID {
+ foundProvider = &p
+ break
+ }
+ }
+
+ if foundProvider != nil {
+ providerConfig = ProviderConfig{
+ ID: providerID,
+ Name: foundProvider.Name,
+ BaseURL: foundProvider.APIEndpoint,
+ Type: foundProvider.Type,
+ Disable: false,
+ ExtraHeaders: make(map[string]string),
+ ExtraParams: make(map[string]string),
+ Models: foundProvider.Models,
+ }
+ setKeyOrToken()
+ } else {
+ return fmt.Errorf("provider with ID %s not found in known providers", providerID)
+ }
+ s.config.Providers.Set(providerID, providerConfig)
+ return nil
+}
+
+// RefreshOAuthToken refreshes the OAuth token for the given provider.
+func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
+ providerConfig, exists := s.config.Providers.Get(providerID)
+ if !exists {
+ return fmt.Errorf("provider %s not found", providerID)
+ }
+
+ if providerConfig.OAuthToken == nil {
+ return fmt.Errorf("provider %s does not have an OAuth token", providerID)
+ }
+
+ var newToken *oauth.Token
+ var refreshErr error
+ switch providerID {
+ case string(catwalk.InferenceProviderCopilot):
+ newToken, refreshErr = copilot.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ case hyperp.Name:
+ newToken, refreshErr = hyper.ExchangeToken(ctx, providerConfig.OAuthToken.RefreshToken)
+ default:
+ return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
+ }
+ if refreshErr != nil {
+ return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
+ }
+
+ slog.Info("Successfully refreshed OAuth token", "provider", providerID)
+ providerConfig.OAuthToken = newToken
+ providerConfig.APIKey = newToken.AccessToken
+
+ switch providerID {
+ case string(catwalk.InferenceProviderCopilot):
+ providerConfig.SetupGitHubCopilot()
+ }
+
+ s.config.Providers.Set(providerID, providerConfig)
+
+ if err := cmp.Or(
+ s.SetConfigField(scope, fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
+ s.SetConfigField(scope, fmt.Sprintf("providers.%s.oauth", providerID), newToken),
+ ); err != nil {
+ return fmt.Errorf("failed to persist refreshed token: %w", err)
+ }
+
+ return nil
+}
+
+// recordRecentModel records a model in the recent models list.
+func (s *ConfigStore) recordRecentModel(scope Scope, modelType SelectedModelType, model SelectedModel) error {
+ if model.Provider == "" || model.Model == "" {
+ return nil
+ }
+
+ if s.config.RecentModels == nil {
+ s.config.RecentModels = make(map[SelectedModelType][]SelectedModel)
+ }
+
+ eq := func(a, b SelectedModel) bool {
+ return a.Provider == b.Provider && a.Model == b.Model
+ }
+
+ entry := SelectedModel{
+ Provider: model.Provider,
+ Model: model.Model,
+ }
+
+ current := s.config.RecentModels[modelType]
+ withoutCurrent := slices.DeleteFunc(slices.Clone(current), func(existing SelectedModel) bool {
+ return eq(existing, entry)
+ })
+
+ updated := append([]SelectedModel{entry}, withoutCurrent...)
+ if len(updated) > maxRecentModelsPerType {
+ updated = updated[:maxRecentModelsPerType]
+ }
+
+ if slices.EqualFunc(current, updated, eq) {
+ return nil
+ }
+
+ s.config.RecentModels[modelType] = updated
+
+ if err := s.SetConfigField(scope, fmt.Sprintf("recent_models.%s", modelType), updated); err != nil {
+ return fmt.Errorf("failed to persist recent models: %w", err)
+ }
+
+ return nil
+}
+
+// ImportCopilot attempts to import a GitHub Copilot token from disk.
+func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) {
+ if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") {
+ return nil, false
+ }
+
+ diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
+ if !hasDiskToken {
+ return nil, false
+ }
+
+ slog.Info("Found existing GitHub Copilot token on disk. Authenticating...")
+ token, err := copilot.RefreshToken(context.TODO(), diskToken)
+ if err != nil {
+ slog.Error("Unable to import GitHub Copilot token", "error", err)
+ return nil, false
+ }
+
+ if err := s.SetProviderAPIKey(ScopeGlobal, string(catwalk.InferenceProviderCopilot), token); err != nil {
+ return token, false
+ }
+
+ if err := cmp.Or(
+ s.SetConfigField(ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
+ s.SetConfigField(ScopeGlobal, "providers.copilot.oauth", token),
+ ); err != nil {
+ slog.Error("Unable to save GitHub Copilot token to disk", "error", err)
+ }
+
+ slog.Info("GitHub Copilot successfully imported")
+ return token, true
+}
@@ -26,18 +26,18 @@ var unavailable = csync.NewMap[string, struct{}]()
// Manager handles lazy initialization of LSP clients based on file types.
type Manager struct {
clients *csync.Map[string, *Client]
- cfg *config.Config
+ cfg *config.ConfigStore
manager *powernapconfig.Manager
callback func(name string, client *Client)
}
// NewManager creates a new LSP manager service.
-func NewManager(cfg *config.Config) *Manager {
+func NewManager(cfg *config.ConfigStore) *Manager {
manager := powernapconfig.NewManager()
manager.LoadDefaults()
// Merge user-configured LSPs into the manager.
- for name, clientConfig := range cfg.LSP {
+ for name, clientConfig := range cfg.Config().LSP {
if clientConfig.Disabled {
slog.Debug("LSP disabled by user config", "name", name)
manager.RemoveServer(name)
@@ -194,7 +194,7 @@ func (s *Manager) startServer(ctx context.Context, name, filepath string, server
cfg,
s.cfg.Resolver(),
s.cfg.WorkingDir(),
- s.cfg.Options.DebugLSP,
+ s.cfg.Config().Options.DebugLSP,
)
if err != nil {
slog.Error("Failed to create LSP client", "name", name, "error", err)
@@ -244,7 +244,7 @@ func (s *Manager) startServer(ctx context.Context, name, filepath string, server
}
func (s *Manager) isUserConfigured(name string) bool {
- cfg, ok := s.cfg.LSP[name]
+ cfg, ok := s.cfg.Config().LSP[name]
return ok && !cfg.Disabled
}
@@ -258,7 +258,7 @@ func (s *Manager) buildConfig(name string, server *powernapconfig.ServerConfig)
InitOptions: server.InitOptions,
Options: server.Settings,
}
- if userCfg, ok := s.cfg.LSP[name]; ok {
+ if userCfg, ok := s.cfg.Config().LSP[name]; ok {
cfg.Timeout = userCfg.Timeout
}
return cfg
@@ -26,11 +26,16 @@ type Common struct {
Styles *styles.Styles
}
-// Config returns the configuration associated with this [Common] instance.
+// Config returns the pure-data configuration associated with this [Common] instance.
func (c *Common) Config() *config.Config {
return c.App.Config()
}
+// Store returns the config store associated with this [Common] instance.
+func (c *Common) Store() *config.ConfigStore {
+ return c.App.Store()
+}
+
// DefaultCommon returns the default common UI configurations.
func DefaultCommon(app *app.App) *Common {
s := styles.DefaultStyles()
@@ -296,7 +296,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg {
Type: m.provider.Type,
BaseURL: m.provider.APIEndpoint,
}
- err := providerConfig.TestConnection(m.com.Config().Resolver())
+ err := providerConfig.TestConnection(m.com.Store().Resolver())
// intentionally wait for at least 750ms to make sure the user sees the spinner
elapsed := time.Since(start)
@@ -312,9 +312,9 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg {
}
func (m *APIKeyInput) saveKeyAndContinue() Action {
- cfg := m.com.Config()
+ store := m.com.Store()
- err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value())
+ err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value())
if err != nil {
return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
}
@@ -123,7 +123,7 @@ func (f *FilePicker) SetImageCapabilities(caps *common.Capabilities) {
// WorkingDir returns the current working directory of the [FilePicker].
func (f *FilePicker) WorkingDir() string {
- wd := f.com.Config().WorkingDir()
+ wd := f.com.Store().WorkingDir()
if len(wd) > 0 {
return wd
}
@@ -490,7 +490,7 @@ func (m *Models) setProviderItems() error {
if len(validRecentItems) != len(recentItems) {
// FIXME: Does this need to be here? Is it mutating the config during a read?
- if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
+ if err := m.com.Store().SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
return fmt.Errorf("failed to update recent models: %w", err)
}
}
@@ -373,9 +373,9 @@ func (d *OAuth) copyCodeAndOpenURL() tea.Cmd {
}
func (m *OAuth) saveKeyAndContinue() Action {
- cfg := m.com.Config()
+ store := m.com.Store()
- err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token)
+ err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.token)
if err != nil {
return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
}
@@ -143,8 +143,7 @@ func renderHeaderDetails(
metadata = dot + metadata
const dirTrimLimit = 4
- cfg := com.Config()
- cwd := fsext.DirTrim(fsext.PrettyPath(cfg.WorkingDir()), dirTrimLimit)
+ cwd := fsext.DirTrim(fsext.PrettyPath(com.Store().WorkingDir()), dirTrimLimit)
cwd = t.Header.WorkingDir.Render(cwd)
result := cwd + metadata
@@ -22,7 +22,7 @@ func (m *UI) selectedLargeModel() *agent.Model {
func (m *UI) landingView() string {
t := m.com.Styles
width := m.layout.main.Dx()
- cwd := common.PrettyPath(t, m.com.Config().WorkingDir(), width)
+ cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width)
parts := []string{
cwd,
@@ -19,7 +19,7 @@ import (
// markProjectInitialized marks the current project as initialized in the config.
func (m *UI) markProjectInitialized() tea.Msg {
// TODO: handle error so we show it in the tui footer
- err := config.MarkProjectInitialized(m.com.Config())
+ err := config.MarkProjectInitialized(m.com.Store())
if err != nil {
slog.Error(err.Error())
}
@@ -52,10 +52,10 @@ func (m *UI) initializeProject() tea.Cmd {
if cmd := m.newSession(); cmd != nil {
cmds = append(cmds, cmd)
}
- cfg := m.com.Config()
+ cfg := m.com.Store()
initialize := func() tea.Msg {
- initPrompt, err := agent.InitializePrompt(*cfg)
+ initPrompt, err := agent.InitializePrompt(cfg)
if err != nil {
return util.InfoMsg{Type: util.InfoTypeError, Msg: err.Error()}
}
@@ -77,10 +77,9 @@ func (m *UI) skipInitializeProject() tea.Cmd {
// initializeView renders the project initialization prompt with Yes/No buttons.
func (m *UI) initializeView() string {
- cfg := m.com.Config()
s := m.com.Styles.Initialize
- cwd := home.Short(cfg.WorkingDir())
- initFile := cfg.Options.InitializeAs
+ cwd := home.Short(m.com.Store().WorkingDir())
+ initFile := m.com.Config().Options.InitializeAs
header := s.Header.Render("Would you like to initialize this project?")
path := s.Accent.PaddingLeft(2).Render(cwd)
@@ -112,7 +112,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) {
height := area.Dy()
title := t.Muted.Width(width).MaxHeight(2).Render(m.session.Title)
- cwd := common.PrettyPath(t, m.com.Config().WorkingDir(), width)
+ cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width)
sidebarLogo := m.sidebarLogo
if height < logoHeightBreakpoint {
sidebarLogo = logo.SmallRender(m.com.Styles, width)
@@ -138,7 +138,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) {
lspSection := m.lspInfo(width, maxLSPs, true)
mcpSection := m.mcpInfo(width, maxMCPs, true)
- filesSection := m.filesInfo(m.com.Config().WorkingDir(), width, maxFiles, true)
+ filesSection := m.filesInfo(m.com.Store().WorkingDir(), width, maxFiles, true)
uv.NewStyledString(
lipgloss.NewStyle().
@@ -317,7 +317,7 @@ func New(com *common.Common) *UI {
desiredFocus := uiFocusEditor
if !com.Config().IsConfigured() {
desiredState = uiOnboarding
- } else if n, _ := config.ProjectNeedsInitialization(com.Config()); n {
+ } else if n, _ := config.ProjectNeedsInitialization(com.Store()); n {
desiredState = uiInitialize
}
@@ -579,7 +579,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case mcp.EventPromptsListChanged:
return m, handleMCPPromptsEvent(msg.Payload.Name)
case mcp.EventToolsListChanged:
- return m, handleMCPToolsEvent(m.com.Config(), msg.Payload.Name)
+ return m, handleMCPToolsEvent(m.com.Store(), msg.Payload.Name)
case mcp.EventResourcesListChanged:
return m, handleMCPResourcesEvent(msg.Payload.Name)
}
@@ -1301,7 +1301,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
currentModel := cfg.Models[agentCfg.Model]
currentModel.Think = !currentModel.Think
- if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil {
+ if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil {
return util.ReportError(err)()
}
m.com.App.UpdateAgentModel(context.TODO())
@@ -1342,7 +1342,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
// Attempt to import GitHub Copilot tokens from VSCode if available.
if isCopilot && !isConfigured() && !msg.ReAuthenticate {
- m.com.Config().ImportCopilot()
+ m.com.Store().ImportCopilot()
}
if !isConfigured() || msg.ReAuthenticate {
@@ -1353,12 +1353,12 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
break
}
- if err := cfg.UpdatePreferredModel(msg.ModelType, msg.Model); err != nil {
+ if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, msg.ModelType, msg.Model); err != nil {
cmds = append(cmds, util.ReportError(err))
} else if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok {
// Ensure small model is set is unset.
smallModel := m.com.App.GetDefaultSmallModel(providerID)
- if err := cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallModel); err != nil {
+ if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, config.SelectedModelTypeSmall, smallModel); err != nil {
cmds = append(cmds, util.ReportError(err))
}
}
@@ -1404,7 +1404,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
currentModel := cfg.Models[agentCfg.Model]
currentModel.ReasoningEffort = msg.Effort
- if err := cfg.UpdatePreferredModel(agentCfg.Model, currentModel); err != nil {
+ if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil {
cmds = append(cmds, util.ReportError(err))
break
}
@@ -2016,7 +2016,7 @@ func (m *UI) View() tea.View {
}
v.MouseMode = tea.MouseModeCellMotion
v.ReportFocus = m.caps.ReportFocusEvents
- v.WindowTitle = "crush " + home.Short(m.com.Config().WorkingDir())
+ v.WindowTitle = "crush " + home.Short(m.com.Store().WorkingDir())
canvas := uv.NewScreenBuffer(m.width, m.height)
v.Cursor = m.Draw(canvas, canvas.Bounds())
@@ -2255,7 +2255,7 @@ func (m *UI) FullHelp() [][]key.Binding {
func (m *UI) toggleCompactMode() tea.Cmd {
m.forceCompactMode = !m.forceCompactMode
- err := m.com.Config().SetCompactMode(m.forceCompactMode)
+ err := m.com.Store().SetCompactMode(config.ScopeGlobal, m.forceCompactMode)
if err != nil {
return util.ReportError(err)
}
@@ -2637,7 +2637,7 @@ func (m *UI) insertMCPResourceCompletion(item completions.ResourceCompletionValu
return func() tea.Msg {
contents, err := mcp.ReadResource(
context.Background(),
- m.com.Config(),
+ m.com.Store(),
item.MCPName,
item.URI,
)
@@ -3299,7 +3299,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) {
lspSection := m.lspInfo(sectionWidth, maxItemsPerSection, false)
mcpSection := m.mcpInfo(sectionWidth, maxItemsPerSection, false)
- filesSection := m.filesInfo(m.com.Config().WorkingDir(), sectionWidth, maxItemsPerSection, false)
+ filesSection := m.filesInfo(m.com.Store().WorkingDir(), sectionWidth, maxItemsPerSection, false)
sections := lipgloss.JoinHorizontal(lipgloss.Top, filesSection, " ", lspSection, " ", mcpSection)
uv.NewStyledString(
s.CompactDetails.View.
@@ -3317,7 +3317,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) {
func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string) tea.Cmd {
load := func() tea.Msg {
- prompt, err := commands.GetMCPPrompt(m.com.Config(), clientID, promptID, arguments)
+ prompt, err := commands.GetMCPPrompt(m.com.Store(), clientID, promptID, arguments)
if err != nil {
// TODO: make this better
return util.ReportError(err)()
@@ -3358,7 +3358,7 @@ func handleMCPPromptsEvent(name string) tea.Cmd {
}
}
-func handleMCPToolsEvent(cfg *config.Config, name string) tea.Cmd {
+func handleMCPToolsEvent(cfg *config.ConfigStore, name string) tea.Cmd {
return func() tea.Msg {
mcp.RefreshTools(
context.Background(),