@@ -0,0 +1,86 @@
+package csync
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLazySlice_Iter(t *testing.T) {
+ t.Parallel()
+
+ data := []string{"a", "b", "c"}
+ s := NewLazySlice(func() []string {
+ // TODO: use synctest when new Go is out.
+ time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens
+ return data
+ })
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ }
+
+ assert.Equal(t, data, result)
+}
+
+func TestLazySlice_IterWaitsForLoading(t *testing.T) {
+ t.Parallel()
+
+ var loaded atomic.Bool
+ data := []string{"x", "y", "z"}
+
+ s := NewLazySlice(func() []string {
+ // TODO: use synctest when new Go is out.
+ time.Sleep(100 * time.Millisecond)
+ loaded.Store(true)
+ return data
+ })
+
+ assert.False(t, loaded.Load(), "should not be loaded immediately")
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ }
+
+ assert.True(t, loaded.Load(), "should be loaded after Iter")
+ assert.Equal(t, data, result)
+}
+
+func TestLazySlice_EmptySlice(t *testing.T) {
+ t.Parallel()
+
+ s := NewLazySlice(func() []string {
+ return []string{}
+ })
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ }
+
+ assert.Empty(t, result)
+}
+
+func TestLazySlice_EarlyBreak(t *testing.T) {
+ t.Parallel()
+
+ data := []string{"a", "b", "c", "d", "e"}
+ s := NewLazySlice(func() []string {
+ time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens
+ return data
+ })
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ if len(result) == 2 {
+ break
+ }
+ }
+
+ assert.Equal(t, []string{"a", "b"}, result)
+}
@@ -8,9 +8,9 @@ import (
"slices"
"strings"
"sync"
- "sync/atomic"
"time"
+ "github.com/charmbracelet/crush/csync"
"github.com/charmbracelet/crush/internal/config"
fur "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/history"
@@ -68,8 +68,7 @@ type agent struct {
sessions session.Service
messages message.Service
- toolsDone atomic.Bool
- tools []tools.BaseTool
+ tools *csync.LazySlice[tools.BaseTool]
provider provider.Provider
providerID string
@@ -168,24 +167,10 @@ func NewAgent(
return nil, err
}
- agent := &agent{
- Broker: pubsub.NewBroker[AgentEvent](),
- agentCfg: agentCfg,
- provider: agentProvider,
- providerID: string(providerCfg.ID),
- messages: messages,
- sessions: sessions,
- titleProvider: titleProvider,
- summarizeProvider: summarizeProvider,
- summarizeProviderID: string(smallModelProviderCfg.ID),
- activeRequests: sync.Map{},
- }
-
- go func() {
+ toolFn := func() []tools.BaseTool {
slog.Info("Initializing agent tools", "agent", agentCfg.ID)
defer func() {
slog.Info("Initialized agent tools", "agent", agentCfg.ID)
- agent.toolsDone.Store(true)
}()
cwd := cfg.WorkingDir()
@@ -214,8 +199,7 @@ func NewAgent(
}
if agentCfg.AllowedTools == nil {
- agent.tools = allTools
- return
+ return allTools
}
var filteredTools []tools.BaseTool
@@ -224,10 +208,22 @@ func NewAgent(
filteredTools = append(filteredTools, tool)
}
}
- agent.tools = filteredTools
- }()
+ return filteredTools
+ }
- return agent, nil
+ return &agent{
+ Broker: pubsub.NewBroker[AgentEvent](),
+ agentCfg: agentCfg,
+ provider: agentProvider,
+ providerID: string(providerCfg.ID),
+ messages: messages,
+ sessions: sessions,
+ titleProvider: titleProvider,
+ summarizeProvider: summarizeProvider,
+ summarizeProviderID: string(smallModelProviderCfg.ID),
+ activeRequests: sync.Map{},
+ tools: csync.NewLazySlice(toolFn),
+ }, nil
}
func (a *agent) Model() fur.Model {
@@ -449,10 +445,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- if !a.toolsDone.Load() {
- return message.Message{}, nil, fmt.Errorf("agent is still initializing, please wait a moment and try again")
- }
- eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
+ eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter()))
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
@@ -501,7 +494,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
default:
// Continue processing
var tool tools.BaseTool
- for _, availableTool := range a.tools {
+ for availableTool := range a.tools.Iter() {
if availableTool.Info().Name == toolCall.Name {
tool = availableTool
break