fix: race in agent.go (#1853)

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/agent/agent.go       | 107 ++++++++++++++++++++----------------
internal/agent/event.go       |  12 ++--
internal/csync/slices.go      |  44 +++-----------
internal/csync/slices_test.go |  57 -------------------
internal/csync/value.go       |  44 +++++++++++++++
internal/csync/value_test.go  |  99 ++++++++++++++++++++++++++++++++++
6 files changed, 218 insertions(+), 145 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -87,12 +87,13 @@ type Model struct {
 }
 
 type sessionAgent struct {
-	largeModel           Model
-	smallModel           Model
-	systemPromptPrefix   string
-	systemPrompt         string
+	largeModel         *csync.Value[Model]
+	smallModel         *csync.Value[Model]
+	systemPromptPrefix *csync.Value[string]
+	systemPrompt       *csync.Value[string]
+	tools              *csync.Slice[fantasy.AgentTool]
+
 	isSubAgent           bool
-	tools                []fantasy.AgentTool
 	sessions             session.Service
 	messages             message.Service
 	disableAutoSummarize bool
@@ -119,15 +120,15 @@ func NewSessionAgent(
 	opts SessionAgentOptions,
 ) SessionAgent {
 	return &sessionAgent{
-		largeModel:           opts.LargeModel,
-		smallModel:           opts.SmallModel,
-		systemPromptPrefix:   opts.SystemPromptPrefix,
-		systemPrompt:         opts.SystemPrompt,
+		largeModel:           csync.NewValue(opts.LargeModel),
+		smallModel:           csync.NewValue(opts.SmallModel),
+		systemPromptPrefix:   csync.NewValue(opts.SystemPromptPrefix),
+		systemPrompt:         csync.NewValue(opts.SystemPrompt),
 		isSubAgent:           opts.IsSubAgent,
 		sessions:             opts.Sessions,
 		messages:             opts.Messages,
 		disableAutoSummarize: opts.DisableAutoSummarize,
-		tools:                opts.Tools,
+		tools:                csync.NewSliceFrom(opts.Tools),
 		isYolo:               opts.IsYolo,
 		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 		activeRequests:       csync.NewMap[string, context.CancelFunc](),
@@ -153,15 +154,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 		return nil, nil
 	}
 
-	if len(a.tools) > 0 {
+	// Copy mutable fields under lock to avoid races with SetTools/SetModels.
+	agentTools := a.tools.Copy()
+	largeModel := a.largeModel.Get()
+	systemPrompt := a.systemPrompt.Get()
+	promptPrefix := a.systemPromptPrefix.Get()
+
+	if len(agentTools) > 0 {
 		// Add Anthropic caching to the last tool.
-		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
+		agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
 	}
 
 	agent := fantasy.NewAgent(
-		a.largeModel.Model,
-		fantasy.WithSystemPrompt(a.systemPrompt),
-		fantasy.WithTools(a.tools...),
+		largeModel.Model,
+		fantasy.WithSystemPrompt(systemPrompt),
+		fantasy.WithTools(agentTools...),
 	)
 
 	sessionLock := sync.Mutex{}
@@ -234,7 +241,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
 			}
 
-			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
+			prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
 
 			lastSystemRoleInx := 0
 			systemMessageUpdated := false
@@ -252,7 +259,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				}
 			}
 
-			if promptPrefix := a.promptPrefix(); promptPrefix != "" {
+			if promptPrefix != "" {
 				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 			}
 
@@ -260,15 +267,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 				Role:     message.Assistant,
 				Parts:    []message.ContentPart{},
-				Model:    a.largeModel.ModelCfg.Model,
-				Provider: a.largeModel.ModelCfg.Provider,
+				Model:    largeModel.ModelCfg.Model,
+				Provider: largeModel.ModelCfg.Provider,
 			})
 			if err != nil {
 				return callContext, prepared, err
 			}
 			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
-			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
-			callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
+			callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
+			callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
 			currentAssistant = &assistantMsg
 			return callContext, prepared, err
 		},
@@ -362,7 +369,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				sessionLock.Unlock()
 				return getSessionErr
 			}
