diff --git a/csync/slices.go b/csync/slices.go new file mode 100644 index 0000000000000000000000000000000000000000..388ad074d53a9bd7188418b231afbf39adca0565 --- /dev/null +++ b/csync/slices.go @@ -0,0 +1,34 @@ +package csync + +import ( + "iter" + "sync" +) + +type LazySlice[K any] struct { + inner []K + mu sync.Mutex +} + +func NewLazySlice[K any](load func() []K) *LazySlice[K] { + s := &LazySlice[K]{} + s.mu.Lock() + go func() { + s.inner = load() + s.mu.Unlock() + }() + return s +} + +func (s *LazySlice[K]) Iter() iter.Seq[K] { + s.mu.Lock() + inner := s.inner + s.mu.Unlock() + return func(yield func(K) bool) { + for _, v := range inner { + if !yield(v) { + return + } + } + } +} diff --git a/csync/slices_test.go b/csync/slices_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d1c7af8cf30f3d58a84046f899f8dd89f80beb51 --- /dev/null +++ b/csync/slices_test.go @@ -0,0 +1,86 @@ +package csync + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLazySlice_Iter(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c"} + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.Equal(t, data, result) +} + +func TestLazySlice_IterWaitsForLoading(t *testing.T) { + t.Parallel() + + var loaded atomic.Bool + data := []string{"x", "y", "z"} + + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(100 * time.Millisecond) + loaded.Store(true) + return data + }) + + assert.False(t, loaded.Load(), "should not be loaded immediately") + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.True(t, loaded.Load(), "should be loaded after Iter") + assert.Equal(t, data, result) +} + +func TestLazySlice_EmptySlice(t *testing.T) { + t.Parallel() + + s := NewLazySlice(func() []string { + return []string{} + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.Empty(t, result) +} + +func TestLazySlice_EarlyBreak(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c", "d", "e"} + s := NewLazySlice(func() []string { + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + if len(result) == 2 { + break + } + } + + assert.Equal(t, []string{"a", "b"}, result) +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 523a81e417242bd4b4f939a73629d0f12b54e3e6..2f76cc7771e3f0383f20b4ef1dffe448e06a253c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -8,9 +8,9 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" + "github.com/charmbracelet/crush/csync" "github.com/charmbracelet/crush/internal/config" fur "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" @@ -68,8 +68,7 @@ type agent struct { sessions session.Service messages message.Service - toolsDone atomic.Bool - tools []tools.BaseTool + tools *csync.LazySlice[tools.BaseTool] provider provider.Provider providerID string @@ -168,24 +167,10 @@ func NewAgent( return nil, err } - agent := &agent{ - Broker: pubsub.NewBroker[AgentEvent](), - agentCfg: agentCfg, - provider: agentProvider, - providerID: string(providerCfg.ID), - messages: messages, - sessions: sessions, - titleProvider: titleProvider, - summarizeProvider: summarizeProvider, - summarizeProviderID: string(smallModelProviderCfg.ID), - activeRequests: sync.Map{}, - } - - go func() { + toolFn := func() []tools.BaseTool { slog.Info("Initializing agent tools", "agent", agentCfg.ID) defer func() { slog.Info("Initialized agent tools", "agent", agentCfg.ID) - agent.toolsDone.Store(true) }() cwd := cfg.WorkingDir() @@ -214,8 +199,7 @@ func NewAgent( } if agentCfg.AllowedTools == nil { - agent.tools = allTools - return + return allTools } var filteredTools []tools.BaseTool @@ -224,10 +208,22 @@ func NewAgent( filteredTools = append(filteredTools, tool) } } - agent.tools = filteredTools - }() + return filteredTools + } - return agent, nil + return &agent{ + Broker: pubsub.NewBroker[AgentEvent](), + agentCfg: agentCfg, + provider: agentProvider, + providerID: string(providerCfg.ID), + messages: messages, + sessions: sessions, + titleProvider: titleProvider, + summarizeProvider: summarizeProvider, + summarizeProviderID: string(smallModelProviderCfg.ID), + activeRequests: sync.Map{}, + tools: csync.NewLazySlice(toolFn), + }, nil } func (a *agent) Model() fur.Model { @@ -449,10 +445,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - if !a.toolsDone.Load() { - return message.Message{}, nil, fmt.Errorf("agent is still initializing, please wait a moment and try again") - } - eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter())) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, @@ -501,7 +494,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for _, availableTool := range a.tools { + for availableTool := range a.tools.Iter() { if availableTool.Info().Name == toolCall.Name { tool = availableTool break