Detailed changes
@@ -157,9 +157,12 @@ type Options struct {
}
type PreferredModel struct {
- ModelID string `json:"model_id"`
- Provider provider.InferenceProvider `json:"provider"`
- ReasoningEffort string `json:"reasoning_effort,omitempty"`
+ ModelID string `json:"model_id"`
+ Provider provider.InferenceProvider `json:"provider"`
+ // Overrides the default reasoning effort for this model
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
+ // Overrides the default max tokens for this model
+ MaxTokens int64 `json:"max_tokens,omitempty"`
}
type PreferredModels struct {
@@ -147,7 +147,6 @@ func NewAgent(
opts := []provider.ProviderClientOption{
provider.WithModel(agentCfg.Model),
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
- provider.WithMaxTokens(model.DefaultMaxTokens),
}
agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
if err != nil {
@@ -184,7 +183,6 @@ func NewAgent(
titleOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
- provider.WithMaxTokens(40),
}
titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
if err != nil {
@@ -193,7 +191,6 @@ func NewAgent(
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
- provider.WithMaxTokens(smallModel.DefaultMaxTokens),
}
summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
if err != nil {
@@ -832,7 +829,6 @@ func (a *agent) UpdateModel() error {
opts := []provider.ProviderClientOption{
provider.WithModel(a.agentCfg.Model),
provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
- provider.WithMaxTokens(model.DefaultMaxTokens),
}
newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...)
@@ -877,7 +873,6 @@ func (a *agent) UpdateModel() error {
titleOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
- provider.WithMaxTokens(40),
}
newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
if err != nil {
@@ -888,7 +883,6 @@ func (a *agent) UpdateModel() error {
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
- provider.WithMaxTokens(smallModel.DefaultMaxTokens),
}
newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
if err != nil {
@@ -164,9 +164,19 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
// }
// }
+ cfg := config.Get()
+ modelConfig := cfg.Models.Large
+ if a.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
+
return anthropic.MessageNewParams{
Model: anthropic.Model(model.ID),
- MaxTokens: a.providerOptions.maxTokens,
+ MaxTokens: maxTokens,
Temperature: temperature,
Messages: messages,
Tools: tools,
@@ -155,17 +155,26 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
// Convert messages
geminiMessages := g.convertMessages(messages)
-
+ model := g.providerOptions.model(g.providerOptions.modelType)
cfg := config.Get()
if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
+ modelConfig := cfg.Models.Large
+ if g.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
- MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
@@ -173,7 +182,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
- model := g.providerOptions.model(g.providerOptions.modelType)
chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
attempts := 0
@@ -245,16 +253,25 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
// Convert messages
geminiMessages := g.convertMessages(messages)
+ model := g.providerOptions.model(g.providerOptions.modelType)
cfg := config.Get()
if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
+ modelConfig := cfg.Models.Large
+ if g.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
- MaxOutputTokens: int32(g.providerOptions.maxTokens),
+ MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
},
@@ -262,7 +279,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
- model := g.providerOptions.model(g.providerOptions.modelType)
chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
attempts := 0
@@ -160,8 +160,13 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
Messages: messages,
Tools: tools,
}
+
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
if model.CanReason {
- params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
+ params.MaxCompletionTokens = openai.Int(maxTokens)
switch reasoningEffort {
case "low":
params.ReasoningEffort = shared.ReasoningEffortLow
@@ -173,7 +178,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
params.ReasoningEffort = shared.ReasoningEffortMedium
}
} else {
- params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
+ params.MaxTokens = openai.Int(maxTokens)
}
return params
@@ -64,7 +64,6 @@ type providerClientOptions struct {
modelType config.ModelType
model func(config.ModelType) config.Model
disableCache bool
- maxTokens int64
systemMessage string
extraHeaders map[string]string
extraParams map[string]string
@@ -121,12 +120,6 @@ func WithDisableCache(disableCache bool) ProviderClientOption {
}
}
-func WithMaxTokens(maxTokens int64) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.maxTokens = maxTokens
- }
-}
-
func WithSystemMessage(systemMessage string) ProviderClientOption {
return func(options *providerClientOptions) {
options.systemMessage = systemMessage