-			a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
+			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 			_, sessionErr := a.sessions.Save(genCtx, updatedSession)
 			sessionLock.Unlock()
 			if sessionErr != nil {
@@ -372,7 +379,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 		},
 		StopWhen: []fantasy.StopCondition{
 			func(_ []fantasy.StepResult) bool {
-				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
+				cw := int64(largeModel.CatwalkCfg.ContextWindow)
 				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
 				remaining := cw - tokens
 				var threshold int64
@@ -474,7 +481,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				currentAssistant.AddFinish(
 					message.FinishReasonError,
 					"Copilot model not enabled",
-					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
+					fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
 				)
 			} else {
 				currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
@@ -529,6 +536,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 		return ErrSessionBusy
 	}
 
+	// Copy mutable fields under lock to avoid races with SetModels.
+	largeModel := a.largeModel.Get()
+	systemPromptPrefix := a.systemPromptPrefix.Get()
+
 	currentSession, err := a.sessions.Get(ctx, sessionID)
 	if err != nil {
 		return fmt.Errorf("failed to get session: %w", err)
@@ -549,13 +560,13 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 	defer a.activeRequests.Del(sessionID)
 	defer cancel()
 
-	agent := fantasy.NewAgent(a.largeModel.Model,
+	agent := fantasy.NewAgent(largeModel.Model,
 		fantasy.WithSystemPrompt(string(summaryPrompt)),
 	)
 	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 		Role:             message.Assistant,
-		Model:            a.largeModel.Model.Model(),
-		Provider:         a.largeModel.Model.Provider(),
+		Model:            largeModel.Model.Model(),
+		Provider:         largeModel.Model.Provider(),
 		IsSummaryMessage: true,
 	})
 	if err != nil {
@@ -570,8 +581,8 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 		ProviderOptions: opts,
 		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 			prepared.Messages = options.Messages
-			if a.systemPromptPrefix != "" {
-				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
+			if systemPromptPrefix != "" {
+				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
 			}
 			return callContext, prepared, nil
 		},
@@ -622,7 +633,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 		}
 	}
 
-	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
+	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 
 	// Just in case, get just the last usage info.
 	usage := resp.Response.Usage
@@ -730,9 +741,13 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
 		return
 	}
 
+	smallModel := a.smallModel.Get()
+	largeModel := a.largeModel.Get()
+	systemPromptPrefix := a.systemPromptPrefix.Get()
+
 	var maxOutputTokens int64 = 40
-	if a.smallModel.CatwalkCfg.CanReason {
-		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
+	if smallModel.CatwalkCfg.CanReason {
+		maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
 	}
 
 	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
@@ -746,9 +761,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
 		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
 		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
 			prepared.Messages = opts.Messages
-			if a.systemPromptPrefix != "" {
+			if systemPromptPrefix != "" {
 				prepared.Messages = append([]fantasy.Message{
-					fantasy.NewSystemMessage(a.systemPromptPrefix),
+					fantasy.NewSystemMessage(systemPromptPrefix),
 				}, prepared.Messages...)
 			}
 			return callCtx, prepared, nil
@@ -756,7 +771,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
 	}
 
 	// Use the small model to generate the title.
-	model := &a.smallModel
+	model := smallModel
 	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
 	resp, err := agent.Stream(ctx, streamCall)
 	if err == nil {
@@ -765,7 +780,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
 	} else {
 		// It didn't work. Let's try with the big model.
 		slog.Error("error generating title with small model; trying big model", "err", err)
-		model = &a.largeModel
+		model = largeModel
 		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
 		resp, err = agent.Stream(ctx, streamCall)
 		if err == nil {
@@ -960,24 +975,20 @@ func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
 }
 
 func (a *sessionAgent) SetModels(large Model, small Model) {
-	a.largeModel = large
-	a.smallModel = small
+	a.largeModel.Set(large)
+	a.smallModel.Set(small)
 }
 
 func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
-	a.tools = tools
+	a.tools.SetSlice(tools)
 }
 
 func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
-	a.systemPrompt = systemPrompt
+	a.systemPrompt.Set(systemPrompt)
 }
 
 func (a *sessionAgent) Model() Model {
-	return a.largeModel
-}
-
-func (a *sessionAgent) promptPrefix() string {
-	return a.systemPromptPrefix
+	return a.largeModel.Get()
 }
 
 // convertToToolResult converts a fantasy tool result to a message tool result.
@@ -1034,9 +1045,9 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes
 //
 //	BEFORE: [tool result: image data]
 //	AFTER:  [tool result: "Image loaded - see attached"], [user: image attachment]
