chore: change config signature

Kujtim Hoxha created

- make some functions pointer methods
- add env tests

Change summary

cmd/logs.go             |   3 
pkg/config/config.go    |  59 ++++++++++++++++-
pkg/config/load.go      |  25 +++++-
pkg/config/load_test.go |  39 ++++++-----
pkg/env/env_test.go     | 142 +++++++++++++++++++++++++++++++++++++++++++
pkg/log/log.go          |   9 -
6 files changed, 239 insertions(+), 38 deletions(-)

Detailed changes

cmd/logs.go 🔗

@@ -8,7 +8,6 @@ import (
 	"time"
 
 	"github.com/charmbracelet/crush/pkg/config"
-	"github.com/charmbracelet/crush/pkg/env"
 	"github.com/charmbracelet/log/v2"
 	"github.com/nxadm/tail"
 	"github.com/spf13/cobra"
@@ -27,7 +26,7 @@ var logsCmd = &cobra.Command{
 		if err != nil {
 			return fmt.Errorf("failed to get current working directory: %v", err)
 		}
-		cfg, err := config.Load(cwd, env.New())
+		cfg, err := config.Load(cwd, false)
 		if err != nil {
 			return fmt.Errorf("failed to load configuration: %v", err)
 		}

pkg/config/config.go 🔗

@@ -1,6 +1,11 @@
 package config
 
-import "github.com/charmbracelet/crush/internal/fur/provider"
+import (
+	"slices"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/fur/provider"
+)
 
 const (
 	appName              = "crush"
@@ -70,7 +75,7 @@ const (
 	MCPHttp  MCPType = "http"
 )
 
-type MCP struct {
+type MCPConfig struct {
 	Command string   `json:"command,omitempty" `
 	Env     []string `json:"env,omitempty"`
 	Args    []string `json:"args,omitempty"`
@@ -103,20 +108,64 @@ type Options struct {
 	DataDirectory string `json:"data_directory,omitempty"`
 }
 
+type MCPs map[string]MCPConfig
+
+type MCP struct {
+	Name string    `json:"name"`
+	MCP  MCPConfig `json:"mcp"`
+}
+
+func (m MCPs) Sorted() []MCP {
+	sorted := make([]MCP, 0, len(m))
+	for k, v := range m {
+		sorted = append(sorted, MCP{
+			Name: k,
+			MCP:  v,
+		})
+	}
+	slices.SortFunc(sorted, func(a, b MCP) int {
+		return strings.Compare(a.Name, b.Name)
+	})
+	return sorted
+}
+
+type LSPs map[string]LSPConfig
+
+type LSP struct {
+	Name string    `json:"name"`
+	LSP  LSPConfig `json:"lsp"`
+}
+
+func (l LSPs) Sorted() []LSP {
+	sorted := make([]LSP, 0, len(l))
+	for k, v := range l {
+		sorted = append(sorted, LSP{
+			Name: k,
+			LSP:  v,
+		})
+	}
+	slices.SortFunc(sorted, func(a, b LSP) int {
+		return strings.Compare(a.Name, b.Name)
+	})
+	return sorted
+}
+
 // Config holds the configuration for crush.
 type Config struct {
-	workingDir string `json:"-"`
 	// We currently only support large/small as values here.
 	Models map[string]SelectedModel `json:"models,omitempty"`
 
 	// The providers that are configured
 	Providers map[string]ProviderConfig `json:"providers,omitempty"`
 
-	MCP map[string]MCP `json:"mcp,omitempty"`
+	MCP MCPs `json:"mcp,omitempty"`
 
-	LSP map[string]LSPConfig `json:"lsp,omitempty"`
+	LSP LSPs `json:"lsp,omitempty"`
 
 	Options *Options `json:"options,omitempty"`
+
+	// Internal
+	workingDir string `json:"-"`
 }
 
 func (c *Config) WorkingDir() string {

pkg/config/load.go 🔗

@@ -13,6 +13,7 @@ import (
 	"github.com/charmbracelet/crush/internal/fur/client"
 	"github.com/charmbracelet/crush/internal/fur/provider"
 	"github.com/charmbracelet/crush/pkg/env"
+	"github.com/charmbracelet/crush/pkg/log"
 )
 
 // LoadReader config via io.Reader.
@@ -31,7 +32,7 @@ func LoadReader(fd io.Reader) (*Config, error) {
 }
 
 // Load loads the configuration from the default paths.
-func Load(workingDir string, env env.Env) (*Config, error) {
+func Load(workingDir string, debug bool) (*Config, error) {
 	// uses default config paths
 	configPaths := []string{
 		globalConfig(),
@@ -40,6 +41,17 @@ func Load(workingDir string, env env.Env) (*Config, error) {
 		filepath.Join(workingDir, fmt.Sprintf(".%s.json", appName)),
 	}
 	cfg, err := loadFromConfigPaths(configPaths)
+
+	if debug {
+		cfg.Options.Debug = true
+	}
+
+	// Init logs
+	log.Init(
+		filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)),
+		cfg.Options.Debug,
+	)
+
 	if err != nil {
 		return nil, fmt.Errorf("failed to load config: %w", err)
 	}
@@ -47,7 +59,7 @@ func Load(workingDir string, env env.Env) (*Config, error) {
 	// e.x validate the models
 	// e.x validate provider config
 
-	setDefaults(workingDir, cfg)
+	cfg.setDefaults(workingDir)
 
 	// Load known providers, this loads the config from fur
 	providers, err := LoadProviders(client.New())
@@ -55,16 +67,17 @@ func Load(workingDir string, env env.Env) (*Config, error) {
 		return nil, fmt.Errorf("failed to load providers: %w", err)
 	}
 
+	env := env.New()
 	// Configure providers
 	valueResolver := NewShellVariableResolver(env)
-	if err := configureProviders(cfg, env, valueResolver, providers); err != nil {
+	if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
 		return nil, fmt.Errorf("failed to configure providers: %w", err)
 	}
 
 	return cfg, nil
 }
 
-func configureProviders(cfg *Config, env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
+func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
 	for _, p := range knownProviders {
 
 		config, ok := cfg.Providers[string(p.ID)]
@@ -169,7 +182,7 @@ func hasAWSCredentials(env env.Env) bool {
 	return false
 }
 
-func setDefaults(workingDir string, cfg *Config) {
+func (cfg *Config) setDefaults(workingDir string) {
 	cfg.workingDir = workingDir
 	if cfg.Options == nil {
 		cfg.Options = &Options{}
@@ -190,7 +203,7 @@ func setDefaults(workingDir string, cfg *Config) {
 		cfg.Models = make(map[string]SelectedModel)
 	}
 	if cfg.MCP == nil {
-		cfg.MCP = make(map[string]MCP)
+		cfg.MCP = make(map[string]MCPConfig)
 	}
 	if cfg.LSP == nil {
 		cfg.LSP = make(map[string]LSPConfig)

pkg/config/load_test.go 🔗

@@ -27,7 +27,7 @@ func TestConfig_LoadFromReaders(t *testing.T) {
 func TestConfig_setDefaults(t *testing.T) {
 	cfg := &Config{}
 
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 
 	assert.NotNil(t, cfg.Options)
 	assert.NotNil(t, cfg.Options.TUI)
@@ -56,12 +56,12 @@ func TestConfig_configureProviders(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	assert.Len(t, cfg.Providers, 1)
 
@@ -98,12 +98,13 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
 			},
 		},
 	}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
+
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	assert.Len(t, cfg.Providers, 1)
 
@@ -139,12 +140,12 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
 			},
 		},
 	}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{
 		"OPENAI_API_KEY": "test-key",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	// Should be to because of the env variable
 	assert.Len(t, cfg.Providers, 2)
@@ -171,13 +172,13 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "test-key-id",
 		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	assert.Len(t, cfg.Providers, 1)
 
@@ -200,10 +201,10 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	// Provider should not be configured without credentials
 	assert.Len(t, cfg.Providers, 0)
@@ -222,13 +223,13 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{
 		"AWS_ACCESS_KEY_ID":     "test-key-id",
 		"AWS_SECRET_ACCESS_KEY": "test-secret-key",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.Error(t, err)
 }
 
@@ -245,14 +246,14 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{
 		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 		"GOOGLE_CLOUD_PROJECT":      "test-project",
 		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	assert.Len(t, cfg.Providers, 1)
 
@@ -277,14 +278,14 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{
 		"GOOGLE_GENAI_USE_VERTEXAI": "false",
 		"GOOGLE_CLOUD_PROJECT":      "test-project",
 		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	// Provider should not be configured without proper credentials
 	assert.Len(t, cfg.Providers, 0)
@@ -303,13 +304,13 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
 	}
 
 	cfg := &Config{}
-	setDefaults("/tmp", cfg)
+	cfg.setDefaults("/tmp")
 	env := env.NewFromMap(map[string]string{
 		"GOOGLE_GENAI_USE_VERTEXAI": "true",
 		"GOOGLE_CLOUD_LOCATION":     "us-central1",
 	})
 	resolver := NewEnvironmentVariableResolver(env)
-	err := configureProviders(cfg, env, resolver, knownProviders)
+	err := cfg.configureProviders(env, resolver, knownProviders)
 	assert.NoError(t, err)
 	// Provider should not be configured without project
 	assert.Len(t, cfg.Providers, 0)

pkg/env/env_test.go 🔗

@@ -0,0 +1,142 @@
+package env
+
+import (
+	"os"
+	"strings"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestOsEnv_Get(t *testing.T) {
+	env := New()
+
+	// Test getting an existing environment variable
+	os.Setenv("TEST_VAR", "test_value")
+	defer os.Unsetenv("TEST_VAR")
+
+	value := env.Get("TEST_VAR")
+	assert.Equal(t, "test_value", value)
+
+	// Test getting a non-existent environment variable
+	value = env.Get("NON_EXISTENT_VAR")
+	assert.Equal(t, "", value)
+}
+
+func TestOsEnv_Env(t *testing.T) {
+	env := New()
+
+	envVars := env.Env()
+
+	// Environment should not be empty in normal circumstances
+	assert.NotNil(t, envVars)
+	assert.Greater(t, len(envVars), 0)
+
+	// Each environment variable should be in key=value format
+	for _, envVar := range envVars {
+		assert.Contains(t, envVar, "=")
+	}
+}
+
+func TestNewFromMap(t *testing.T) {
+	testMap := map[string]string{
+		"KEY1": "value1",
+		"KEY2": "value2",
+	}
+
+	env := NewFromMap(testMap)
+	assert.NotNil(t, env)
+	assert.IsType(t, &mapEnv{}, env)
+}
+
+func TestMapEnv_Get(t *testing.T) {
+	testMap := map[string]string{
+		"KEY1": "value1",
+		"KEY2": "value2",
+	}
+
+	env := NewFromMap(testMap)
+
+	// Test getting existing keys
+	assert.Equal(t, "value1", env.Get("KEY1"))
+	assert.Equal(t, "value2", env.Get("KEY2"))
+
+	// Test getting non-existent key
+	assert.Equal(t, "", env.Get("NON_EXISTENT"))
+}
+
+func TestMapEnv_Env(t *testing.T) {
+	t.Run("with values", func(t *testing.T) {
+		testMap := map[string]string{
+			"KEY1": "value1",
+			"KEY2": "value2",
+		}
+
+		env := NewFromMap(testMap)
+		envVars := env.Env()
+
+		assert.Len(t, envVars, 2)
+
+		// Convert to map for easier testing (order is not guaranteed)
+		envMap := make(map[string]string)
+		for _, envVar := range envVars {
+			parts := strings.SplitN(envVar, "=", 2)
+			assert.Len(t, parts, 2)
+			envMap[parts[0]] = parts[1]
+		}
+
+		assert.Equal(t, "value1", envMap["KEY1"])
+		assert.Equal(t, "value2", envMap["KEY2"])
+	})
+
+	t.Run("empty map", func(t *testing.T) {
+		env := NewFromMap(map[string]string{})
+		envVars := env.Env()
+		assert.Nil(t, envVars)
+	})
+
+	t.Run("nil map", func(t *testing.T) {
+		env := NewFromMap(nil)
+		envVars := env.Env()
+		assert.Nil(t, envVars)
+	})
+}
+
+func TestMapEnv_GetEmptyValue(t *testing.T) {
+	testMap := map[string]string{
+		"EMPTY_KEY":  "",
+		"NORMAL_KEY": "value",
+	}
+
+	env := NewFromMap(testMap)
+
+	// Test that empty values are returned correctly
+	assert.Equal(t, "", env.Get("EMPTY_KEY"))
+	assert.Equal(t, "value", env.Get("NORMAL_KEY"))
+}
+
+func TestMapEnv_EnvFormat(t *testing.T) {
+	testMap := map[string]string{
+		"KEY_WITH_EQUALS": "value=with=equals",
+		"KEY_WITH_SPACES": "value with spaces",
+	}
+
+	env := NewFromMap(testMap)
+	envVars := env.Env()
+
+	assert.Len(t, envVars, 2)
+
+	// Check that the format is correct even with special characters
+	found := make(map[string]bool)
+	for _, envVar := range envVars {
+		if envVar == "KEY_WITH_EQUALS=value=with=equals" {
+			found["equals"] = true
+		}
+		if envVar == "KEY_WITH_SPACES=value with spaces" {
+			found["spaces"] = true
+		}
+	}
+
+	assert.True(t, found["equals"], "Should handle values with equals signs")
+	assert.True(t, found["spaces"], "Should handle values with spaces")
+}

pkg/log/log.go 🔗

@@ -2,20 +2,17 @@ package log
 
 import (
 	"log/slog"
-	"path/filepath"
 	"sync"
 
-	"github.com/charmbracelet/crush/pkg/config"
-
 	"gopkg.in/natefinch/lumberjack.v2"
 )
 
 var initOnce sync.Once
 
-func Init(cfg *config.Config) {
+func Init(logFile string, debug bool) {
 	initOnce.Do(func() {
 		logRotator := &lumberjack.Logger{
-			Filename:   filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log"),
+			Filename:   logFile,
 			MaxSize:    10,    // Max size in MB
 			MaxBackups: 0,     // Number of backups
 			MaxAge:     30,    // Days
@@ -23,7 +20,7 @@ func Init(cfg *config.Config) {
 		}
 
 		level := slog.LevelInfo
-		if cfg.Options.Debug {
+		if debug {
 			level = slog.LevelDebug
 		}