@@ -87,12 +87,13 @@ type Model struct {
}
type sessionAgent struct {
- largeModel Model
- smallModel Model
- systemPromptPrefix string
- systemPrompt string
+ largeModel *csync.Value[Model]
+ smallModel *csync.Value[Model]
+ systemPromptPrefix *csync.Value[string]
+ systemPrompt *csync.Value[string]
+ tools *csync.Slice[fantasy.AgentTool]
+
isSubAgent bool
- tools []fantasy.AgentTool
sessions session.Service
messages message.Service
disableAutoSummarize bool
@@ -119,15 +120,15 @@ func NewSessionAgent(
opts SessionAgentOptions,
) SessionAgent {
return &sessionAgent{
- largeModel: opts.LargeModel,
- smallModel: opts.SmallModel,
- systemPromptPrefix: opts.SystemPromptPrefix,
- systemPrompt: opts.SystemPrompt,
+ largeModel: csync.NewValue(opts.LargeModel),
+ smallModel: csync.NewValue(opts.SmallModel),
+ systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
+ systemPrompt: csync.NewValue(opts.SystemPrompt),
isSubAgent: opts.IsSubAgent,
sessions: opts.Sessions,
messages: opts.Messages,
disableAutoSummarize: opts.DisableAutoSummarize,
- tools: opts.Tools,
+ tools: csync.NewSliceFrom(opts.Tools),
isYolo: opts.IsYolo,
messageQueue: csync.NewMap[string, []SessionAgentCall](),
activeRequests: csync.NewMap[string, context.CancelFunc](),
@@ -153,15 +154,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
return nil, nil
}
- if len(a.tools) > 0 {
+ // Copy mutable fields under lock to avoid races with SetTools/SetModels.
+ agentTools := a.tools.Copy()
+ largeModel := a.largeModel.Get()
+ systemPrompt := a.systemPrompt.Get()
+ promptPrefix := a.systemPromptPrefix.Get()
+
+ if len(agentTools) > 0 {
// Add Anthropic caching to the last tool.
- a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
+ agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
}
agent := fantasy.NewAgent(
- a.largeModel.Model,
- fantasy.WithSystemPrompt(a.systemPrompt),
- fantasy.WithTools(a.tools...),
+ largeModel.Model,
+ fantasy.WithSystemPrompt(systemPrompt),
+ fantasy.WithTools(agentTools...),
)
sessionLock := sync.Mutex{}
@@ -234,7 +241,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
}
- prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages)
+ prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
lastSystemRoleInx := 0
systemMessageUpdated := false
@@ -252,7 +259,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}
}
- if promptPrefix := a.promptPrefix(); promptPrefix != "" {
+ if promptPrefix != "" {
prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
}
@@ -260,15 +267,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
- Model: a.largeModel.ModelCfg.Model,
- Provider: a.largeModel.ModelCfg.Provider,
+ Model: largeModel.ModelCfg.Model,
+ Provider: largeModel.ModelCfg.Provider,
})
if err != nil {
return callContext, prepared, err
}
callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
- callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, a.largeModel.CatwalkCfg.SupportsImages)
- callContext = context.WithValue(callContext, tools.ModelNameContextKey, a.largeModel.CatwalkCfg.Name)
+ callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
+ callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
currentAssistant = &assistantMsg
return callContext, prepared, err
},
@@ -362,7 +369,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
sessionLock.Unlock()
return getSessionErr
}
- a.updateSessionUsage(a.largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
+ a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
_, sessionErr := a.sessions.Save(genCtx, updatedSession)
sessionLock.Unlock()
if sessionErr != nil {
@@ -372,7 +379,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
},
StopWhen: []fantasy.StopCondition{
func(_ []fantasy.StepResult) bool {
- cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
+ cw := int64(largeModel.CatwalkCfg.ContextWindow)
tokens := currentSession.CompletionTokens + currentSession.PromptTokens
remaining := cw - tokens
var threshold int64
@@ -474,7 +481,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
currentAssistant.AddFinish(
message.FinishReasonError,
"Copilot model not enabled",
- fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", a.largeModel.CatwalkCfg.Name, link),
+ fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
)
} else {
currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
@@ -529,6 +536,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return ErrSessionBusy
}
+ // Copy mutable fields under lock to avoid races with SetModels.
+ largeModel := a.largeModel.Get()
+ systemPromptPrefix := a.systemPromptPrefix.Get()
+
currentSession, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -549,13 +560,13 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
defer a.activeRequests.Del(sessionID)
defer cancel()
- agent := fantasy.NewAgent(a.largeModel.Model,
+ agent := fantasy.NewAgent(largeModel.Model,
fantasy.WithSystemPrompt(string(summaryPrompt)),
)
summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
- Model: a.largeModel.Model.Model(),
- Provider: a.largeModel.Model.Provider(),
+ Model: largeModel.Model.Model(),
+ Provider: largeModel.Model.Provider(),
IsSummaryMessage: true,
})
if err != nil {
@@ -570,8 +581,8 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
ProviderOptions: opts,
PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
prepared.Messages = options.Messages
- if a.systemPromptPrefix != "" {
- prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
+ if systemPromptPrefix != "" {
+ prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
}
return callContext, prepared, nil
},
@@ -622,7 +633,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
}
}
- a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
+ a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
// Just in case, get just the last usage info.
usage := resp.Response.Usage
@@ -730,9 +741,13 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
return
}
+ smallModel := a.smallModel.Get()
+ largeModel := a.largeModel.Get()
+ systemPromptPrefix := a.systemPromptPrefix.Get()
+
var maxOutputTokens int64 = 40
- if a.smallModel.CatwalkCfg.CanReason {
- maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
+ if smallModel.CatwalkCfg.CanReason {
+ maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
}
newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
@@ -746,9 +761,9 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
prepared.Messages = opts.Messages
- if a.systemPromptPrefix != "" {
+ if systemPromptPrefix != "" {
prepared.Messages = append([]fantasy.Message{
- fantasy.NewSystemMessage(a.systemPromptPrefix),
+ fantasy.NewSystemMessage(systemPromptPrefix),
}, prepared.Messages...)
}
return callCtx, prepared, nil
@@ -756,7 +771,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
}
// Use the small model to generate the title.
- model := &a.smallModel
+ model := smallModel
agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
resp, err := agent.Stream(ctx, streamCall)
if err == nil {
@@ -765,7 +780,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
} else {
// It didn't work. Let's try with the big model.
slog.Error("error generating title with small model; trying big model", "err", err)
- model = &a.largeModel
+ model = largeModel
agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
resp, err = agent.Stream(ctx, streamCall)
if err == nil {
@@ -960,24 +975,20 @@ func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
}
func (a *sessionAgent) SetModels(large Model, small Model) {
- a.largeModel = large
- a.smallModel = small
+ a.largeModel.Set(large)
+ a.smallModel.Set(small)
}
func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
- a.tools = tools
+ a.tools.SetSlice(tools)
}
func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
- a.systemPrompt = systemPrompt
+ a.systemPrompt.Set(systemPrompt)
}
func (a *sessionAgent) Model() Model {
- return a.largeModel
-}
-
-func (a *sessionAgent) promptPrefix() string {
- return a.systemPromptPrefix
+ return a.largeModel.Get()
}
// convertToToolResult converts a fantasy tool result to a message tool result.
@@ -1034,9 +1045,9 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes
//
// BEFORE: [tool result: image data]
// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
-func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message) []fantasy.Message {
- providerSupportsMedia := a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
- a.largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
+func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
+ providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
+ largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
if providerSupportsMedia {
return messages
@@ -109,44 +109,6 @@ func TestSlice(t *testing.T) {
require.Equal(t, "world", val)
})
- t.Run("Prepend", func(t *testing.T) {
- s := NewSlice[string]()
- s.Append("world")
- s.Prepend("hello")
-
- require.Equal(t, 2, s.Len())
- val, ok := s.Get(0)
- require.True(t, ok)
- require.Equal(t, "hello", val)
-
- val, ok = s.Get(1)
- require.True(t, ok)
- require.Equal(t, "world", val)
- })
-
- t.Run("Delete", func(t *testing.T) {
- s := NewSliceFrom([]int{1, 2, 3, 4, 5})
-
- // Delete middle element
- ok := s.Delete(2)
- require.True(t, ok)
- require.Equal(t, 4, s.Len())
-
- expected := []int{1, 2, 4, 5}
- actual := slices.Collect(s.Seq())
- require.Equal(t, expected, actual)
-
- // Delete out of bounds
- ok = s.Delete(10)
- require.False(t, ok)
- require.Equal(t, 4, s.Len())
-
- // Delete negative index
- ok = s.Delete(-1)
- require.False(t, ok)
- require.Equal(t, 4, s.Len())
- })
-
t.Run("Get", func(t *testing.T) {
s := NewSliceFrom([]string{"a", "b", "c"})
@@ -163,25 +125,6 @@ func TestSlice(t *testing.T) {
require.False(t, ok)
})
- t.Run("Set", func(t *testing.T) {
- s := NewSliceFrom([]string{"a", "b", "c"})
-
- ok := s.Set(1, "modified")
- require.True(t, ok)
-
- val, ok := s.Get(1)
- require.True(t, ok)
- require.Equal(t, "modified", val)
-
- // Out of bounds
- ok = s.Set(10, "invalid")
- require.False(t, ok)
-
- // Negative index
- ok = s.Set(-1, "invalid")
- require.False(t, ok)
- })
-
t.Run("SetSlice", func(t *testing.T) {
s := NewSlice[int]()
s.Append(1)