diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 198159d53adbcbba8f8598bf24a8eef55825acfc..c0b9080bb640085c6fd0fdbde8db0fbfe7f476dd 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -87,12 +87,13 @@ type Model struct { } type sessionAgent struct { - largeModel Model - smallModel Model - systemPromptPrefix string - systemPrompt string + largeModel *csync.Value[Model] + smallModel *csync.Value[Model] + systemPromptPrefix *csync.Value[string] + systemPrompt *csync.Value[string] + tools *csync.Slice[fantasy.AgentTool] + isSubAgent bool - tools []fantasy.AgentTool sessions session.Service messages message.Service disableAutoSummarize bool @@ -119,15 +120,15 @@ func NewSessionAgent( opts SessionAgentOptions, ) SessionAgent { return &sessionAgent{ - largeModel: opts.LargeModel, - smallModel: opts.SmallModel, - systemPromptPrefix: opts.SystemPromptPrefix, - systemPrompt: opts.SystemPrompt, + largeModel: csync.NewValue(opts.LargeModel), + smallModel: csync.NewValue(opts.SmallModel), + systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix), + systemPrompt: csync.NewValue(opts.SystemPrompt), isSubAgent: opts.IsSubAgent, sessions: opts.Sessions, messages: opts.Messages, disableAutoSummarize: opts.DisableAutoSummarize, - tools: opts.Tools, + tools: csync.NewSliceFrom(opts.Tools), isYolo: opts.IsYolo, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), @@ -153,15 +154,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy return nil, nil } - if len(a.tools) > 0 { + // Copy mutable fields under lock to avoid races with SetTools/SetModels. + agentTools := a.tools.Copy() + largeModel := a.largeModel.Get() + systemPrompt := a.systemPrompt.Get() + promptPrefix := a.systemPromptPrefix.Get() + + if len(agentTools) > 0 { // Add Anthropic caching to the last tool. - a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions()) + agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions()) } agent := fantasy.NewAgent( - a.largeModel.Model, - fantasy.WithSystemPrompt(a.systemPrompt), - fantasy.WithTools(a.tools...), + largeModel.Model, + fantasy.WithSystemPrompt(systemPrompt), + fantasy.WithTools(agentTools...), ) sessionLock := sync.Mutex{} @@ -234,7 +241,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } - prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages) + prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel) lastSystemRoleInx := 0 systemMessageUpdated := false @@ -252,7 +259,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy } } - if promptPrefix := a.promptPrefix(); promptPrefix != "" { + if promptPrefix != "" { prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...) } @@ -260,15 +267,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: a.largeModel.ModelCfg.Model, - Provider: a.largeModel.ModelCfg.Provider, + Model: largeModel.ModelCfg.Model, + Provider: largeModel.ModelCfg.Provider, }) if err != nil { return callContext, prepared, err } callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID) - callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages) - callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name) + callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages) + callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name) currentAssistant = &assistantMsg return callContext, prepared, err }, @@ -362,7 +369,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy sessionLock.Unlock() return getSessionErr } - a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) + a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) _, sessionErr := a.sessions.Save(genCtx, updatedSession) sessionLock.Unlock() if sessionErr != nil { @@ -372,7 +379,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy }, StopWhen: []fantasy.StopCondition{ func(_ []fantasy.StepResult) bool { - cw := int64(a.largeModel.CatwalkCfg.ContextWindow) + cw := int64(largeModel.CatwalkCfg.ContextWindow) tokens := currentSession.CompletionTokens + currentSession.PromptTokens remaining := cw - tokens var threshold int64 @@ -474,7 +481,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy currentAssistant.AddFinish( message.FinishReasonError, "Copilot model not enabled", - fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link), + fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link), ) } else { currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message) @@ -529,6 +536,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan return ErrSessionBusy } + // Copy mutable fields under lock to avoid races with SetModels. + largeModel := a.largeModel.Get() + systemPromptPrefix := a.systemPromptPrefix.Get() + currentSession, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) @@ -549,13 +560,13 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan defer a.activeRequests.Del(sessionID) defer cancel() - agent := fantasy.NewAgent(a.largeModel.Model, + agent := fantasy.NewAgent(largeModel.Model, fantasy.WithSystemPrompt(string(summaryPrompt)), ) summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, - Model: a.largeModel.Model.Model(), - Provider: a.largeModel.Model.Provider(), + Model: largeModel.Model.Model(), + Provider: largeModel.Model.Provider(), IsSummaryMessage: true, }) if err != nil { @@ -570,8 +581,8 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan ProviderOptions: opts, PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { prepared.Messages = options.Messages - if a.systemPromptPrefix != "" { - prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...) + if systemPromptPrefix != "" { + prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...) } return callContext, prepared, nil }, @@ -622,7 +633,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan } } - a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost) + a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost) // Just in case, get just the last usage info. usage := resp.Response.Usage @@ -730,9 +741,13 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user return } + smallModel := a.smallModel.Get() + largeModel := a.largeModel.Get() + systemPromptPrefix := a.systemPromptPrefix.Get() + var maxOutputTokens int64 = 40 - if a.smallModel.CatwalkCfg.CanReason { - maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens + if smallModel.CatwalkCfg.CanReason { + maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens } newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent { @@ -746,9 +761,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n \n\n", userPrompt), PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { prepared.Messages = opts.Messages - if a.systemPromptPrefix != "" { + if systemPromptPrefix != "" { prepared.Messages = append([]fantasy.Message{ - fantasy.NewSystemMessage(a.systemPromptPrefix), + fantasy.NewSystemMessage(systemPromptPrefix), }, prepared.Messages...) } return callCtx, prepared, nil @@ -756,7 +771,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user } // Use the small model to generate the title. - model := &a.smallModel + model := smallModel agent := newAgent(model.Model, titlePrompt, maxOutputTokens) resp, err := agent.Stream(ctx, streamCall) if err == nil { @@ -765,7 +780,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user } else { // It didn't work. Let's try with the big model. slog.Error("error generating title with small model; trying big model", "err", err) - model = &a.largeModel + model = largeModel agent = newAgent(model.Model, titlePrompt, maxOutputTokens) resp, err = agent.Stream(ctx, streamCall) if err == nil { @@ -960,24 +975,20 @@ func (a *sessionAgent) QueuedPromptsList(sessionID string) []string { } func (a *sessionAgent) SetModels(large Model, small Model) { - a.largeModel = large - a.smallModel = small + a.largeModel.Set(large) + a.smallModel.Set(small) } func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) { - a.tools = tools + a.tools.SetSlice(tools) } func (a *sessionAgent) SetSystemPrompt(systemPrompt string) { - a.systemPrompt = systemPrompt + a.systemPrompt.Set(systemPrompt) } func (a *sessionAgent) Model() Model { - return a.largeModel -} - -func (a *sessionAgent) promptPrefix() string { - return a.systemPromptPrefix + return a.largeModel.Get() } // convertToToolResult converts a fantasy tool result to a message tool result. @@ -1034,9 +1045,9 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes // // BEFORE: [tool result: image data] // AFTER: [tool result: "Image loaded - see attached"], [user: image attachment] -func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message { - providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) || - a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock) +func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message { + providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) || + largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock) if providerSupportsMedia { return messages diff --git a/internal/agent/event.go b/internal/agent/event.go index bf36ec84bf4270bd2e63ae0efae0440474288565..3f6c640f6a983c515034e0698676632d0cb57824 100644 --- a/internal/agent/event.go +++ b/internal/agent/event.go @@ -7,23 +7,23 @@ import ( "github.com/charmbracelet/crush/internal/event" ) -func (a sessionAgent) eventPromptSent(sessionID string) { +func (a *sessionAgent) eventPromptSent(sessionID string) { event.PromptSent( - a.eventCommon(sessionID, a.largeModel)..., + a.eventCommon(sessionID, a.largeModel.Get())..., ) } -func (a sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) { +func (a *sessionAgent) eventPromptResponded(sessionID string, duration time.Duration) { event.PromptResponded( append( - a.eventCommon(sessionID, a.largeModel), + a.eventCommon(sessionID, a.largeModel.Get()), "prompt duration pretty", duration.String(), "prompt duration in seconds", int64(duration.Seconds()), )..., ) } -func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) { +func (a *sessionAgent) eventTokensUsed(sessionID string, model Model, usage fantasy.Usage, cost float64) { event.TokensUsed( append( a.eventCommon(sessionID, model), @@ -37,7 +37,7 @@ func (a sessionAgent) eventTokensUsed(sessionID string, model Model, usage fanta ) } -func (a sessionAgent) eventCommon(sessionID string, model Model) []any { +func (a *sessionAgent) eventCommon(sessionID string, model Model) []any { m := model.ModelCfg return []any{ diff --git a/internal/csync/slices.go b/internal/csync/slices.go index c5c635683e70046694f1cdf647aac8cb425abd24..fcce9881b6e27021adcc9462b123f49d469dcd9f 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -2,7 +2,6 @@ package csync import ( "iter" - "slices" "sync" ) @@ -63,24 +62,6 @@ func (s *Slice[T]) Append(items ...T) { s.inner = append(s.inner, items...) } -// Prepend adds an element to the beginning of the slice. -func (s *Slice[T]) Prepend(item T) { - s.mu.Lock() - defer s.mu.Unlock() - s.inner = append([]T{item}, s.inner...) -} - -// Delete removes the element at the specified index. -func (s *Slice[T]) Delete(index int) bool { - s.mu.Lock() - defer s.mu.Unlock() - if index < 0 || index >= len(s.inner) { - return false - } - s.inner = slices.Delete(s.inner, index, index+1) - return true -} - // Get returns the element at the specified index. func (s *Slice[T]) Get(index int) (T, bool) { s.mu.RLock() @@ -92,17 +73,6 @@ func (s *Slice[T]) Get(index int) (T, bool) { return s.inner[index], true } -// Set updates the element at the specified index. -func (s *Slice[T]) Set(index int, item T) bool { - s.mu.Lock() - defer s.mu.Unlock() - if index < 0 || index >= len(s.inner) { - return false - } - s.inner[index] = item - return true -} - // Len returns the number of elements in the slice. func (s *Slice[T]) Len() int { s.mu.RLock() @@ -131,10 +101,7 @@ func (s *Slice[T]) Seq() iter.Seq[T] { // Seq2 returns an iterator that yields index-value pairs from the slice. func (s *Slice[T]) Seq2() iter.Seq2[int, T] { - s.mu.RLock() - items := make([]T, len(s.inner)) - copy(items, s.inner) - s.mu.RUnlock() + items := s.Copy() return func(yield func(int, T) bool) { for i, v := range items { if !yield(i, v) { @@ -143,3 +110,12 @@ func (s *Slice[T]) Seq2() iter.Seq2[int, T] { } } } + +// Copy returns a copy of the inner slice. +func (s *Slice[T]) Copy() []T { + s.mu.RLock() + defer s.mu.RUnlock() + items := make([]T, len(s.inner)) + copy(items, s.inner) + return items +} diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go index 85aedbaba40103ff9a8979e5c70299223f74591f..c7946ac6f1a84614def05b7b6e7e9b0ed11b3a73 100644 --- a/internal/csync/slices_test.go +++ b/internal/csync/slices_test.go @@ -109,44 +109,6 @@ func TestSlice(t *testing.T) { require.Equal(t, "world", val) }) - t.Run("Prepend", func(t *testing.T) { - s := NewSlice[string]() - s.Append("world") - s.Prepend("hello") - - require.Equal(t, 2, s.Len()) - val, ok := s.Get(0) - require.True(t, ok) - require.Equal(t, "hello", val) - - val, ok = s.Get(1) - require.True(t, ok) - require.Equal(t, "world", val) - }) - - t.Run("Delete", func(t *testing.T) { - s := NewSliceFrom([]int{1, 2, 3, 4, 5}) - - // Delete middle element - ok := s.Delete(2) - require.True(t, ok) - require.Equal(t, 4, s.Len()) - - expected := []int{1, 2, 4, 5} - actual := slices.Collect(s.Seq()) - require.Equal(t, expected, actual) - - // Delete out of bounds - ok = s.Delete(10) - require.False(t, ok) - require.Equal(t, 4, s.Len()) - - // Delete negative index - ok = s.Delete(-1) - require.False(t, ok) - require.Equal(t, 4, s.Len()) - }) - t.Run("Get", func(t *testing.T) { s := NewSliceFrom([]string{"a", "b", "c"}) @@ -163,25 +125,6 @@ func TestSlice(t *testing.T) { require.False(t, ok) }) - t.Run("Set", func(t *testing.T) { - s := NewSliceFrom([]string{"a", "b", "c"}) - - ok := s.Set(1, "modified") - require.True(t, ok) - - val, ok := s.Get(1) - require.True(t, ok) - require.Equal(t, "modified", val) - - // Out of bounds - ok = s.Set(10, "invalid") - require.False(t, ok) - - // Negative index - ok = s.Set(-1, "invalid") - require.False(t, ok) - }) - t.Run("SetSlice", func(t *testing.T) { s := NewSlice[int]() s.Append(1) diff --git a/internal/csync/value.go b/internal/csync/value.go new file mode 100644 index 0000000000000000000000000000000000000000..17528a281e0d34d49b206a7c3901b892370c18ba --- /dev/null +++ b/internal/csync/value.go @@ -0,0 +1,44 @@ +package csync + +import ( + "reflect" + "sync" +) + +// Value is a generic thread-safe wrapper for any value type. +// +// For slices, use [Slice]. For maps, use [Map]. Pointers are not supported. +type Value[T any] struct { + v T + mu sync.RWMutex +} + +// NewValue creates a new Value with the given initial value. +// +// Panics if t is a pointer, slice, or map. Use the dedicated types for those. +func NewValue[T any](t T) *Value[T] { + v := reflect.ValueOf(t) + switch v.Kind() { + case reflect.Pointer: + panic("csync.Value does not support pointer types") + case reflect.Slice: + panic("csync.Value does not support slice types; use csync.Slice") + case reflect.Map: + panic("csync.Value does not support map types; use csync.Map") + } + return &Value[T]{v: t} +} + +// Get returns the current value. +func (v *Value[T]) Get() T { + v.mu.RLock() + defer v.mu.RUnlock() + return v.v +} + +// Set updates the value. +func (v *Value[T]) Set(t T) { + v.mu.Lock() + defer v.mu.Unlock() + v.v = t +} diff --git a/internal/csync/value_test.go b/internal/csync/value_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3fa41d85144ea9373c7d440238c0321f52286330 --- /dev/null +++ b/internal/csync/value_test.go @@ -0,0 +1,99 @@ +package csync + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValue_GetSet(t *testing.T) { + t.Parallel() + + v := NewValue(42) + require.Equal(t, 42, v.Get()) + + v.Set(100) + require.Equal(t, 100, v.Get()) +} + +func TestValue_ZeroValue(t *testing.T) { + t.Parallel() + + v := NewValue("") + require.Equal(t, "", v.Get()) + + v.Set("hello") + require.Equal(t, "hello", v.Get()) +} + +func TestValue_Struct(t *testing.T) { + t.Parallel() + + type config struct { + Name string + Count int + } + + v := NewValue(config{Name: "test", Count: 1}) + require.Equal(t, config{Name: "test", Count: 1}, v.Get()) + + v.Set(config{Name: "updated", Count: 2}) + require.Equal(t, config{Name: "updated", Count: 2}, v.Get()) +} + +func TestValue_PointerPanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue(&struct{}{}) + }) +} + +func TestValue_SlicePanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue([]string{"a", "b"}) + }) +} + +func TestValue_MapPanics(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + NewValue(map[string]int{"a": 1}) + }) +} + +func TestValue_ConcurrentAccess(t *testing.T) { + t.Parallel() + + v := NewValue(0) + var wg sync.WaitGroup + + // Concurrent writers. + for i := range 100 { + wg.Add(1) + go func(val int) { + defer wg.Done() + v.Set(val) + }(i) + } + + // Concurrent readers. + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + _ = v.Get() + }() + } + + wg.Wait() + + // Value should be one of the set values (0-99). + got := v.Get() + require.GreaterOrEqual(t, got, 0) + require.Less(t, got, 100) +}