feat(cmd): support overriding the data directory

Vincent Ambo created

In some cases it might not be desirable to write `.crush` folders
everywhere, or to reuse `.crush` folders between different filesystem
locations.

This change makes it possible to specify `-D` / `--data-dir` to set a
different directory to load/create the project-specific data in.

Change summary

internal/cmd/logs.go                 |  7 ++
internal/cmd/root.go                 |  7 ++
internal/config/init.go              |  4 
internal/config/load.go              | 10 +++-
internal/config/load_test.go         | 64 +++++++++++++++---------------
internal/llm/provider/openai_test.go |  2 
6 files changed, 54 insertions(+), 40 deletions(-)

Detailed changes

internal/cmd/logs.go 🔗

@@ -28,6 +28,11 @@ var logsCmd = &cobra.Command{
 			return fmt.Errorf("failed to get current working directory: %v", err)
 		}
 
+		dataDir, err := cmd.Flags().GetString("data-dir")
+		if err != nil {
+			return fmt.Errorf("failed to get data directory: %v", err)
+		}
+
 		follow, err := cmd.Flags().GetBool("follow")
 		if err != nil {
 			return fmt.Errorf("failed to get follow flag: %v", err)
@@ -41,7 +46,7 @@ var logsCmd = &cobra.Command{
 		log.SetLevel(log.DebugLevel)
 		log.SetOutput(os.Stdout)
 
-		cfg, err := config.Load(cwd, false)
+		cfg, err := config.Load(cwd, dataDir, false)
 		if err != nil {
 			return fmt.Errorf("failed to load configuration: %v", err)
 		}

internal/cmd/root.go 🔗

@@ -21,6 +21,7 @@ import (
 
 func init() {
 	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
+	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 
 	rootCmd.Flags().BoolP("help", "h", false, "Help")
@@ -45,6 +46,9 @@ crush -d
 # Run with debug logging in a specific directory
 crush -d -c /path/to/project
 
+# Run with custom data directory
+crush -D /path/to/custom/.crush
+
 # Print version
 crush -v
 
@@ -96,6 +100,7 @@ func Execute() {
 func setupApp(cmd *cobra.Command) (*app.App, error) {
 	debug, _ := cmd.Flags().GetBool("debug")
 	yolo, _ := cmd.Flags().GetBool("yolo")
+	dataDir, _ := cmd.Flags().GetString("data-dir")
 	ctx := cmd.Context()
 
 	cwd, err := ResolveCwd(cmd)
@@ -103,7 +108,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
 		return nil, err
 	}
 
-	cfg, err := config.Init(cwd, debug)
+	cfg, err := config.Init(cwd, dataDir, debug)
 	if err != nil {
 		return nil, err
 	}

internal/config/init.go 🔗

@@ -19,8 +19,8 @@ type ProjectInitFlag struct {
 // TODO: we need to remove the global config instance keeping it now just until everything is migrated
 var instance atomic.Pointer[Config]
 
-func Init(workingDir string, debug bool) (*Config, error) {
-	cfg, err := Load(workingDir, debug)
+func Init(workingDir, dataDir string, debug bool) (*Config, error) {
+	cfg, err := Load(workingDir, dataDir, debug)
 	if err != nil {
 		return nil, err
 	}

internal/config/load.go 🔗

@@ -37,7 +37,7 @@ func LoadReader(fd io.Reader) (*Config, error) {
 }
 
 // Load loads the configuration from the default paths.
-func Load(workingDir string, debug bool) (*Config, error) {
+func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 	// uses default config paths
 	configPaths := []string{
 		globalConfig(),
@@ -52,7 +52,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
 
 	cfg.dataConfigDir = GlobalConfigData()
 
-	cfg.setDefaults(workingDir)
+	cfg.setDefaults(workingDir, dataDir)
 
 	if debug {
 		cfg.Options.Debug = true
@@ -299,7 +299,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
 	return nil
 }
 
-func (c *Config) setDefaults(workingDir string) {
+func (c *Config) setDefaults(workingDir, dataDir string) {
 	c.workingDir = workingDir
 	if c.Options == nil {
 		c.Options = &Options{}
@@ -317,6 +317,10 @@ func (c *Config) setDefaults(workingDir string) {
 			c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
 		}
 	}
+	// explicit dataDir flag always takes precedence
+	if dataDir != "" {
+		c.Options.DataDirectory = dataDir
+	}
 	if c.Providers == nil {
 		c.Providers = csync.NewMap[string, ProviderConfig]()
 	}

internal/config/load_test.go 🔗

@@ -39,7 +39,7 @@ func TestConfig_LoadFromReaders(t *testing.T) {
 func TestConfig_setDefaults(t *testing.T) {
 	cfg := &Config{}
 
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 
 	require.NotNil(t, cfg.Options)
 	require.NotNil(t, cfg.Options.TUI)
@@ -68,7 +68,7 @@ func TestConfig_configureProviders(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
 	})
@@ -110,7 +110,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
 			},
 		},
 	})
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
@@ -153,7 +153,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 			},
 		}),
 	}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
 	})
@@ -188,7 +188,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "test-key-id",
 		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
@@ -217,7 +217,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{})
 	resolver := NewEnvironmentVariableResolver(env)
 	err := cfg.configureProviders(env, resolver, knownProviders)
@@ -239,7 +239,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "test-key-id",
 		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
@@ -262,7 +262,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"VERTEXAI_PROJECT":  "test-project",
 		"VERTEXAI_LOCATION": "us-central1",
@@ -293,7 +293,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 		"GOOGLE_CLOUD_PROJECT":      "test-project",
@@ -319,7 +319,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 		"GOOGLE_CLOUD_LOCATION":     "us-central1",
@@ -344,7 +344,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
 	})
@@ -472,7 +472,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
 			},
 		}),
 	}
-	cfg.setDefaults("/tmp")
+	cfg.setDefaults("/tmp", "")
 
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
@@ -502,7 +502,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -525,7 +525,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -547,7 +547,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -572,7 +572,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -597,7 +597,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -625,7 +625,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -655,7 +655,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -688,7 +688,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{
 			"GOOGLE_GENAI_USE_VERTEXAI": "false",
@@ -721,7 +721,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -752,7 +752,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
@@ -783,7 +783,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 
 		env := env.NewFromMap(map[string]string{
 			"OPENAI_API_KEY": "test-key",
@@ -820,7 +820,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
 		}
 
 		cfg := &Config{}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -856,7 +856,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
 		}
 
 		cfg := &Config{}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -886,7 +886,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
 		}
 
 		cfg := &Config{}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -929,7 +929,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -973,7 +973,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -1015,7 +1015,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
 				},
 			}),
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -1063,7 +1063,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
 				},
 			},
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -1125,7 +1125,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
 				},
 			},
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)
@@ -1170,7 +1170,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
 				},
 			},
 		}
-		cfg.setDefaults("/tmp")
+		cfg.setDefaults("/tmp", "")
 		env := env.NewFromMap(map[string]string{})
 		resolver := NewEnvironmentVariableResolver(env)
 		err := cfg.configureProviders(env, resolver, knownProviders)

internal/llm/provider/openai_test.go 🔗

@@ -17,7 +17,7 @@ import (
 )
 
 func TestMain(m *testing.M) {
-	_, err := config.Init(".", true)
+	_, err := config.Init(".", "", true)
 	if err != nil {
 		panic("Failed to initialize config: " + err.Error())
 	}