From 380d524ec1f9c77dbff62298adb05c40d7f4b937 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 12:10:26 -0300 Subject: [PATCH 01/21] perf: init tools in background --- internal/app/app.go | 2 +- internal/cmd/root.go | 4 ++ internal/llm/agent/agent.go | 114 ++++++++++++++++++++---------------- 3 files changed, 70 insertions(+), 50 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index d63c90c6e2599f63e3a65cd8069b53638f45cc5f..170debad340e2cb33fcb7a1c9fe814c184573c9b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -80,7 +80,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { app.setupEvents() // Initialize LSP clients in the background. - go app.initLSPClients(ctx) + app.initLSPClients(ctx) // TODO: remove the concept of agent config, most likely. if cfg.IsConfigured() { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index d63160992141da26b6a26610b06f1b601213e00d..e782898f42e38bc4b06f890b76a2deebb96b9df1 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -6,6 +6,7 @@ import ( "io" "log/slog" "os" + "time" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" @@ -83,12 +84,15 @@ to assist developers in writing, debugging, and understanding code directly from return err } + slog.Info("Initing...") + now := time.Now() app, err := app.New(ctx, conn, cfg) if err != nil { slog.Error(fmt.Sprintf("Failed to create app instance: %v", err)) return err } defer app.Shutdown() + slog.Info("Init done", "took", time.Since(now).String()) prompt, err = maybePrependStdin(prompt) if err != nil { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 907a2348f838aa2f2ba6792db9b768eb656904a8..291528d36594a5178f72ae64bb565f18e36c1cae 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -8,6 +8,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/charmbracelet/crush/internal/config" @@ -67,7 +68,9 @@ type agent struct { sessions session.Service messages message.Service - tools []tools.BaseTool + toolsDone atomic.Bool + tools []tools.BaseTool + provider provider.Provider providerID string @@ -94,46 +97,7 @@ func NewAgent( ) (Service, error) { ctx := context.Background() cfg := config.Get() - otherTools := GetMCPTools(ctx, permissions, cfg) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - cwd := cfg.WorkingDir() - allTools := []tools.BaseTool{ - tools.NewBashTool(permissions, cwd), - tools.NewDownloadTool(permissions, cwd), - tools.NewEditTool(lspClients, permissions, history, cwd), - tools.NewFetchTool(permissions, cwd), - tools.NewGlobTool(cwd), - tools.NewGrepTool(cwd), - tools.NewLsTool(cwd), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients, cwd), - tools.NewWriteTool(lspClients, permissions, history, cwd), - } - - if agentCfg.ID == "coder" { - taskAgentCfg := config.Get().Agents["task"] - if taskAgentCfg.ID == "" { - return nil, fmt.Errorf("task agent not found in config") - } - taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) - if err != nil { - return nil, fmt.Errorf("failed to create task agent: %w", err) - } - - allTools = append( - allTools, - NewAgentTool( - taskAgent, - sessions, - messages, - ), - ) - } - - allTools = append(allTools, otherTools...) providerCfg := config.Get().GetProviderForModel(agentCfg.Model) if providerCfg == nil { return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) @@ -190,15 +154,22 @@ func NewAgent( return nil, err } - agentTools := []tools.BaseTool{} - if agentCfg.AllowedTools == nil { - agentTools = allTools - } else { - for _, tool := range allTools { - if slices.Contains(agentCfg.AllowedTools, tool.Name()) { - agentTools = append(agentTools, tool) - } + var agentTool tools.BaseTool + if agentCfg.ID == "coder" { + taskAgentCfg := config.Get().Agents["task"] + if taskAgentCfg.ID == "" { + return nil, fmt.Errorf("task agent not found in config") } + taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) + if err != nil { + return nil, fmt.Errorf("failed to create task agent: %w", err) + } + + agentTool = NewAgentTool( + taskAgent, + sessions, + messages, + ) } agent := &agent{ @@ -208,13 +179,55 @@ func NewAgent( providerID: string(providerCfg.ID), messages: messages, sessions: sessions, - tools: agentTools, titleProvider: titleProvider, summarizeProvider: summarizeProvider, summarizeProviderID: string(smallModelProviderCfg.ID), activeRequests: sync.Map{}, } + go func() { + slog.Info("Initializing agent tools", "agent", agentCfg.ID) + + cwd := cfg.WorkingDir() + allTools := []tools.BaseTool{ + tools.NewBashTool(permissions, cwd), + tools.NewDownloadTool(permissions, cwd), + tools.NewEditTool(lspClients, permissions, history, cwd), + tools.NewFetchTool(permissions, cwd), + tools.NewGlobTool(cwd), + tools.NewGrepTool(cwd), + tools.NewLsTool(cwd), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients, cwd), + tools.NewWriteTool(lspClients, permissions, history, cwd), + } + + mcpTools := GetMCPTools(ctx, permissions, cfg) + if len(lspClients) > 0 { + mcpTools = append(mcpTools, tools.NewDiagnosticsTool(lspClients)) + } + allTools = append(allTools, mcpTools...) + + if agentTool != nil { + allTools = append(allTools, agentTool) + } + + agentTools := []tools.BaseTool{} + if agentCfg.AllowedTools == nil { + agentTools = allTools + } else { + for _, tool := range allTools { + if slices.Contains(agentCfg.AllowedTools, tool.Name()) { + agentTools = append(agentTools, tool) + } + } + } + + slog.Info("Initialized agent tools", "agent", agentCfg.ID) + agent.tools = agentTools + agent.toolsDone.Store(true) + }() + return agent, nil } @@ -437,6 +450,9 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) + if !a.toolsDone.Load() { + return message.Message{}, nil, fmt.Errorf("tools not initialized yet") + } eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ From 4399c9ab30335e52cd09ad45c6424c33b5a51771 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 12:15:04 -0300 Subject: [PATCH 02/21] fix: improvements --- internal/cmd/root.go | 4 ---- internal/llm/agent/agent.go | 29 ++++++++++++++++------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index e782898f42e38bc4b06f890b76a2deebb96b9df1..d63160992141da26b6a26610b06f1b601213e00d 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -6,7 +6,6 @@ import ( "io" "log/slog" "os" - "time" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" @@ -84,15 +83,12 @@ to assist developers in writing, debugging, and understanding code directly from return err } - slog.Info("Initing...") - now := time.Now() app, err := app.New(ctx, conn, cfg) if err != nil { slog.Error(fmt.Sprintf("Failed to create app instance: %v", err)) return err } defer app.Shutdown() - slog.Info("Init done", "took", time.Since(now).String()) prompt, err = maybePrependStdin(prompt) if err != nil { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 291528d36594a5178f72ae64bb565f18e36c1cae..84ef27a86e1a91557a7af88c868fdd4ba8ce9362 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -187,6 +187,10 @@ func NewAgent( go func() { slog.Info("Initializing agent tools", "agent", agentCfg.ID) + defer func() { + slog.Info("Initialized agent tools", "agent", agentCfg.ID) + agent.toolsDone.Store(true) + }() cwd := cfg.WorkingDir() allTools := []tools.BaseTool{ @@ -203,29 +207,28 @@ func NewAgent( } mcpTools := GetMCPTools(ctx, permissions, cfg) + allTools = append(allTools, mcpTools...) + if len(lspClients) > 0 { - mcpTools = append(mcpTools, tools.NewDiagnosticsTool(lspClients)) + allTools = append(allTools, tools.NewDiagnosticsTool(lspClients)) } - allTools = append(allTools, mcpTools...) if agentTool != nil { allTools = append(allTools, agentTool) } - agentTools := []tools.BaseTool{} if agentCfg.AllowedTools == nil { - agentTools = allTools - } else { - for _, tool := range allTools { - if slices.Contains(agentCfg.AllowedTools, tool.Name()) { - agentTools = append(agentTools, tool) - } - } + agent.tools = allTools + return } - slog.Info("Initialized agent tools", "agent", agentCfg.ID) - agent.tools = agentTools - agent.toolsDone.Store(true) + var filteredTools []tools.BaseTool + for _, tool := range allTools { + if slices.Contains(agentCfg.AllowedTools, tool.Name()) { + filteredTools = append(filteredTools, tool) + } + } + agent.tools = filteredTools }() return agent, nil From 05c73bbbdf341f9e8625b0f57027676a6a45228e Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 12:16:11 -0300 Subject: [PATCH 03/21] fix: improv diff --- internal/llm/agent/agent.go | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 84ef27a86e1a91557a7af88c868fdd4ba8ce9362..501cf3aad0d1f6fa57dbda1f055e7f23a855b862 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -98,6 +98,24 @@ func NewAgent( ctx := context.Background() cfg := config.Get() + var agentTool tools.BaseTool + if agentCfg.ID == "coder" { + taskAgentCfg := config.Get().Agents["task"] + if taskAgentCfg.ID == "" { + return nil, fmt.Errorf("task agent not found in config") + } + taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) + if err != nil { + return nil, fmt.Errorf("failed to create task agent: %w", err) + } + + agentTool = NewAgentTool( + taskAgent, + sessions, + messages, + ) + } + providerCfg := config.Get().GetProviderForModel(agentCfg.Model) if providerCfg == nil { return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) @@ -154,24 +172,6 @@ func NewAgent( return nil, err } - var agentTool tools.BaseTool - if agentCfg.ID == "coder" { - taskAgentCfg := config.Get().Agents["task"] - if taskAgentCfg.ID == "" { - return nil, fmt.Errorf("task agent not found in config") - } - taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) - if err != nil { - return nil, fmt.Errorf("failed to create task agent: %w", err) - } - - agentTool = NewAgentTool( - taskAgent, - sessions, - messages, - ) - } - agent := &agent{ Broker: pubsub.NewBroker[AgentEvent](), agentCfg: agentCfg, From e9ec258bc9896dfd27c55b15ea5eaffa940fa5d8 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 12:16:40 -0300 Subject: [PATCH 04/21] wip --- internal/llm/agent/agent.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 501cf3aad0d1f6fa57dbda1f055e7f23a855b862..4160e6ef5211a7161799106be4574ab0da61e62d 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -109,11 +109,7 @@ func NewAgent( return nil, fmt.Errorf("failed to create task agent: %w", err) } - agentTool = NewAgentTool( - taskAgent, - sessions, - messages, - ) + agentTool = NewAgentTool(taskAgent, sessions, messages) } providerCfg := config.Get().GetProviderForModel(agentCfg.Model) From 8eeb04336e744134a910a67ac3478ec60295f229 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 12:58:23 -0300 Subject: [PATCH 05/21] Update agent.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/llm/agent/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 4160e6ef5211a7161799106be4574ab0da61e62d..69e675156c5c358a7f160ff1c3f7af08d0b2372a 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -450,7 +450,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) if !a.toolsDone.Load() { - return message.Message{}, nil, fmt.Errorf("tools not initialized yet") + return message.Message{}, nil, fmt.Errorf("Agent is still initializing, please wait a moment and try again") } eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) From e25a42c8f04d5f0fe97db4c384a9430d4e8b661a Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 15:11:00 -0300 Subject: [PATCH 06/21] Update internal/llm/agent/agent.go Co-authored-by: Andrey Nering --- internal/llm/agent/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 69e675156c5c358a7f160ff1c3f7af08d0b2372a..e920651d0faeb87da765c4ab67735c1c2d285001 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -450,7 +450,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) if !a.toolsDone.Load() { - return message.Message{}, nil, fmt.Errorf("Agent is still initializing, please wait a moment and try again") + return message.Message{}, nil, fmt.Errorf("agent is still initializing, please wait a moment and try again") } eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) From df56ffd4b1875a04c7b157ac5308decc6e0bab8d Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 16:52:42 -0300 Subject: [PATCH 07/21] feat: improve providers startup (#287) * feat: check providers in background Signed-off-by: Carlos Alexandro Becker * test: test csync.Map * fix: improvements --------- Signed-off-by: Carlos Alexandro Becker --- csync/slices.go | 34 ++ csync/slices_test.go | 86 ++++ internal/config/config.go | 21 +- internal/config/load.go | 85 ++-- internal/config/load_test.go | 207 ++++---- internal/csync/maps.go | 84 ++++ internal/csync/maps_test.go | 450 ++++++++++++++++++ internal/llm/agent/agent.go | 51 +- internal/tui/components/chat/splash/splash.go | 2 +- .../tui/components/dialogs/models/list.go | 6 +- .../tui/components/dialogs/models/models.go | 2 +- 11 files changed, 841 insertions(+), 187 deletions(-) create mode 100644 csync/slices.go create mode 100644 csync/slices_test.go create mode 100644 internal/csync/maps.go create mode 100644 internal/csync/maps_test.go diff --git a/csync/slices.go b/csync/slices.go new file mode 100644 index 0000000000000000000000000000000000000000..388ad074d53a9bd7188418b231afbf39adca0565 --- /dev/null +++ b/csync/slices.go @@ -0,0 +1,34 @@ +package csync + +import ( + "iter" + "sync" +) + +type LazySlice[K any] struct { + inner []K + mu sync.Mutex +} + +func NewLazySlice[K any](load func() []K) *LazySlice[K] { + s := &LazySlice[K]{} + s.mu.Lock() + go func() { + s.inner = load() + s.mu.Unlock() + }() + return s +} + +func (s *LazySlice[K]) Iter() iter.Seq[K] { + s.mu.Lock() + inner := s.inner + s.mu.Unlock() + return func(yield func(K) bool) { + for _, v := range inner { + if !yield(v) { + return + } + } + } +} diff --git a/csync/slices_test.go b/csync/slices_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d1c7af8cf30f3d58a84046f899f8dd89f80beb51 --- /dev/null +++ b/csync/slices_test.go @@ -0,0 +1,86 @@ +package csync + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLazySlice_Iter(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c"} + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.Equal(t, data, result) +} + +func TestLazySlice_IterWaitsForLoading(t *testing.T) { + t.Parallel() + + var loaded atomic.Bool + data := []string{"x", "y", "z"} + + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(100 * time.Millisecond) + loaded.Store(true) + return data + }) + + assert.False(t, loaded.Load(), "should not be loaded immediately") + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.True(t, loaded.Load(), "should be loaded after Iter") + assert.Equal(t, data, result) +} + +func TestLazySlice_EmptySlice(t *testing.T) { + t.Parallel() + + s := NewLazySlice(func() []string { + return []string{} + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.Empty(t, result) +} + +func TestLazySlice_EarlyBreak(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c", "d", "e"} + s := NewLazySlice(func() []string { + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + if len(result) == 2 { + break + } + } + + assert.Equal(t, []string{"a", "b"}, result) +} diff --git a/internal/config/config.go b/internal/config/config.go index 1c20188a12a3955fde6b6eeed9f12ea39288e328..18eca04912189415606599c5849e8a7beb592cb4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/tidwall/sjson" @@ -236,7 +237,7 @@ type Config struct { Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` // The providers that are configured - Providers map[string]ProviderConfig `json:"providers,omitempty"` + Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"` MCP MCPs `json:"mcp,omitempty"` @@ -259,8 +260,8 @@ func (c *Config) WorkingDir() string { } func (c *Config) EnabledProviders() []ProviderConfig { - enabled := make([]ProviderConfig, 0, len(c.Providers)) - for _, p := range c.Providers { + var enabled []ProviderConfig + for _, p := range c.Providers.Seq2() { if !p.Disable { enabled = append(enabled, p) } @@ -274,7 +275,7 @@ func (c *Config) IsConfigured() bool { } func (c *Config) GetModel(provider, model string) *provider.Model { - if providerConfig, ok := c.Providers[provider]; ok { + if providerConfig, ok := c.Providers.Get(provider); ok { for _, m := range providerConfig.Models { if m.ID == model { return &m @@ -289,7 +290,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi if !ok { return nil } - if providerConfig, ok := c.Providers[model.Provider]; ok { + if providerConfig, ok := c.Providers.Get(model.Provider); ok { return &providerConfig } return nil @@ -370,14 +371,10 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return fmt.Errorf("failed to save API key to config file: %w", err) } - if c.Providers == nil { - c.Providers = make(map[string]ProviderConfig) - } - - providerConfig, exists := c.Providers[providerID] + providerConfig, exists := c.Providers.Get(providerID) if exists { providerConfig.APIKey = apiKey - c.Providers[providerID] = providerConfig + c.Providers.Set(providerID, providerConfig) return nil } @@ -406,7 +403,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return fmt.Errorf("provider with ID %s not found in known providers", providerID) } // Store the updated provider config - c.Providers[providerID] = providerConfig + c.Providers.Set(providerID, providerConfig) return nil } diff --git a/internal/config/load.go b/internal/config/load.go index cd4ccd08c46e48155091407962137da2cb913869..09d65e5391b94a1f80b15e7e576ba5d3e38ef19d 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -11,6 +11,7 @@ import ( "strings" "sync" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" @@ -80,30 +81,34 @@ func Load(workingDir string, debug bool) (*Config, error) { var testResults sync.Map var wg sync.WaitGroup - for _, p := range cfg.Providers { - if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { - wg.Add(1) - go func(provider ProviderConfig) { - defer wg.Done() - err := provider.TestConnection(cfg.resolver) - testResults.Store(provider.ID, err == nil) - if err != nil { - slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) - } - }(p) - } - } - wg.Wait() - - // Remove failed providers - testResults.Range(func(key, value any) bool { - providerID := key.(string) - passed := value.(bool) - if !passed { - delete(cfg.Providers, providerID) + go func() { + slog.Info("Testing provider connections") + defer slog.Info("Provider connection tests completed") + for _, p := range cfg.Providers.Seq2() { + if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { + wg.Add(1) + go func(provider ProviderConfig) { + defer wg.Done() + err := provider.TestConnection(cfg.resolver) + testResults.Store(provider.ID, err == nil) + if err != nil { + slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) + } + }(p) + } } - return true - }) + wg.Wait() + + // Remove failed providers + testResults.Range(func(key, value any) bool { + providerID := key.(string) + passed := value.(bool) + if !passed { + cfg.Providers.Del(providerID) + } + return true + }) + }() if !cfg.IsConfigured() { slog.Warn("No providers configured") @@ -121,12 +126,12 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know knownProviderNames := make(map[string]bool) for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true - config, configExists := c.Providers[string(p.ID)] + config, configExists := c.Providers.Get(string(p.ID)) // if the user configured a known provider we need to allow it to override a couple of parameters if configExists { if config.Disable { slog.Debug("Skipping provider due to disable flag", "provider", p.ID) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) continue } if config.BaseURL != "" { @@ -182,7 +187,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if !hasVertexCredentials(env) { if configExists { slog.Warn("Skipping Vertex AI provider due to missing credentials") - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -193,7 +198,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if err != nil || endpoint == "" { if configExists { slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -203,7 +208,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if !hasAWSCredentials(env) { if configExists { slog.Warn("Skipping Bedrock provider due to missing AWS credentials") - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -218,16 +223,16 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if v == "" || err != nil { if configExists { slog.Warn("Skipping provider due to missing API key", "provider", p.ID) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } } - c.Providers[string(p.ID)] = prepared + c.Providers.Set(string(p.ID), prepared) } // validate the custom providers - for id, providerConfig := range c.Providers { + for id, providerConfig := range c.Providers.Seq2() { if knownProviderNames[id] { continue } @@ -244,7 +249,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if providerConfig.Disable { slog.Debug("Skipping custom provider due to disable flag", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if providerConfig.APIKey == "" { @@ -252,17 +257,17 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } if providerConfig.BaseURL == "" { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if len(providerConfig.Models) == 0 { slog.Warn("Skipping custom provider because the provider has no models", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if providerConfig.Type != provider.TypeOpenAI { slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) - delete(c.Providers, id) + c.Providers.Del(id) continue } @@ -273,11 +278,11 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know baseURL, err := resolver.ResolveValue(providerConfig.BaseURL) if baseURL == "" || err != nil { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err) - delete(c.Providers, id) + c.Providers.Del(id) continue } - c.Providers[id] = providerConfig + c.Providers.Set(id, providerConfig) } return nil } @@ -297,7 +302,7 @@ func (c *Config) setDefaults(workingDir string) { c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory) } if c.Providers == nil { - c.Providers = make(map[string]ProviderConfig) + c.Providers = csync.NewMap[string, ProviderConfig]() } if c.Models == nil { c.Models = make(map[SelectedModelType]SelectedModel) @@ -316,7 +321,7 @@ func (c *Config) setDefaults(workingDir string) { } func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { - if len(knownProviders) == 0 && len(c.Providers) == 0 { + if len(knownProviders) == 0 { // TODO:}&& len(c.Providers) == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return } @@ -324,7 +329,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg // Use the first provider enabled based on the known providers order // if no provider found that is known use the first provider configured for _, p := range knownProviders { - providerConfig, ok := c.Providers[string(p.ID)] + providerConfig, ok := c.Providers.Get(string(p.ID)) if !ok || providerConfig.Disable { continue } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index b96ca5e81cd265cbcd1bdf9d456603ad3f22c558..86a2356da2021dc22de88de05a80717e95aa492a 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/stretchr/testify/assert" @@ -29,9 +30,10 @@ func TestConfig_LoadFromReaders(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, loadedConfig) - assert.Len(t, loadedConfig.Providers, 1) - assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey) - assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL) + assert.Equal(t, 1, loadedConfig.Providers.Len()) + pc, _ := loadedConfig.Providers.Get("openai") + assert.Equal(t, "key2", pc.APIKey) + assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) } func TestConfig_setDefaults(t *testing.T) { @@ -73,10 +75,11 @@ func TestConfig_configureProviders(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey) } func TestConfig_configureProvidersWithOverride(t *testing.T) { @@ -92,22 +95,21 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ - "openai": { - APIKey: "xyz", - BaseURL: "https://api.openai.com/v2", - Models: []provider.Model{ - { - ID: "test-model", - Model: "Updated", - }, - { - ID: "another-model", - }, - }, + Providers: csync.NewMap[string, ProviderConfig](), + } + cfg.Providers.Set("openai", ProviderConfig{ + APIKey: "xyz", + BaseURL: "https://api.openai.com/v2", + Models: []provider.Model{ + { + ID: "test-model", + Model: "Updated", + }, + { + ID: "another-model", }, }, - } + }) cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ @@ -116,13 +118,14 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey) - assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL) - assert.Len(t, cfg.Providers["openai"].Models, 2) - assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "xyz", pc.APIKey) + assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) + assert.Len(t, pc.Models, 2) + assert.Equal(t, "Updated", pc.Models[0].Model) } func TestConfig_configureProvidersWithNewProvider(t *testing.T) { @@ -138,7 +141,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", @@ -148,7 +151,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ @@ -158,16 +161,17 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Should be to because of the env variable - assert.Len(t, cfg.Providers, 2) + assert.Equal(t, cfg.Providers.Len(), 2) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey) + pc, _ := cfg.Providers.Get("custom") + assert.Equal(t, "xyz", pc.APIKey) // Make sure we set the ID correctly - assert.Equal(t, "custom", cfg.Providers["custom"].ID) - assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL) - assert.Len(t, cfg.Providers["custom"].Models, 1) + assert.Equal(t, "custom", pc.ID) + assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL) + assert.Len(t, pc.Models, 1) - _, ok := cfg.Providers["openai"] + _, ok := cfg.Providers.Get("openai") assert.True(t, ok, "OpenAI provider should still be present") } @@ -192,9 +196,9 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) - bedrockProvider, ok := cfg.Providers["bedrock"] + bedrockProvider, ok := cfg.Providers.Get("bedrock") assert.True(t, ok, "Bedrock provider should be present") assert.Len(t, bedrockProvider.Models, 1) assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID) @@ -219,7 +223,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without credentials - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { @@ -267,9 +271,9 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) - vertexProvider, ok := cfg.Providers["vertexai"] + vertexProvider, ok := cfg.Providers.Get("vertexai") assert.True(t, ok, "VertexAI provider should be present") assert.Len(t, vertexProvider.Models, 1) assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID) @@ -300,7 +304,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without proper credentials - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { @@ -325,7 +329,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without project - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersSetProviderID(t *testing.T) { @@ -348,16 +352,17 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) // Provider ID should be set - assert.Equal(t, "openai", cfg.Providers["openai"].ID) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "openai", pc.ID) } func TestConfig_EnabledProviders(t *testing.T) { t.Run("all providers enabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -368,7 +373,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: false, }, - }, + }), } enabled := cfg.EnabledProviders() @@ -377,7 +382,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("some providers disabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -388,7 +393,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: true, }, - }, + }), } enabled := cfg.EnabledProviders() @@ -398,7 +403,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("empty providers map", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{}, + Providers: csync.NewMap[string, ProviderConfig](), } enabled := cfg.EnabledProviders() @@ -409,13 +414,13 @@ func TestConfig_EnabledProviders(t *testing.T) { func TestConfig_IsConfigured(t *testing.T) { t.Run("returns true when at least one provider is enabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: false, }, - }, + }), } assert.True(t, cfg.IsConfigured()) @@ -423,7 +428,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when no providers are configured", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{}, + Providers: csync.NewMap[string, ProviderConfig](), } assert.False(t, cfg.IsConfigured()) @@ -431,7 +436,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when all providers are disabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -442,7 +447,7 @@ func TestConfig_IsConfigured(t *testing.T) { APIKey: "key2", Disable: true, }, - }, + }), } assert.False(t, cfg.IsConfigured()) @@ -462,11 +467,11 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { Disable: true, }, - }, + }), } cfg.setDefaults("/tmp") @@ -478,15 +483,15 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { assert.NoError(t, err) // Provider should be removed from config when disabled - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("openai") assert.False(t, exists) } func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { BaseURL: "https://api.custom.com/v1", Models: []provider.Model{{ @@ -496,7 +501,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "openai": { APIKey: "$MISSING", }, - }, + }), } cfg.setDefaults("/tmp") @@ -505,21 +510,21 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 1) + _, exists := cfg.Providers.Get("custom") assert.True(t, exists) }) t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", Models: []provider.Model{{ ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -528,20 +533,20 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("custom provider with no models is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []provider.Model{}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -550,14 +555,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("custom provider with unsupported type is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -566,7 +571,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -575,14 +580,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -591,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -600,8 +605,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - customProvider, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 1) + customProvider, exists := cfg.Providers.Get("custom") assert.True(t, exists) assert.Equal(t, "custom", customProvider.ID) assert.Equal(t, "test-key", customProvider.APIKey) @@ -610,7 +615,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -620,7 +625,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -629,8 +634,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) } @@ -649,11 +654,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "vertexai": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -664,8 +669,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["vertexai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("vertexai") assert.False(t, exists) }) @@ -682,11 +687,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "bedrock": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -695,8 +700,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["bedrock"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("bedrock") assert.False(t, exists) }) @@ -713,11 +718,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -726,8 +731,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("openai") assert.False(t, exists) }) @@ -744,11 +749,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { APIKey: "test-key", }, - }, + }), } cfg.setDefaults("/tmp") @@ -759,8 +764,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 1) + _, exists := cfg.Providers.Get("openai") assert.True(t, exists) }) } @@ -883,7 +888,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -894,7 +899,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) @@ -932,13 +937,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []provider.Model{}, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) @@ -969,7 +974,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -980,7 +985,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) diff --git a/internal/csync/maps.go b/internal/csync/maps.go new file mode 100644 index 0000000000000000000000000000000000000000..69b56050d45b13abd189d3eb2da75120fe13589f --- /dev/null +++ b/internal/csync/maps.go @@ -0,0 +1,84 @@ +package csync + +import ( + "encoding/json" + "iter" + "maps" + "sync" +) + +type Map[K comparable, V any] struct { + inner map[K]V + mu sync.RWMutex +} + +func NewMap[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + inner: make(map[K]V), + } +} + +func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] { + return &Map[K, V]{ + inner: m, + } +} + +func (m *Map[K, V]) Set(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + m.inner[key] = value +} + +func (m *Map[K, V]) Del(key K) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.inner, key) +} + +func (m *Map[K, V]) Get(key K) (V, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + v, ok := m.inner[key] + return v, ok +} + +func (m *Map[K, V]) Len() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.inner) +} + +func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { + dst := make(map[K]V) + m.mu.RLock() + maps.Copy(dst, m.inner) + m.mu.RUnlock() + return func(yield func(K, V) bool) { + for k, v := range dst { + if !yield(k, v) { + return + } + } + } +} + +var ( + _ json.Unmarshaler = &Map[string, any]{} + _ json.Marshaler = &Map[string, any]{} +) + +// UnmarshalJSON implements json.Unmarshaler. +func (m *Map[K, V]) UnmarshalJSON(data []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + m.inner = make(map[K]V) + return json.Unmarshal(data, &m.inner) +} + +// MarshalJSON implements json.Marshaler. +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return json.Marshal(m.inner) +} diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5eddd92ce201f12d5f59620817a9e04c4e2f3008 --- /dev/null +++ b/internal/csync/maps_test.go @@ -0,0 +1,450 @@ +package csync + +import ( + "encoding/json" + "maps" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + assert.NotNil(t, m) + assert.NotNil(t, m.inner) + assert.Equal(t, 0, m.Len()) +} + +func TestNewMapFrom(t *testing.T) { + t.Parallel() + + original := map[string]int{ + "key1": 1, + "key2": 2, + } + + m := NewMapFrom(original) + assert.NotNil(t, m) + assert.Equal(t, original, m.inner) + assert.Equal(t, 2, m.Len()) + + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) +} + +func TestMap_Set(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + m.Set("key1", 42) + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 1, m.Len()) + + m.Set("key1", 100) + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 100, value) + assert.Equal(t, 1, m.Len()) +} + +func TestMap_Get(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + value, ok := m.Get("nonexistent") + assert.False(t, ok) + assert.Equal(t, 0, value) + + m.Set("key1", 42) + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) +} + +func TestMap_Del(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + m.Set("key2", 100) + + assert.Equal(t, 2, m.Len()) + + m.Del("key1") + _, ok := m.Get("key1") + assert.False(t, ok) + assert.Equal(t, 1, m.Len()) + + value, ok := m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 100, value) + + m.Del("nonexistent") + assert.Equal(t, 1, m.Len()) +} + +func TestMap_Len(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + assert.Equal(t, 0, m.Len()) + + m.Set("key1", 1) + assert.Equal(t, 1, m.Len()) + + m.Set("key2", 2) + assert.Equal(t, 2, m.Len()) + + m.Del("key1") + assert.Equal(t, 1, m.Len()) + + m.Del("key2") + assert.Equal(t, 0, m.Len()) +} + +func TestMap_Seq2(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + collected := maps.Collect(m.Seq2()) + + assert.Equal(t, 3, len(collected)) + assert.Equal(t, 1, collected["key1"]) + assert.Equal(t, 2, collected["key2"]) + assert.Equal(t, 3, collected["key3"]) +} + +func TestMap_Seq2_EarlyReturn(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + count := 0 + for range m.Seq2() { + count++ + if count == 2 { + break + } + } + + assert.Equal(t, 2, count) +} + +func TestMap_Seq2_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + count := 0 + for range m.Seq2() { + count++ + } + + assert.Equal(t, 0, count) +} + +func TestMap_MarshalJSON(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + + data, err := json.Marshal(m) + assert.NoError(t, err) + + var result map[string]int + err = json.Unmarshal(data, &result) + assert.NoError(t, err) + assert.Equal(t, 2, len(result)) + assert.Equal(t, 1, result["key1"]) + assert.Equal(t, 2, result["key2"]) +} + +func TestMap_MarshalJSON_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + data, err := json.Marshal(m) + assert.NoError(t, err) + assert.Equal(t, "{}", string(data)) +} + +func TestMap_UnmarshalJSON(t *testing.T) { + t.Parallel() + + jsonData := `{"key1": 1, "key2": 2}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + + assert.Equal(t, 2, m.Len()) + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) + + value, ok = m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 2, value) +} + +func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) { + t.Parallel() + + jsonData := `{}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + assert.Equal(t, 0, m.Len()) +} + +func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + jsonData := `{"key1": 1, "key2":}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.Error(t, err) +} + +func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("existing", 999) + + jsonData := `{"key1": 1, "key2": 2}` + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + + assert.Equal(t, 2, m.Len()) + _, ok := m.Get("existing") + assert.False(t, ok) + + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) +} + +func TestMap_JSONRoundTrip(t *testing.T) { + t.Parallel() + + original := NewMap[string, int]() + original.Set("key1", 1) + original.Set("key2", 2) + original.Set("key3", 3) + + data, err := json.Marshal(original) + assert.NoError(t, err) + + restored := NewMap[string, int]() + err = json.Unmarshal(data, restored) + assert.NoError(t, err) + + assert.Equal(t, original.Len(), restored.Len()) + + for k, v := range original.Seq2() { + restoredValue, ok := restored.Get(k) + assert.True(t, ok) + assert.Equal(t, v, restoredValue) + } +} + +func TestMap_ConcurrentAccess(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numGoroutines = 100 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := range numGoroutines { + go func(id int) { + defer wg.Done() + for j := range numOperations { + key := id*numOperations + j + m.Set(key, key*2) + value, ok := m.Get(key) + assert.True(t, ok) + assert.Equal(t, key*2, value) + } + }(i) + } + + wg.Wait() + + assert.Equal(t, numGoroutines*numOperations, m.Len()) +} + +func TestMap_ConcurrentReadWrite(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numReaders = 50 + const numWriters = 50 + const numOperations = 100 + + for i := range 1000 { + m.Set(i, i) + } + + var wg sync.WaitGroup + wg.Add(numReaders + numWriters) + + for range numReaders { + go func() { + defer wg.Done() + for j := range numOperations { + key := j % 1000 + value, ok := m.Get(key) + if ok { + assert.Equal(t, key, value) + } + _ = m.Len() + } + }() + } + + for i := range numWriters { + go func(id int) { + defer wg.Done() + for j := range numOperations { + key := 1000 + id*numOperations + j + m.Set(key, key) + if j%10 == 0 { + m.Del(key) + } + } + }(i) + } + + wg.Wait() +} + +func TestMap_ConcurrentSeq2(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + for i := range 100 { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numIterators = 10 + + wg.Add(numIterators) + for range numIterators { + go func() { + defer wg.Done() + count := 0 + for k, v := range m.Seq2() { + assert.Equal(t, k*2, v) + count++ + } + assert.Equal(t, 100, count) + }() + } + + wg.Wait() +} + +func TestMap_TypeSafety(t *testing.T) { + t.Parallel() + + stringIntMap := NewMap[string, int]() + stringIntMap.Set("key", 42) + value, ok := stringIntMap.Get("key") + assert.True(t, ok) + assert.Equal(t, 42, value) + + intStringMap := NewMap[int, string]() + intStringMap.Set(42, "value") + strValue, ok := intStringMap.Get(42) + assert.True(t, ok) + assert.Equal(t, "value", strValue) + + structMap := NewMap[string, struct{ Name string }]() + structMap.Set("key", struct{ Name string }{Name: "test"}) + structValue, ok := structMap.Get("key") + assert.True(t, ok) + assert.Equal(t, "test", structValue.Name) +} + +func TestMap_InterfaceCompliance(t *testing.T) { + t.Parallel() + + var _ json.Marshaler = &Map[string, any]{} + var _ json.Unmarshaler = &Map[string, any]{} +} + +func BenchmarkMap_Set(b *testing.B) { + m := NewMap[int, int]() + + for i := 0; b.Loop(); i++ { + m.Set(i, i*2) + } +} + +func BenchmarkMap_Get(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for i := 0; b.Loop(); i++ { + m.Get(i % 1000) + } +} + +func BenchmarkMap_Seq2(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for b.Loop() { + for range m.Seq2() { + } + } +} + +func BenchmarkMap_ConcurrentReadWrite(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + m.Get(i % 1000) + } else { + m.Set(i+1000, i*2) + } + i++ + } + }) +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index e920651d0faeb87da765c4ab67735c1c2d285001..2f76cc7771e3f0383f20b4ef1dffe448e06a253c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -8,9 +8,9 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" + "github.com/charmbracelet/crush/csync" "github.com/charmbracelet/crush/internal/config" fur "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" @@ -68,8 +68,7 @@ type agent struct { sessions session.Service messages message.Service - toolsDone atomic.Bool - tools []tools.BaseTool + tools *csync.LazySlice[tools.BaseTool] provider provider.Provider providerID string @@ -168,24 +167,10 @@ func NewAgent( return nil, err } - agent := &agent{ - Broker: pubsub.NewBroker[AgentEvent](), - agentCfg: agentCfg, - provider: agentProvider, - providerID: string(providerCfg.ID), - messages: messages, - sessions: sessions, - titleProvider: titleProvider, - summarizeProvider: summarizeProvider, - summarizeProviderID: string(smallModelProviderCfg.ID), - activeRequests: sync.Map{}, - } - - go func() { + toolFn := func() []tools.BaseTool { slog.Info("Initializing agent tools", "agent", agentCfg.ID) defer func() { slog.Info("Initialized agent tools", "agent", agentCfg.ID) - agent.toolsDone.Store(true) }() cwd := cfg.WorkingDir() @@ -214,8 +199,7 @@ func NewAgent( } if agentCfg.AllowedTools == nil { - agent.tools = allTools - return + return allTools } var filteredTools []tools.BaseTool @@ -224,10 +208,22 @@ func NewAgent( filteredTools = append(filteredTools, tool) } } - agent.tools = filteredTools - }() + return filteredTools + } - return agent, nil + return &agent{ + Broker: pubsub.NewBroker[AgentEvent](), + agentCfg: agentCfg, + provider: agentProvider, + providerID: string(providerCfg.ID), + messages: messages, + sessions: sessions, + titleProvider: titleProvider, + summarizeProvider: summarizeProvider, + summarizeProviderID: string(smallModelProviderCfg.ID), + activeRequests: sync.Map{}, + tools: csync.NewLazySlice(toolFn), + }, nil } func (a *agent) Model() fur.Model { @@ -449,10 +445,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - if !a.toolsDone.Load() { - return message.Message{}, nil, fmt.Errorf("agent is still initializing, please wait a moment and try again") - } - eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter())) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, @@ -501,7 +494,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for _, availableTool := range a.tools { + for availableTool := range a.tools.Iter() { if availableTool.Info().Name == toolCall.Name { tool = availableTool break @@ -911,7 +904,7 @@ func (a *agent) UpdateModel() error { smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig - for _, p := range cfg.Providers { + for _, p := range cfg.Providers.Seq2() { if p.ID == smallModelCfg.Provider { smallModelProviderCfg = p break diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index f7a6dce4baa2c3a2798c30baa6b995f6da72d05b..2a2c47a171a7ac685d644005e61e507a3964389f 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -422,7 +422,7 @@ func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provide func (s *splashCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers[providerID]; ok { + if _, ok := cfg.Providers.Get(providerID); ok { return true } return false diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 86b1b9a3fa0b4b6faa56a927a9011673aa8365af..5f558364eec801d77a250c891a80110e0c9a3b86 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -103,7 +103,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { if err != nil { return util.ReportError(err) } - for providerID, providerConfig := range cfg.Providers { + for providerID, providerConfig := range cfg.Providers.Seq2() { if providerConfig.Disable { continue } @@ -164,7 +164,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } // Check if this provider is configured and not disabled - if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable { + if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable { continue } @@ -174,7 +174,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } section := commands.NewItemSection(name) - if _, ok := cfg.Providers[string(provider.ID)]; ok { + if _, ok := cfg.Providers.Get(string(provider.ID)); ok { section.SetInfo(configured) } modelItems = append(modelItems, section) diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index b28efc6010582a503c34e87ad101832925d8acca..b53388d16f17bbae8612cc66d1525e3e0e616db5 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -357,7 +357,7 @@ func (m *modelDialogCmp) modelTypeRadio() string { func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers[providerID]; ok { + if _, ok := cfg.Providers.Get(providerID); ok { return true } return false From 7f10a030744571179aa197cbe87323009e72e7d5 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 16:56:42 -0300 Subject: [PATCH 08/21] fix: improvements --- internal/config/load.go | 65 +++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/internal/config/load.go b/internal/config/load.go index 09d65e5391b94a1f80b15e7e576ba5d3e38ef19d..c05861a2431303591274568502ae090fde71a4dd 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -77,38 +77,7 @@ func Load(workingDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to configure providers: %w", err) } - // Test provider connections in parallel - var testResults sync.Map - var wg sync.WaitGroup - - go func() { - slog.Info("Testing provider connections") - defer slog.Info("Provider connection tests completed") - for _, p := range cfg.Providers.Seq2() { - if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { - wg.Add(1) - go func(provider ProviderConfig) { - defer wg.Done() - err := provider.TestConnection(cfg.resolver) - testResults.Store(provider.ID, err == nil) - if err != nil { - slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) - } - }(p) - } - } - wg.Wait() - - // Remove failed providers - testResults.Range(func(key, value any) bool { - providerID := key.(string) - passed := value.(bool) - if !passed { - cfg.Providers.Del(providerID) - } - return true - }) - }() + go cfg.removeUnresponsiveProviders() if !cfg.IsConfigured() { slog.Warn("No providers configured") @@ -122,6 +91,38 @@ func Load(workingDir string, debug bool) (*Config, error) { return cfg, nil } +func (c *Config) removeUnresponsiveProviders() { + // Test provider connections in parallel + var testResults sync.Map + var wg sync.WaitGroup + slog.Info("Testing provider connections") + defer slog.Info("Provider connection tests completed") + for _, p := range c.Providers.Seq2() { + if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { + wg.Add(1) + go func(provider ProviderConfig) { + defer wg.Done() + err := provider.TestConnection(c.resolver) + testResults.Store(provider.ID, err == nil) + if err != nil { + slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) + } + }(p) + } + } + wg.Wait() + + // Remove failed providers + testResults.Range(func(key, value any) bool { + providerID := key.(string) + passed := value.(bool) + if !passed { + c.Providers.Del(providerID) + } + return true + }) +} + func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { knownProviderNames := make(map[string]bool) for _, p := range knownProviders { From 9f66b30091725d2088ac7a9496117ad7fae31553 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 20:35:39 -0300 Subject: [PATCH 09/21] fix: pkg --- {csync => internal/csync}/slices.go | 0 {csync => internal/csync}/slices_test.go | 0 internal/llm/agent/agent.go | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename {csync => internal/csync}/slices.go (100%) rename {csync => internal/csync}/slices_test.go (100%) diff --git a/csync/slices.go b/internal/csync/slices.go similarity index 100% rename from csync/slices.go rename to internal/csync/slices.go diff --git a/csync/slices_test.go b/internal/csync/slices_test.go similarity index 100% rename from csync/slices_test.go rename to internal/csync/slices_test.go diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 2f76cc7771e3f0383f20b4ef1dffe448e06a253c..02961969f8d16ed316f91132c944be1ea2311f48 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -10,8 +10,8 @@ import ( "sync" "time" - "github.com/charmbracelet/crush/csync" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" fur "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" From c3cde2f933b489cdcfbbe875e9e8214b51cbeefb Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 20:36:11 -0300 Subject: [PATCH 10/21] fix: todo --- internal/config/load.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/config/load.go b/internal/config/load.go index c05861a2431303591274568502ae090fde71a4dd..044ef504859bcbcc051b93322099f6d03b1fa601 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -322,7 +322,7 @@ func (c *Config) setDefaults(workingDir string) { } func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { - if len(knownProviders) == 0 { // TODO:}&& len(c.Providers) == 0 { + if len(knownProviders) == 0 && c.Providers.Len() == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return } From a6f21b2794162b316c6fb4c18725e01524cfeaa1 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 20:38:44 -0300 Subject: [PATCH 11/21] fix: test --- internal/csync/maps_test.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 5eddd92ce201f12d5f59620817a9e04c4e2f3008..73e6f1db245231e9fad82103366d96a326acc4f6 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -168,12 +168,14 @@ func TestMap_MarshalJSON(t *testing.T) { data, err := json.Marshal(m) assert.NoError(t, err) - var result map[string]int - err = json.Unmarshal(data, &result) + result := &Map[string, int]{} + err = json.Unmarshal(data, result) assert.NoError(t, err) - assert.Equal(t, 2, len(result)) - assert.Equal(t, 1, result["key1"]) - assert.Equal(t, 2, result["key2"]) + assert.Equal(t, 2, result.Len()) + v1, _ := result.Get("key1") + v2, _ := result.Get("key2") + assert.Equal(t, 1, v1) + assert.Equal(t, 2, v2) } func TestMap_MarshalJSON_EmptyMap(t *testing.T) { From 7e7a69fa6b86d5eaac25b9c0fe5a33c3e805b843 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 20:42:38 -0300 Subject: [PATCH 12/21] docs: godoc --- internal/csync/maps.go | 8 ++++++++ internal/csync/slices.go | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/internal/csync/maps.go b/internal/csync/maps.go index 69b56050d45b13abd189d3eb2da75120fe13589f..45e426630a4e50b45125d41dcca54d4e183b4f6f 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -7,35 +7,41 @@ import ( "sync" ) +// Map is a concurrent map implementation that provides thread-safe access. type Map[K comparable, V any] struct { inner map[K]V mu sync.RWMutex } +// NewMap creates a new thread-safe map with the specified key and value types. func NewMap[K comparable, V any]() *Map[K, V] { return &Map[K, V]{ inner: make(map[K]V), } } +// NewMapFrom creates a new thread-safe map from an existing map. func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] { return &Map[K, V]{ inner: m, } } +// Set sets the value for the specified key in the map. func (m *Map[K, V]) Set(key K, value V) { m.mu.Lock() defer m.mu.Unlock() m.inner[key] = value } +// Del deletes the specified key from the map. func (m *Map[K, V]) Del(key K) { m.mu.Lock() defer m.mu.Unlock() delete(m.inner, key) } +// Get gets the value for the specified key from the map. func (m *Map[K, V]) Get(key K) (V, bool) { m.mu.RLock() defer m.mu.RUnlock() @@ -43,12 +49,14 @@ func (m *Map[K, V]) Get(key K) (V, bool) { return v, ok } +// Len returns the number of items in the map. func (m *Map[K, V]) Len() int { m.mu.RLock() defer m.mu.RUnlock() return len(m.inner) } +// Seq2 returns an iter.Seq2 that yields key-value pairs from the map. func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { dst := make(map[K]V) m.mu.RLock() diff --git a/internal/csync/slices.go b/internal/csync/slices.go index 388ad074d53a9bd7188418b231afbf39adca0565..2ce448ef2a1371ecd2e36f955cd1096140cd798a 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -5,11 +5,14 @@ import ( "sync" ) +// LazySlice is a thread-safe lazy-loaded slice. type LazySlice[K any] struct { inner []K mu sync.Mutex } +// NewLazySlice creates a new slice and runs the [load] function in a goroutine +// to populate it. func NewLazySlice[K any](load func() []K) *LazySlice[K] { s := &LazySlice[K]{} s.mu.Lock() @@ -20,6 +23,7 @@ func NewLazySlice[K any](load func() []K) *LazySlice[K] { return s } +// Iter returns an iterator that yields elements from the slice. func (s *LazySlice[K]) Iter() iter.Seq[K] { s.mu.Lock() inner := s.inner From af1514b627e5d130690739a0a3b09c696d9ef68f Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 20:43:08 -0300 Subject: [PATCH 13/21] docs: todo --- internal/csync/slices_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index d1c7af8cf30f3d58a84046f899f8dd89f80beb51..31f70d73c4b1c975592c17cf41cf0d57b0050b43 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -70,6 +70,7 @@ func TestLazySlice_EarlyBreak(t *testing.T) { data := []string{"a", "b", "c", "d", "e"} s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens return data }) From 26be55f6f678afeb8ca2a4510d2437e3c3a872db Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 22 Jul 2025 20:44:41 -0300 Subject: [PATCH 14/21] fix: method name --- internal/csync/slices.go | 4 ++-- internal/csync/slices_test.go | 14 +++++++------- internal/llm/agent/agent.go | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/internal/csync/slices.go b/internal/csync/slices.go index 2ce448ef2a1371ecd2e36f955cd1096140cd798a..bc295fc66fa3fcfa8a3e6a7ecf28492b15269200 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -23,8 +23,8 @@ func NewLazySlice[K any](load func() []K) *LazySlice[K] { return s } -// Iter returns an iterator that yields elements from the slice. -func (s *LazySlice[K]) Iter() iter.Seq[K] { +// Seq returns an iterator that yields elements from the slice. +func (s *LazySlice[K]) Seq() iter.Seq[K] { s.mu.Lock() inner := s.inner s.mu.Unlock() diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index 31f70d73c4b1c975592c17cf41cf0d57b0050b43..731cb96f55dd24cae74f55c0ef8e97ebd28aacaa 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLazySlice_Iter(t *testing.T) { +func TestLazySlice_Seq(t *testing.T) { t.Parallel() data := []string{"a", "b", "c"} @@ -19,14 +19,14 @@ func TestLazySlice_Iter(t *testing.T) { }) var result []string - for v := range s.Iter() { + for v := range s.Seq() { result = append(result, v) } assert.Equal(t, data, result) } -func TestLazySlice_IterWaitsForLoading(t *testing.T) { +func TestLazySlice_SeqWaitsForLoading(t *testing.T) { t.Parallel() var loaded atomic.Bool @@ -42,11 +42,11 @@ func TestLazySlice_IterWaitsForLoading(t *testing.T) { assert.False(t, loaded.Load(), "should not be loaded immediately") var result []string - for v := range s.Iter() { + for v := range s.Seq() { result = append(result, v) } - assert.True(t, loaded.Load(), "should be loaded after Iter") + assert.True(t, loaded.Load(), "should be loaded after Seq") assert.Equal(t, data, result) } @@ -58,7 +58,7 @@ func TestLazySlice_EmptySlice(t *testing.T) { }) var result []string - for v := range s.Iter() { + for v := range s.Seq() { result = append(result, v) } @@ -76,7 +76,7 @@ func TestLazySlice_EarlyBreak(t *testing.T) { }) var result []string - for v := range s.Iter() { + for v := range s.Seq() { result = append(result, v) if len(result) == 2 { break diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 02961969f8d16ed316f91132c944be1ea2311f48..72697cb0ac801f013a094dc5c44a3152f1443af1 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -445,7 +445,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter())) + eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq())) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, @@ -494,7 +494,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for availableTool := range a.tools.Iter() { + for availableTool := range a.tools.Seq() { if availableTool.Info().Name == toolCall.Name { tool = availableTool break From f36edc61db5c0abe4dd4e6cbb28636d3d93538e5 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Wed, 23 Jul 2025 00:47:10 -0300 Subject: [PATCH 15/21] fix: load providers in background --- internal/config/load.go | 5 +- internal/config/provider.go | 84 ++++++++++++++++++-------- internal/config/provider_empty_test.go | 47 ++++++++++++++ internal/config/provider_test.go | 6 +- 4 files changed, 111 insertions(+), 31 deletions(-) create mode 100644 internal/config/provider_empty_test.go diff --git a/internal/config/load.go b/internal/config/load.go index 044ef504859bcbcc051b93322099f6d03b1fa601..48ef9b1caf1e5d9ec1877f7fc9c3a53ab996d129 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "os" "path/filepath" "runtime" @@ -13,10 +14,8 @@ import ( "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/log" - "golang.org/x/exp/slog" ) // LoadReader config via io.Reader. @@ -63,7 +62,7 @@ func Load(workingDir string, debug bool) (*Config, error) { ) // Load known providers, this loads the config from fur - providers, err := LoadProviders(client.New()) + providers, err := Providers() if err != nil || len(providers) == 0 { return nil, fmt.Errorf("failed to load providers: %w", err) } diff --git a/internal/config/provider.go b/internal/config/provider.go index b8369b934963aca0a7f449fb219764ee079493ef..caeba48707be933d222313729934cc69c819f68e 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -2,10 +2,13 @@ package config import ( "encoding/json" + "fmt" + "log/slog" "os" "path/filepath" "runtime" "sync" + "time" "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" @@ -42,57 +45,88 @@ func providerCacheFileData() string { } func saveProvidersInCache(path string, providers []provider.Provider) error { - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err + slog.Info("Caching provider data") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("failed to create directory for provider cache: %w", err) } data, err := json.MarshalIndent(providers, "", " ") if err != nil { - return err + return fmt.Errorf("failed to marshal provider data: %w", err) } - return os.WriteFile(path, data, 0o644) + if err := os.WriteFile(path, data, 0o644); err != nil { + return fmt.Errorf("failed to write provider data to cache: %w", err) + } + return nil } func loadProvidersFromCache(path string) ([]provider.Provider, error) { data, err := os.ReadFile(path) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read provider cache file: %w", err) } var providers []provider.Provider - err = json.Unmarshal(data, &providers) - return providers, err -} - -func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) { - providers, err := client.GetProviders() - if err != nil { - fallbackToCache, err := loadProvidersFromCache(path) - if err != nil { - return nil, err - } - providers = fallbackToCache - } else { - if err := saveProvidersInCache(path, providerList); err != nil { - return nil, err - } + if err := json.Unmarshal(data, &providers); err != nil { + return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err) } return providers, nil } func Providers() ([]provider.Provider, error) { - return LoadProviders(client.New()) + client := client.New() + path := providerCacheFileData() + return loadProvidersOnce(client, path) } -func LoadProviders(client ProviderClient) ([]provider.Provider, error) { +func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider, error) { var err error providerOnce.Do(func() { - providerList, err = loadProviders(providerCacheFileData(), client) + providerList, err = loadProviders(client, path) }) if err != nil { return nil, err } return providerList, nil } + +func loadProviders(client ProviderClient, path string) (providerList []provider.Provider, err error) { + // if cache is not stale, load from it + stale, exists := isCacheStale(path) + if !stale { + slog.Info("Using cached provider data") + providerList, err = loadProvidersFromCache(path) + if len(providerList) > 0 && err == nil { + go func() { + slog.Info("Updating provider cache in background") + updated, uerr := client.GetProviders() + if len(updated) == 0 && uerr == nil { + _ = saveProvidersInCache(path, updated) + } + }() + return + } + } + + slog.Info("Getting live provider data") + providerList, err = client.GetProviders() + if len(providerList) > 0 && err == nil { + err = saveProvidersInCache(path, providerList) + return + } + if !exists { + err = fmt.Errorf("failed to load providers") + return + } + providerList, err = loadProvidersFromCache(path) + return +} + +func isCacheStale(path string) (stale, exists bool) { + info, err := os.Stat(path) + if err != nil { + return true, false + } + return time.Since(info.ModTime()) > 24*time.Hour, true +} diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go new file mode 100644 index 0000000000000000000000000000000000000000..480869d98e4d69087aefc5759de0776f7910ebec --- /dev/null +++ b/internal/config/provider_empty_test.go @@ -0,0 +1,47 @@ +package config + +import ( + "encoding/json" + "os" + "testing" + + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/stretchr/testify/require" +) + +type emptyProviderClient struct{} + +func (m *emptyProviderClient) GetProviders() ([]provider.Provider, error) { + return []provider.Provider{}, nil +} + +func TestProvider_loadProvidersEmptyResult(t *testing.T) { + client := &emptyProviderClient{} + tmpPath := t.TempDir() + "/providers.json" + + providers, err := loadProviders(client, tmpPath) + require.EqualError(t, err, "failed to load providers") + require.Empty(t, providers) + require.Len(t, providers, 0) + + // Check that no cache file was created for empty results + require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results") +} + +func TestProvider_loadProvidersEmptyCache(t *testing.T) { + client := &mockProviderClient{shouldFail: false} + tmpPath := t.TempDir() + "/providers.json" + + // Create an empty cache file + emptyProviders := []provider.Provider{} + data, err := json.Marshal(emptyProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) + + // Should refresh and get real providers instead of using empty cache + providers, err := loadProviders(client, tmpPath) + require.NoError(t, err) + require.NotNil(t, providers) + require.Len(t, providers, 1) + require.Equal(t, "Mock", providers[0].Name) +} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index a3562838c7103239aa303c906c866220164a4ba0..abfb6592bcd5e46a7cbf40dba54a10722ee69980 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) { func TestProvider_loadProvidersNoIssues(t *testing.T) { client := &mockProviderClient{shouldFail: false} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.NoError(t, err) assert.NotNil(t, providers) assert.Len(t, providers, 1) @@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { if err != nil { t.Fatalf("Failed to write old providers to file: %v", err) } - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.NoError(t, err) assert.NotNil(t, providers) assert.Len(t, providers, 1) @@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.Error(t, err) assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") } From 574ec2e85070d83115c4a4ef49ea5c6652fc663d Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Thu, 24 Jul 2025 10:51:56 -0300 Subject: [PATCH 16/21] fix: cache update logic --- internal/config/provider.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/config/provider.go b/internal/config/provider.go index 98235cd84794812128082533f3a501bfce952cb8..ba02f9d8e1bc0f2ec58c2ed3e736a87e1d7a614b 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -44,7 +44,7 @@ func providerCacheFileData() string { } func saveProvidersInCache(path string, providers []catwalk.Provider) error { - slog.Info("Caching provider data") + slog.Info("Saving cached provider data", "path", path) if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return fmt.Errorf("failed to create directory for provider cache: %w", err) } @@ -94,13 +94,13 @@ func loadProviders(client ProviderClient, path string) (providerList []catwalk.P // if cache is not stale, load from it stale, exists := isCacheStale(path) if !stale { - slog.Info("Using cached provider data") + slog.Info("Using cached provider data", "path", path) providerList, err = loadProvidersFromCache(path) if len(providerList) > 0 && err == nil { go func() { slog.Info("Updating provider cache in background") updated, uerr := client.GetProviders() - if len(updated) == 0 && uerr == nil { + if len(updated) > 0 && uerr == nil { _ = saveProvidersInCache(path, updated) } }() From 69752719f067e2cc37aaead6d088c4f8a4b3aec0 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Thu, 24 Jul 2025 11:50:45 -0300 Subject: [PATCH 17/21] fix: improve lazy slice --- internal/csync/slices.go | 12 +++++------- internal/llm/agent/agent.go | 4 ++-- internal/llm/provider/gemini.go | 8 ++------ internal/llm/provider/openai_test.go | 3 +-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/internal/csync/slices.go b/internal/csync/slices.go index bc295fc66fa3fcfa8a3e6a7ecf28492b15269200..be723655079ccc6b07f55c3237b706a17bb14d40 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -8,28 +8,26 @@ import ( // LazySlice is a thread-safe lazy-loaded slice. type LazySlice[K any] struct { inner []K - mu sync.Mutex + wg sync.WaitGroup } // NewLazySlice creates a new slice and runs the [load] function in a goroutine // to populate it. func NewLazySlice[K any](load func() []K) *LazySlice[K] { s := &LazySlice[K]{} - s.mu.Lock() + s.wg.Add(1) go func() { s.inner = load() - s.mu.Unlock() + s.wg.Done() }() return s } // Seq returns an iterator that yields elements from the slice. func (s *LazySlice[K]) Seq() iter.Seq[K] { - s.mu.Lock() - inner := s.inner - s.mu.Unlock() + s.wg.Wait() return func(yield func(K) bool) { - for _, v := range inner { + for _, v := range s.inner { if !yield(v) { return } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 75f1f545929cc2422461ed0e775775689f8567d2..2c3876ccac9ed028b1714ed96b0c6de0cce007c9 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -292,7 +292,7 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str Parts: parts, }, }, - make([]tools.BaseTool, 0), + nil, ) var finalResponse *provider.ProviderResponse @@ -745,7 +745,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { response := a.summarizeProvider.StreamResponse( summarizeCtx, msgsWithPrompt, - make([]tools.BaseTool, 0), + nil, ) var finalResponse *provider.ProviderResponse for r := range response { diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 4fa0cff4d17c28da16528d33ff54e2a905521387..b2d1da11148e74362e7b529b9ec78dc1810d0f0d 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -188,9 +188,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, } - if len(tools) > 0 { - config.Tools = g.convertTools(tools) - } + config.Tools = g.convertTools(tools) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 @@ -290,9 +288,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, } - if len(tools) > 0 { - config.Tools = g.convertTools(tools) - } + config.Tools = g.convertTools(tools) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go index 26c4d85ae35bbf4681719a12b568befccd8012af..ef79803c8a8aa1ee3fe6cb7de8bc8fa86f26c03c 100644 --- a/internal/llm/provider/openai_test.go +++ b/internal/llm/provider/openai_test.go @@ -11,7 +11,6 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" @@ -79,7 +78,7 @@ func TestOpenAIClientStreamChoices(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - eventsChan := client.stream(ctx, messages, []tools.BaseTool{}) + eventsChan := client.stream(ctx, messages, nil) // Collect events - this will panic without the bounds check for event := range eventsChan { From ac782d1ff81f45a91bb34c0669deb2861b27f733 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Thu, 24 Jul 2025 11:51:00 -0300 Subject: [PATCH 18/21] fix: prevent nil ptr --- go.mod | 3 ++- internal/config/config.go | 2 +- internal/config/init.go | 19 ++++++++++--------- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index e17354c051a21b593a385b1e3995cc543aafd0dd..1f24cb1d27dc197ff662d1b4caf3a4aadf828cb9 100644 --- a/go.mod +++ b/go.mod @@ -42,11 +42,12 @@ require ( github.com/tidwall/sjson v1.2.5 github.com/u-root/u-root v0.14.1-0.20250722142936-bf4e78a90dfc github.com/zeebo/xxh3 v1.0.2 - golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 gopkg.in/natefinch/lumberjack.v2 v2.2.1 mvdan.cc/sh/v3 v3.11.0 ) +require golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect + require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.13.0 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index b9d44bc87448d3244d27c426bf0f70dc98ce064a..9709c11a0636d91cb492b7735b63e46e5e843c74 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "context" "fmt" + "log/slog" "net/http" "os" "slices" @@ -13,7 +14,6 @@ import ( "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/tidwall/sjson" - "golang.org/x/exp/slog" ) const ( diff --git a/internal/config/init.go b/internal/config/init.go index 827a287718e40e1fc5b9b761293c00799ec5ef3d..3a4194afdefb7bb1d0c95bd41a837dced15e9433 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -6,12 +6,9 @@ import ( "path/filepath" "strings" "sync" - "sync/atomic" ) -const ( - InitFlagFilename = "init" -) +const InitFlagFilename = "init" type ProjectInitFlag struct { Initialized bool `json:"initialized"` @@ -19,25 +16,29 @@ type ProjectInitFlag struct { // TODO: we need to remove the global config instance keeping it now just until everything is migrated var ( - instance atomic.Pointer[Config] + instance *Config cwd string - once sync.Once // Ensures the initialization happens only once + once sync.Once + wg sync.WaitGroup ) func Init(workingDir string, debug bool) (*Config, error) { var err error + wg.Add(1) once.Do(func() { cwd = workingDir var cfg *Config cfg, err = Load(cwd, debug) - instance.Store(cfg) + instance = cfg + wg.Done() }) - return instance.Load(), err + return instance, err } func Get() *Config { - return instance.Load() + wg.Wait() + return instance } func ProjectNeedsInitialization() (bool, error) { From a4dee0614060db756b3624db82a5d52b31c15337 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Thu, 24 Jul 2025 11:58:56 -0300 Subject: [PATCH 19/21] fix: csync.Map --- internal/config/load.go | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/internal/config/load.go b/internal/config/load.go index 5d11901dd4d169041d315eb139d27e6dbec736de..6a683a1c98191e0c9f78e75773fc77839143f51e 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -93,35 +93,34 @@ func Load(workingDir string, debug bool) (*Config, error) { } func (c *Config) removeUnresponsiveProviders() { - // Test provider connections in parallel - var testResults sync.Map - var wg sync.WaitGroup slog.Info("Testing provider connections") defer slog.Info("Provider connection tests completed") + + // Test provider connections in parallel + var wg sync.WaitGroup + testResults := csync.NewMap[string, bool]() for _, p := range c.Providers.Seq2() { - if p.Type == catwalk.TypeOpenAI || p.Type == catwalk.TypeAnthropic { - wg.Add(1) - go func(provider ProviderConfig) { - defer wg.Done() - err := provider.TestConnection(c.resolver) - testResults.Store(provider.ID, err == nil) - if err != nil { - slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) - } - }(p) + if p.Type != catwalk.TypeOpenAI && p.Type != catwalk.TypeAnthropic { + continue } + wg.Add(1) + go func(provider ProviderConfig) { + defer wg.Done() + err := provider.TestConnection(c.resolver) + testResults.Set(provider.ID, err == nil) + if err != nil { + slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) + } + }(p) } wg.Wait() // Remove failed providers - testResults.Range(func(key, value any) bool { - providerID := key.(string) - passed := value.(bool) + for providerID, passed := range testResults.Seq2() { if !passed { c.Providers.Del(providerID) } - return true - }) + } } func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { From 4509fe77d6b718c19cfab4040a9db03637a86d15 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Thu, 24 Jul 2025 16:30:50 -0300 Subject: [PATCH 20/21] fix: sync --- internal/config/init.go | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/internal/config/init.go b/internal/config/init.go index 3a4194afdefb7bb1d0c95bd41a837dced15e9433..ff44d43bb878f579d003c84537fcd970f9e52f9e 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -5,40 +5,32 @@ import ( "os" "path/filepath" "strings" - "sync" + "sync/atomic" ) -const InitFlagFilename = "init" +const ( + InitFlagFilename = "init" +) type ProjectInitFlag struct { Initialized bool `json:"initialized"` } // TODO: we need to remove the global config instance keeping it now just until everything is migrated -var ( - instance *Config - cwd string - once sync.Once - wg sync.WaitGroup -) +var instance atomic.Pointer[Config] func Init(workingDir string, debug bool) (*Config, error) { - var err error - wg.Add(1) - once.Do(func() { - cwd = workingDir - var cfg *Config - cfg, err = Load(cwd, debug) - instance = cfg - wg.Done() - }) - - return instance, err + cfg, err := Load(workingDir, debug) + if err != nil { + return nil, err + } + instance.Store(cfg) + return instance.Load(), nil } func Get() *Config { - wg.Wait() - return instance + cfg := instance.Load() + return cfg } func ProjectNeedsInitialization() (bool, error) { From 40423175e8a72f35a46fc40b6f914b6c51b556d4 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 25 Jul 2025 11:30:24 +0200 Subject: [PATCH 21/21] chore: remove provider tests on startup --- internal/config/load.go | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/internal/config/load.go b/internal/config/load.go index 6a683a1c98191e0c9f78e75773fc77839143f51e..98569d41be810dd0b9382c4df56cfb3e9c1c5842 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -10,7 +10,6 @@ import ( "runtime" "slices" "strings" - "sync" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" @@ -78,8 +77,6 @@ func Load(workingDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to configure providers: %w", err) } - go cfg.removeUnresponsiveProviders() - if !cfg.IsConfigured() { slog.Warn("No providers configured") return cfg, nil @@ -92,37 +89,6 @@ func Load(workingDir string, debug bool) (*Config, error) { return cfg, nil } -func (c *Config) removeUnresponsiveProviders() { - slog.Info("Testing provider connections") - defer slog.Info("Provider connection tests completed") - - // Test provider connections in parallel - var wg sync.WaitGroup - testResults := csync.NewMap[string, bool]() - for _, p := range c.Providers.Seq2() { - if p.Type != catwalk.TypeOpenAI && p.Type != catwalk.TypeAnthropic { - continue - } - wg.Add(1) - go func(provider ProviderConfig) { - defer wg.Done() - err := provider.TestConnection(c.resolver) - testResults.Set(provider.ID, err == nil) - if err != nil { - slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) - } - }(p) - } - wg.Wait() - - // Remove failed providers - for providerID, passed := range testResults.Seq2() { - if !passed { - c.Providers.Del(providerID) - } - } -} - func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) for _, p := range knownProviders {