-func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
-	providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
-		a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
+func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
+	providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
+		largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
 
 	if providerSupportsMedia {
 		return messages

internal/agent/event.go 🔗

@@ -7,23 +7,23 @@ import (
 	"github.com/charmbracelet/crush/internal/event"
 )
 
-func (a sessionAgent) eventPromptSent(sessionID string) {
+func (a *sessionAgent) eventPromptSent(sessionID string) {
 	event.PromptSent(
-		a.eventCommon(sessionID, a.largeModel)...,
+		a.eventCommon(sessionID, a.largeModel.Get())...,
 	)
 }
 
-func (a sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) {
+func (a *sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) {
 	event.PromptResponded(
 		append(
-			a.eventCommon(sessionID, a.largeModel),
+			a.eventCommon(sessionID, a.largeModel.Get()),
 			"prompt duration pretty", duration.String(),
 			"prompt duration in seconds", int64(duration.Seconds()),
 		)...,
 	)
 }
 
-func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) {
+func (a *sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) {
 	event.TokensUsed(
 		append(
 			a.eventCommon(sessionID, model),
@@ -37,7 +37,7 @@ func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fanta
 	)
 }
 
-func (a sessionAgent) eventCommon(sessionID string, model Model) []any {
+func (a *sessionAgent) eventCommon(sessionID string, model Model) []any {
 	m := model.ModelCfg
 
 	return []any{

internal/csync/slices.go 🔗

@@ -2,7 +2,6 @@ package csync
 
 import (
 	"iter"
-	"slices"
 	"sync"
 )
 
@@ -63,24 +62,6 @@ func (s *Slice[T]) Append(items ...T) {
 	s.inner = append(s.inner, items...)
 }
 
-// Prepend adds an element to the beginning of the slice.
-func (s *Slice[T]) Prepend(item T) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	s.inner = append([]T{item}, s.inner...)
-}
-
-// Delete removes the element at the specified index.
-func (s *Slice[T]) Delete(index int) bool {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	if index < 0 || index >= len(s.inner) {
-		return false
-	}
-	s.inner = slices.Delete(s.inner, index, index+1)
-	return true
-}
-
 // Get returns the element at the specified index.
 func (s *Slice[T]) Get(index int) (T, bool) {
 	s.mu.RLock()
@@ -92,17 +73,6 @@ func (s *Slice[T]) Get(index int) (T, bool) {
 	return s.inner[index], true
 }
 
-// Set updates the element at the specified index.
-func (s *Slice[T]) Set(index int, item T) bool {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	if index < 0 || index >= len(s.inner) {
-		return false
-	}
-	s.inner[index] = item
-	return true
-}
-
 // Len returns the number of elements in the slice.
 func (s *Slice[T]) Len() int {
 	s.mu.RLock()
@@ -131,10 +101,7 @@ func (s *Slice[T]) Seq() iter.Seq[T] {
 
 // Seq2 returns an iterator that yields index-value pairs from the slice.
 func (s *Slice[T]) Seq2() iter.Seq2[int, T] {
-	s.mu.RLock()
-	items := make([]T, len(s.inner))
-	copy(items, s.inner)
-	s.mu.RUnlock()
+	items := s.Copy()
 	return func(yield func(int, T) bool) {
 		for i, v := range items {
 			if !yield(i, v) {
@@ -143,3 +110,12 @@ func (s *Slice[T]) Seq2() iter.Seq2[int, T] {
 		}
 	}
 }
+
+// Copy returns a copy of the inner slice.
+func (s *Slice[T]) Copy() []T {
+	s.mu.RLock()
+	defer s.mu.RUnlock()
+	items := make([]T, len(s.inner))
+	copy(items, s.inner)
+	return items
+}

internal/csync/slices_test.go 🔗

@@ -109,44 +109,6 @@ func TestSlice(t *testing.T) {
 		require.Equal(t, "world", val)
 	})
 
-	t.Run("Prepend", func(t *testing.T) {
-		s := NewSlice[string]()
-		s.Append("world")
-		s.Prepend("hello")
-
-		require.Equal(t, 2, s.Len())
-		val, ok := s.Get(0)
-		require.True(t, ok)
-		require.Equal(t, "hello", val)
-
-		val, ok = s.Get(1)
-		require.True(t, ok)
-		require.Equal(t, "world", val)
-	})
-
-	t.Run("Delete", func(t *testing.T) {
-		s := NewSliceFrom([]int{1, 2, 3, 4, 5})
-
-		// Delete middle element
-		ok := s.Delete(2)
-		require.True(t, ok)
-		require.Equal(t, 4, s.Len())
-
-		expected := []int{1, 2, 4, 5}
-		actual := slices.Collect(s.Seq())
-		require.Equal(t, expected, actual)
-
-		// Delete out of bounds
-		ok = s.Delete(10)
-		require.False(t, ok)
-		require.Equal(t, 4, s.Len())
-
-		// Delete negative index
-		ok = s.Delete(-1)
-		require.False(t, ok)
-		require.Equal(t, 4, s.Len())
-	})
-
 	t.Run("Get", func(t *testing.T) {
 		s := NewSliceFrom([]string{"a", "b", "c"})
 
@@ -163,25 +125,6 @@ func TestSlice(t *testing.T) {
 		require.False(t, ok)
 	})
 
-	t.Run("Set", func(t *testing.T) {
-		s := NewSliceFrom([]string{"a", "b", "c"})
-
-		ok := s.Set(1, "modified")
-		require.True(t, ok)
-
-		val, ok := s.Get(1)
-		require.True(t, ok)
-		require.Equal(t, "modified", val)
-
-		// Out of bounds
-		ok = s.Set(10, "invalid")
-		require.False(t, ok)
-
-		// Negative index
-		ok = s.Set(-1, "invalid")
-		require.False(t, ok)
-	})
-
 	t.Run("SetSlice", func(t *testing.T) {
 		s := NewSlice[int]()
 		s.Append(1)

internal/csync/value.go 🔗

@@ -0,0 +1,44 @@
+package csync
+
+import (
+	"reflect"
+	"sync"
+)
+
+// Value is a generic thread-safe wrapper for any value type.
+//
+// For slices, use [Slice]. For maps, use [Map]. Pointers are not supported.
+type Value[T any] struct {
+	v  T
+	mu sync.RWMutex
+}
+
+// NewValue creates a new Value with the given initial value.
+//
+// Panics if t is a pointer, slice, or map. Use the dedicated types for those.
+func NewValue[T any](t T) *Value[T] {
+	v := reflect.ValueOf(t)
+	switch v.Kind() {
+	case reflect.Pointer:
+		panic("csync.Value does not support pointer types")
+	case reflect.Slice:
+		panic("csync.Value does not support slice types; use csync.Slice")
+	case reflect.Map:
+		panic("csync.Value does not support map types; use csync.Map")
+	}
+	return &Value[T]{v: t}
+}
+
+// Get returns the current value.
+func (v *Value[T]) Get() T {
+	v.mu.RLock()
+	defer v.mu.RUnlock()
+	return v.v
+}
+
+// Set updates the value.
+func (v *Value[T]) Set(t T) {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	v.v = t
+}

internal/csync/value_test.go 🔗

@@ -0,0 +1,99 @@
+package csync
+
+import (
+	"sync"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestValue_GetSet(t *testing.T) {
+	t.Parallel()
+
+	v := NewValue(42)
+	require.Equal(t, 42, v.Get())
+
+	v.Set(100)
+	require.Equal(t, 100, v.Get())
+}
+
+func TestValue_ZeroValue(t *testing.T) {
+	t.Parallel()
+
+	v := NewValue("")
+	require.Equal(t, "", v.Get())
+
+	v.Set("hello")
+	require.Equal(t, "hello", v.Get())
+}
+
+func TestValue_Struct(t *testing.T) {
+	t.Parallel()
+
+	type config struct {
+		Name  string
+		Count int
+	}
+
+	v := NewValue(config{Name: "test", Count: 1})
+	require.Equal(t, config{Name: "test", Count: 1}, v.Get())
+
+	v.Set(config{Name: "updated", Count: 2})
+	require.Equal(t, config{Name: "updated", Count: 2}, v.Get())
+}
+
+func TestValue_PointerPanics(t *testing.T) {
+	t.Parallel()
+
+	require.Panics(t, func() {
+		NewValue(&struct{}{})
+	})
+}
+
+func TestValue_SlicePanics(t *testing.T) {
+	t.Parallel()
+
+	require.Panics(t, func() {
+		NewValue([]string{"a", "b"})
+	})
+}
+
+func TestValue_MapPanics(t *testing.T) {
+	t.Parallel()
+
+	require.Panics(t, func() {
+		NewValue(map[string]int{"a": 1})
+	})
+}
+
+func TestValue_ConcurrentAccess(t *testing.T) {
+	t.Parallel()
+
+	v := NewValue(0)
+	var wg sync.WaitGroup
+
+	// Concurrent writers.
+	for i := range 100 {
+		wg.Add(1)
+		go func(val int) {
+			defer wg.Done()
+			v.Set(val)
+		}(i)
+	}
+
+	// Concurrent readers.
+	for range 100 {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			_ = v.Get()
+		}()
+	}
+
+	wg.Wait()
+
+	// Value should be one of the set values (0-99).
+	got := v.Get()
+	require.GreaterOrEqual(t, got, 0)
+	require.Less(t, got, 100)
+}