diff --git a/internal/cmd/logs.go b/internal/cmd/logs.go index c30b2b177506d1e5db90820e8e651435611f9093..85921c4e4354194d0d260e814fc61222c114d3ef 100644 --- a/internal/cmd/logs.go +++ b/internal/cmd/logs.go @@ -28,6 +28,11 @@ var logsCmd = &cobra.Command{ return fmt.Errorf("failed to get current working directory: %v", err) } + dataDir, err := cmd.Flags().GetString("data-dir") + if err != nil { + return fmt.Errorf("failed to get data directory: %v", err) + } + follow, err := cmd.Flags().GetBool("follow") if err != nil { return fmt.Errorf("failed to get follow flag: %v", err) @@ -41,7 +46,7 @@ var logsCmd = &cobra.Command{ log.SetLevel(log.DebugLevel) log.SetOutput(os.Stdout) - cfg, err := config.Load(cwd, false) + cfg, err := config.Load(cwd, dataDir, false) if err != nil { return fmt.Errorf("failed to load configuration: %v", err) } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index bb0294658e4502838594242a40344edaeeed716c..ee167814a1688ae45238d92f0cae78a7e86c0ccd 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -21,6 +21,7 @@ import ( func init() { rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory") + rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory") rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug") rootCmd.Flags().BoolP("help", "h", false, "Help") @@ -45,6 +46,9 @@ crush -d # Run with debug logging in a specific directory crush -d -c /path/to/project +# Run with custom data directory +crush -D /path/to/custom/.crush + # Print version crush -v @@ -96,6 +100,7 @@ func Execute() { func setupApp(cmd *cobra.Command) (*app.App, error) { debug, _ := cmd.Flags().GetBool("debug") yolo, _ := cmd.Flags().GetBool("yolo") + dataDir, _ := cmd.Flags().GetString("data-dir") ctx := cmd.Context() cwd, err := ResolveCwd(cmd) @@ -103,7 +108,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - cfg, err := config.Init(cwd, debug) + cfg, err := config.Init(cwd, dataDir, debug) if err != nil { return nil, err } diff --git a/internal/config/init.go b/internal/config/init.go index ff44d43bb878f579d003c84537fcd970f9e52f9e..f97272cefa779319a752927456c34fbcff97e3b6 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -19,8 +19,8 @@ type ProjectInitFlag struct { // TODO: we need to remove the global config instance keeping it now just until everything is migrated var instance atomic.Pointer[Config] -func Init(workingDir string, debug bool) (*Config, error) { - cfg, err := Load(workingDir, debug) +func Init(workingDir, dataDir string, debug bool) (*Config, error) { + cfg, err := Load(workingDir, dataDir, debug) if err != nil { return nil, err } diff --git a/internal/config/load.go b/internal/config/load.go index 9810f64cdc4a8672aca95ac4aa6fbffc0881e72c..4858b423903f9898a3d0b74e610b8689d41fc84c 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -37,7 +37,7 @@ func LoadReader(fd io.Reader) (*Config, error) { } // Load loads the configuration from the default paths. -func Load(workingDir string, debug bool) (*Config, error) { +func Load(workingDir, dataDir string, debug bool) (*Config, error) { // uses default config paths configPaths := []string{ globalConfig(), @@ -52,7 +52,7 @@ func Load(workingDir string, debug bool) (*Config, error) { cfg.dataConfigDir = GlobalConfigData() - cfg.setDefaults(workingDir) + cfg.setDefaults(workingDir, dataDir) if debug { cfg.Options.Debug = true @@ -299,7 +299,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know return nil } -func (c *Config) setDefaults(workingDir string) { +func (c *Config) setDefaults(workingDir, dataDir string) { c.workingDir = workingDir if c.Options == nil { c.Options = &Options{} @@ -317,6 +317,10 @@ func (c *Config) setDefaults(workingDir string) { c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory) } } + // explicit dataDir flag always takes precedence + if dataDir != "" { + c.Options.DataDirectory = dataDir + } if c.Providers == nil { c.Providers = csync.NewMap[string, ProviderConfig]() } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 7d9a3038e1c56ec0ab1672ba9a53e4752e418804..a83ab2b94fa29ade149b968c700f22b34b4e86fd 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -39,7 +39,7 @@ func TestConfig_LoadFromReaders(t *testing.T) { func TestConfig_setDefaults(t *testing.T) { cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") require.NotNil(t, cfg.Options) require.NotNil(t, cfg.Options.TUI) @@ -68,7 +68,7 @@ func TestConfig_configureProviders(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) @@ -110,7 +110,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { }, }, }) - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", @@ -153,7 +153,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) @@ -188,7 +188,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", @@ -217,7 +217,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -239,7 +239,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", @@ -262,7 +262,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "VERTEXAI_PROJECT": "test-project", "VERTEXAI_LOCATION": "us-central1", @@ -293,7 +293,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", "GOOGLE_CLOUD_PROJECT": "test-project", @@ -319,7 +319,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "true", "GOOGLE_CLOUD_LOCATION": "us-central1", @@ -344,7 +344,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) @@ -472,7 +472,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", @@ -502,7 +502,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -525,7 +525,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -547,7 +547,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -572,7 +572,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -597,7 +597,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -625,7 +625,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -655,7 +655,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -688,7 +688,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", @@ -721,7 +721,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -752,7 +752,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) @@ -783,7 +783,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", @@ -820,7 +820,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -856,7 +856,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -886,7 +886,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{} - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -929,7 +929,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -973,7 +973,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -1015,7 +1015,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }), } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -1063,7 +1063,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { }, }, } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -1125,7 +1125,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { }, }, } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) @@ -1170,7 +1170,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { }, }, } - cfg.setDefaults("/tmp") + cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go index ef79803c8a8aa1ee3fe6cb7de8bc8fa86f26c03c..db2edbb7e9829af0c07ada532ee1d0cefb51463b 100644 --- a/internal/llm/provider/openai_test.go +++ b/internal/llm/provider/openai_test.go @@ -17,7 +17,7 @@ import ( ) func TestMain(m *testing.M) { - _, err := config.Init(".", true) + _, err := config.Init(".", "", true) if err != nil { panic("Failed to initialize config: " + err.Error()) }