From e74e9dfc43d8d1ef68fa2543d75961cbb179e9c3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 4 Jul 2025 23:25:34 +0200 Subject: [PATCH] chore: change config signature - make some functions pointer methods - add env tests --- cmd/logs.go | 3 +- pkg/config/config.go | 59 +++++++++++++++-- pkg/config/load.go | 25 +++++-- pkg/config/load_test.go | 39 +++++------ pkg/env/env_test.go | 142 ++++++++++++++++++++++++++++++++++++++++ pkg/log/log.go | 9 +-- 6 files changed, 239 insertions(+), 38 deletions(-) create mode 100644 pkg/env/env_test.go diff --git a/cmd/logs.go b/cmd/logs.go index cac9cdbc4749911dd6db52e29776be346242add5..4e69e14b9790dc5985bc8f5e89c33066ddce628a 100644 --- a/cmd/logs.go +++ b/cmd/logs.go @@ -8,7 +8,6 @@ import ( "time" "github.com/charmbracelet/crush/pkg/config" - "github.com/charmbracelet/crush/pkg/env" "github.com/charmbracelet/log/v2" "github.com/nxadm/tail" "github.com/spf13/cobra" @@ -27,7 +26,7 @@ var logsCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to get current working directory: %v", err) } - cfg, err := config.Load(cwd, env.New()) + cfg, err := config.Load(cwd, false) if err != nil { return fmt.Errorf("failed to load configuration: %v", err) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 70f569b4e998230c5770122c164ea5232d67d919..476c356e95868ce1ac7679c309abd939050b8b6a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,6 +1,11 @@ package config -import "github.com/charmbracelet/crush/internal/fur/provider" +import ( + "slices" + "strings" + + "github.com/charmbracelet/crush/internal/fur/provider" +) const ( appName = "crush" @@ -70,7 +75,7 @@ const ( MCPHttp MCPType = "http" ) -type MCP struct { +type MCPConfig struct { Command string `json:"command,omitempty" ` Env []string `json:"env,omitempty"` Args []string `json:"args,omitempty"` @@ -103,20 +108,64 @@ type Options struct { DataDirectory string `json:"data_directory,omitempty"` } +type MCPs map[string]MCPConfig + +type MCP struct { + Name string `json:"name"` + MCP MCPConfig `json:"mcp"` +} + +func (m MCPs) Sorted() []MCP { + sorted := make([]MCP, 0, len(m)) + for k, v := range m { + sorted = append(sorted, MCP{ + Name: k, + MCP: v, + }) + } + slices.SortFunc(sorted, func(a, b MCP) int { + return strings.Compare(a.Name, b.Name) + }) + return sorted +} + +type LSPs map[string]LSPConfig + +type LSP struct { + Name string `json:"name"` + LSP LSPConfig `json:"lsp"` +} + +func (l LSPs) Sorted() []LSP { + sorted := make([]LSP, 0, len(l)) + for k, v := range l { + sorted = append(sorted, LSP{ + Name: k, + LSP: v, + }) + } + slices.SortFunc(sorted, func(a, b LSP) int { + return strings.Compare(a.Name, b.Name) + }) + return sorted +} + // Config holds the configuration for crush. type Config struct { - workingDir string `json:"-"` // We currently only support large/small as values here. Models map[string]SelectedModel `json:"models,omitempty"` // The providers that are configured Providers map[string]ProviderConfig `json:"providers,omitempty"` - MCP map[string]MCP `json:"mcp,omitempty"` + MCP MCPs `json:"mcp,omitempty"` - LSP map[string]LSPConfig `json:"lsp,omitempty"` + LSP LSPs `json:"lsp,omitempty"` Options *Options `json:"options,omitempty"` + + // Internal + workingDir string `json:"-"` } func (c *Config) WorkingDir() string { diff --git a/pkg/config/load.go b/pkg/config/load.go index 19f84625891616a4e60cbdd47566af6bb9caa6a4..7b996aa9c99867ade245a92fe4b783f34148fe66 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -13,6 +13,7 @@ import ( "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/pkg/env" + "github.com/charmbracelet/crush/pkg/log" ) // LoadReader config via io.Reader. @@ -31,7 +32,7 @@ func LoadReader(fd io.Reader) (*Config, error) { } // Load loads the configuration from the default paths. -func Load(workingDir string, env env.Env) (*Config, error) { +func Load(workingDir string, debug bool) (*Config, error) { // uses default config paths configPaths := []string{ globalConfig(), @@ -40,6 +41,17 @@ func Load(workingDir string, env env.Env) (*Config, error) { filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)), } cfg, err := loadFromConfigPaths(configPaths) + + if debug { + cfg.Options.Debug = true + } + + // Init logs + log.Init( + filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)), + cfg.Options.Debug, + ) + if err != nil { return nil, fmt.Errorf("failed to load config: %w", err) } @@ -47,7 +59,7 @@ func Load(workingDir string, env env.Env) (*Config, error) { // e.x validate the models // e.x validate provider config - setDefaults(workingDir, cfg) + cfg.setDefaults(workingDir) // Load known providers, this loads the config from fur providers, err := LoadProviders(client.New()) @@ -55,16 +67,17 @@ func Load(workingDir string, env env.Env) (*Config, error) { return nil, fmt.Errorf("failed to load providers: %w", err) } + env := env.New() // Configure providers valueResolver := NewShellVariableResolver(env) - if err := configureProviders(cfg, env, valueResolver, providers); err != nil { + if err := cfg.configureProviders(env, valueResolver, providers); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } return cfg, nil } -func configureProviders(cfg *Config, env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { +func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { for _, p := range knownProviders { config, ok := cfg.Providers[string(p.ID)] @@ -169,7 +182,7 @@ func hasAWSCredentials(env env.Env) bool { return false } -func setDefaults(workingDir string, cfg *Config) { +func (cfg *Config) setDefaults(workingDir string) { cfg.workingDir = workingDir if cfg.Options == nil { cfg.Options = &Options{} @@ -190,7 +203,7 @@ func setDefaults(workingDir string, cfg *Config) { cfg.Models = make(map[string]SelectedModel) } if cfg.MCP == nil { - cfg.MCP = make(map[string]MCP) + cfg.MCP = make(map[string]MCPConfig) } if cfg.LSP == nil { cfg.LSP = make(map[string]LSPConfig) diff --git a/pkg/config/load_test.go b/pkg/config/load_test.go index 149258c0ff3ffeafb9db4744cbc6ee70afec33bc..e2dd943d58b24f55b433bbd7050d95c9bda18277 100644 --- a/pkg/config/load_test.go +++ b/pkg/config/load_test.go @@ -27,7 +27,7 @@ func TestConfig_LoadFromReaders(t *testing.T) { func TestConfig_setDefaults(t *testing.T) { cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") assert.NotNil(t, cfg.Options) assert.NotNil(t, cfg.Options.TUI) @@ -56,12 +56,12 @@ func TestConfig_configureProviders(t *testing.T) { } cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Len(t, cfg.Providers, 1) @@ -98,12 +98,13 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { }, }, } - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Len(t, cfg.Providers, 1) @@ -139,12 +140,12 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { }, }, } - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Should be to because of the env variable assert.Len(t, cfg.Providers, 2) @@ -171,13 +172,13 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { } cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Len(t, cfg.Providers, 1) @@ -200,10 +201,10 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { } cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without credentials assert.Len(t, cfg.Providers, 0) @@ -222,13 +223,13 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { } cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.Error(t, err) } @@ -245,14 +246,14 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { } cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "true", "GOOGLE_CLOUD_PROJECT": "test-project", "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Len(t, cfg.Providers, 1) @@ -277,14 +278,14 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { } cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", "GOOGLE_CLOUD_PROJECT": "test-project", "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without proper credentials assert.Len(t, cfg.Providers, 0) @@ -303,13 +304,13 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { } cfg := &Config{} - setDefaults("/tmp", cfg) + cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "true", "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := configureProviders(cfg, env, resolver, knownProviders) + err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without project assert.Len(t, cfg.Providers, 0) diff --git a/pkg/env/env_test.go b/pkg/env/env_test.go new file mode 100644 index 0000000000000000000000000000000000000000..73fcb1c2cec876d88f686b8cb2861bb02fd1a632 --- /dev/null +++ b/pkg/env/env_test.go @@ -0,0 +1,142 @@ +package env + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOsEnv_Get(t *testing.T) { + env := New() + + // Test getting an existing environment variable + os.Setenv("TEST_VAR", "test_value") + defer os.Unsetenv("TEST_VAR") + + value := env.Get("TEST_VAR") + assert.Equal(t, "test_value", value) + + // Test getting a non-existent environment variable + value = env.Get("NON_EXISTENT_VAR") + assert.Equal(t, "", value) +} + +func TestOsEnv_Env(t *testing.T) { + env := New() + + envVars := env.Env() + + // Environment should not be empty in normal circumstances + assert.NotNil(t, envVars) + assert.Greater(t, len(envVars), 0) + + // Each environment variable should be in key=value format + for _, envVar := range envVars { + assert.Contains(t, envVar, "=") + } +} + +func TestNewFromMap(t *testing.T) { + testMap := map[string]string{ + "KEY1": "value1", + "KEY2": "value2", + } + + env := NewFromMap(testMap) + assert.NotNil(t, env) + assert.IsType(t, &mapEnv{}, env) +} + +func TestMapEnv_Get(t *testing.T) { + testMap := map[string]string{ + "KEY1": "value1", + "KEY2": "value2", + } + + env := NewFromMap(testMap) + + // Test getting existing keys + assert.Equal(t, "value1", env.Get("KEY1")) + assert.Equal(t, "value2", env.Get("KEY2")) + + // Test getting non-existent key + assert.Equal(t, "", env.Get("NON_EXISTENT")) +} + +func TestMapEnv_Env(t *testing.T) { + t.Run("with values", func(t *testing.T) { + testMap := map[string]string{ + "KEY1": "value1", + "KEY2": "value2", + } + + env := NewFromMap(testMap) + envVars := env.Env() + + assert.Len(t, envVars, 2) + + // Convert to map for easier testing (order is not guaranteed) + envMap := make(map[string]string) + for _, envVar := range envVars { + parts := strings.SplitN(envVar, "=", 2) + assert.Len(t, parts, 2) + envMap[parts[0]] = parts[1] + } + + assert.Equal(t, "value1", envMap["KEY1"]) + assert.Equal(t, "value2", envMap["KEY2"]) + }) + + t.Run("empty map", func(t *testing.T) { + env := NewFromMap(map[string]string{}) + envVars := env.Env() + assert.Nil(t, envVars) + }) + + t.Run("nil map", func(t *testing.T) { + env := NewFromMap(nil) + envVars := env.Env() + assert.Nil(t, envVars) + }) +} + +func TestMapEnv_GetEmptyValue(t *testing.T) { + testMap := map[string]string{ + "EMPTY_KEY": "", + "NORMAL_KEY": "value", + } + + env := NewFromMap(testMap) + + // Test that empty values are returned correctly + assert.Equal(t, "", env.Get("EMPTY_KEY")) + assert.Equal(t, "value", env.Get("NORMAL_KEY")) +} + +func TestMapEnv_EnvFormat(t *testing.T) { + testMap := map[string]string{ + "KEY_WITH_EQUALS": "value=with=equals", + "KEY_WITH_SPACES": "value with spaces", + } + + env := NewFromMap(testMap) + envVars := env.Env() + + assert.Len(t, envVars, 2) + + // Check that the format is correct even with special characters + found := make(map[string]bool) + for _, envVar := range envVars { + if envVar == "KEY_WITH_EQUALS=value=with=equals" { + found["equals"] = true + } + if envVar == "KEY_WITH_SPACES=value with spaces" { + found["spaces"] = true + } + } + + assert.True(t, found["equals"], "Should handle values with equals signs") + assert.True(t, found["spaces"], "Should handle values with spaces") +} diff --git a/pkg/log/log.go b/pkg/log/log.go index c17ca366936ce540b9bfdfc588ef8c862e74ac95..11174a5071c72b6773cec03d3c849d4faff9bc39 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -2,20 +2,17 @@ package log import ( "log/slog" - "path/filepath" "sync" - "github.com/charmbracelet/crush/pkg/config" - "gopkg.in/natefinch/lumberjack.v2" ) var initOnce sync.Once -func Init(cfg *config.Config) { +func Init(logFile string, debug bool) { initOnce.Do(func() { logRotator := &lumberjack.Logger{ - Filename: filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log"), + Filename: logFile, MaxSize: 10, // Max size in MB MaxBackups: 0, // Number of backups MaxAge: 30, // Days @@ -23,7 +20,7 @@ func Init(cfg *config.Config) { } level := slog.LevelInfo - if cfg.Options.Debug { + if debug { level = slog.LevelDebug }