Detailed changes
@@ -65,6 +65,7 @@ type Model struct {
DefaultMaxTokens int64 `json:"default_max_tokens"`
CanReason bool `json:"can_reason"`
ReasoningEffort string `json:"reasoning_effort"`
+ HasReasoningEffort bool `json:"has_reasoning_effort"`
SupportsImages bool `json:"supports_attachments"`
}
@@ -156,8 +157,9 @@ type Options struct {
}
type PreferredModel struct {
- ModelID string `json:"model_id"`
- Provider provider.InferenceProvider `json:"provider"`
+ ModelID string `json:"model_id"`
+ Provider provider.InferenceProvider `json:"provider"`
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
}
type PreferredModels struct {
@@ -693,7 +695,7 @@ func defaultConfigBasedOnEnv() *Config {
}
providerConfig.BaseURL = baseURL
for _, model := range p.Models {
- providerConfig.Models = append(providerConfig.Models, Model{
+ configModel := Model{
ID: model.ID,
Name: model.Name,
CostPer1MIn: model.CostPer1MIn,
@@ -704,7 +706,13 @@ func defaultConfigBasedOnEnv() *Config {
DefaultMaxTokens: model.DefaultMaxTokens,
CanReason: model.CanReason,
SupportsImages: model.SupportsImages,
- })
+ }
+ // Set reasoning effort for reasoning models
+ if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
+ configModel.HasReasoningEffort = model.HasReasoningEffort
+ configModel.ReasoningEffort = model.DefaultReasoningEffort
+ }
+ providerConfig.Models = append(providerConfig.Models, configModel)
}
cfg.Providers[p.ID] = providerConfig
}
@@ -980,25 +988,13 @@ func (c *Config) validateProviders(errors *ValidationErrors) {
}
// Validate provider type
- validType := false
- for _, vt := range validTypes {
- if providerConfig.ProviderType == vt {
- validType = true
- break
- }
- }
+ validType := slices.Contains(validTypes, providerConfig.ProviderType)
if !validType {
errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
}
// Validate custom providers
- isKnownProvider := false
- for _, kp := range knownProviders {
- if providerID == kp {
- isKnownProvider = true
- break
- }
- }
+ isKnownProvider := slices.Contains(knownProviders, providerID)
if !isKnownProvider {
// Custom provider validation
@@ -1200,13 +1196,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) {
// Validate allowed tools
if agent.AllowedTools != nil {
for i, tool := range agent.AllowedTools {
- validTool := false
- for _, vt := range validTools {
- if tool == vt {
- validTool = true
- break
- }
- }
+ validTool := slices.Contains(validTools, tool)
if !validTool {
errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
}
@@ -136,6 +136,34 @@ func MockProviders() []provider.Provider {
CanReason: false,
SupportsImages: true,
},
+ {
+ ID: "o1-preview",
+ Name: "o1-preview",
+ CostPer1MIn: 15.0,
+ CostPer1MOut: 60.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 32768,
+ CanReason: true,
+ HasReasoningEffort: true,
+ DefaultReasoningEffort: "medium",
+ SupportsImages: true,
+ },
+ {
+ ID: "o1-mini",
+ Name: "o1-mini",
+ CostPer1MIn: 3.0,
+ CostPer1MOut: 12.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 65536,
+ CanReason: true,
+ HasReasoningEffort: true,
+ DefaultReasoningEffort: "medium",
+ SupportsImages: true,
+ },
},
},
{
@@ -173,5 +201,57 @@ func MockProviders() []provider.Provider {
},
},
},
+ {
+ Name: "xAI",
+ ID: provider.InferenceProviderXAI,
+ APIKey: "$XAI_API_KEY",
+ APIEndpoint: "https://api.x.ai/v1",
+ Type: provider.TypeXAI,
+ DefaultLargeModelID: "grok-beta",
+ DefaultSmallModelID: "grok-beta",
+ Models: []provider.Model{
+ {
+ ID: "grok-beta",
+ Name: "Grok Beta",
+ CostPer1MIn: 5.0,
+ CostPer1MOut: 15.0,
+ ContextWindow: 131072,
+ DefaultMaxTokens: 4096,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
+ {
+ Name: "OpenRouter",
+ ID: provider.InferenceProviderOpenRouter,
+ APIKey: "$OPENROUTER_API_KEY",
+ APIEndpoint: "https://openrouter.ai/api/v1",
+ Type: provider.TypeOpenAI,
+ DefaultLargeModelID: "anthropic/claude-3.5-sonnet",
+ DefaultSmallModelID: "anthropic/claude-3.5-haiku",
+ Models: []provider.Model{
+ {
+ ID: "anthropic/claude-3.5-sonnet",
+ Name: "Claude 3.5 Sonnet",
+ CostPer1MIn: 3.0,
+ CostPer1MOut: 15.0,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 8192,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ {
+ ID: "anthropic/claude-3.5-haiku",
+ Name: "Claude 3.5 Haiku",
+ CostPer1MIn: 0.8,
+ CostPer1MOut: 4.0,
+ ContextWindow: 200000,
+ DefaultMaxTokens: 8192,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ },
+ },
}
}
@@ -1,6 +1,7 @@
package config
import (
+ "encoding/json"
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
@@ -103,3 +104,177 @@ func TestResetProviders(t *testing.T) {
// Should get the same mock data
assert.Equal(t, len(providers1), len(providers2))
}
+
+func TestReasoningEffortSupport(t *testing.T) {
+ originalUseMock := UseMockProviders
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ ResetProviders()
+ providers := Providers()
+
+ var openaiProvider provider.Provider
+ for _, p := range providers {
+ if p.ID == provider.InferenceProviderOpenAI {
+ openaiProvider = p
+ break
+ }
+ }
+ require.NotEmpty(t, openaiProvider.ID)
+
+ var reasoningModel, nonReasoningModel provider.Model
+ for _, model := range openaiProvider.Models {
+ if model.CanReason && model.HasReasoningEffort {
+ reasoningModel = model
+ } else if !model.CanReason {
+ nonReasoningModel = model
+ }
+ }
+
+ require.NotEmpty(t, reasoningModel.ID)
+ assert.Equal(t, "medium", reasoningModel.DefaultReasoningEffort)
+ assert.True(t, reasoningModel.HasReasoningEffort)
+
+ require.NotEmpty(t, nonReasoningModel.ID)
+ assert.False(t, nonReasoningModel.HasReasoningEffort)
+ assert.Empty(t, nonReasoningModel.DefaultReasoningEffort)
+}
+
+func TestReasoningEffortConfigTransfer(t *testing.T) {
+ originalUseMock := UseMockProviders
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ ResetProviders()
+ t.Setenv("OPENAI_API_KEY", "test-openai-key")
+
+ cfg, err := Init(t.TempDir(), false)
+ require.NoError(t, err)
+
+ openaiProviderConfig, exists := cfg.Providers[provider.InferenceProviderOpenAI]
+ require.True(t, exists)
+
+ var foundReasoning, foundNonReasoning bool
+ for _, model := range openaiProviderConfig.Models {
+ if model.CanReason && model.HasReasoningEffort && model.ReasoningEffort != "" {
+ assert.Equal(t, "medium", model.ReasoningEffort)
+ assert.True(t, model.HasReasoningEffort)
+ foundReasoning = true
+ } else if !model.CanReason {
+ assert.Empty(t, model.ReasoningEffort)
+ assert.False(t, model.HasReasoningEffort)
+ foundNonReasoning = true
+ }
+ }
+
+ assert.True(t, foundReasoning, "Should find at least one reasoning model")
+ assert.True(t, foundNonReasoning, "Should find at least one non-reasoning model")
+}
+
+func TestNewProviders(t *testing.T) {
+ originalUseMock := UseMockProviders
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ ResetProviders()
+ providers := Providers()
+ require.NotEmpty(t, providers)
+
+ var xaiProvider, openRouterProvider provider.Provider
+ for _, p := range providers {
+ switch p.ID {
+ case provider.InferenceProviderXAI:
+ xaiProvider = p
+ case provider.InferenceProviderOpenRouter:
+ openRouterProvider = p
+ }
+ }
+
+ require.NotEmpty(t, xaiProvider.ID)
+ assert.Equal(t, "xAI", xaiProvider.Name)
+ assert.Equal(t, "grok-beta", xaiProvider.DefaultLargeModelID)
+
+ require.NotEmpty(t, openRouterProvider.ID)
+ assert.Equal(t, "OpenRouter", openRouterProvider.Name)
+ assert.Equal(t, "anthropic/claude-3.5-sonnet", openRouterProvider.DefaultLargeModelID)
+}
+
+func TestO1ModelsInMockProvider(t *testing.T) {
+ originalUseMock := UseMockProviders
+ UseMockProviders = true
+ defer func() {
+ UseMockProviders = originalUseMock
+ ResetProviders()
+ }()
+
+ ResetProviders()
+ providers := Providers()
+
+ var openaiProvider provider.Provider
+ for _, p := range providers {
+ if p.ID == provider.InferenceProviderOpenAI {
+ openaiProvider = p
+ break
+ }
+ }
+ require.NotEmpty(t, openaiProvider.ID)
+
+ modelTests := []struct {
+ id string
+ name string
+ }{
+ {"o1-preview", "o1-preview"},
+ {"o1-mini", "o1-mini"},
+ }
+
+ for _, test := range modelTests {
+ var model provider.Model
+ var found bool
+ for _, m := range openaiProvider.Models {
+ if m.ID == test.id {
+ model = m
+ found = true
+ break
+ }
+ }
+ require.True(t, found, "Should find %s model", test.id)
+ assert.Equal(t, test.name, model.Name)
+ assert.True(t, model.CanReason)
+ assert.True(t, model.HasReasoningEffort)
+ assert.Equal(t, "medium", model.DefaultReasoningEffort)
+ }
+}
+
+func TestPreferredModelReasoningEffort(t *testing.T) {
+ // Test that PreferredModel struct can hold reasoning effort
+ preferredModel := PreferredModel{
+ ModelID: "o1-preview",
+ Provider: provider.InferenceProviderOpenAI,
+ ReasoningEffort: "high",
+ }
+
+ assert.Equal(t, "o1-preview", preferredModel.ModelID)
+ assert.Equal(t, provider.InferenceProviderOpenAI, preferredModel.Provider)
+ assert.Equal(t, "high", preferredModel.ReasoningEffort)
+
+ // Test JSON marshaling/unmarshaling
+ jsonData, err := json.Marshal(preferredModel)
+ require.NoError(t, err)
+
+ var unmarshaled PreferredModel
+ err = json.Unmarshal(jsonData, &unmarshaled)
+ require.NoError(t, err)
+
+ assert.Equal(t, preferredModel.ModelID, unmarshaled.ModelID)
+ assert.Equal(t, preferredModel.Provider, unmarshaled.Provider)
+ assert.Equal(t, preferredModel.ReasoningEffort, unmarshaled.ReasoningEffort)
+}
@@ -407,27 +407,6 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
return true, int64(retryMs), nil
}
-func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
- var toolCalls []message.ToolCall
-
- if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
- for _, part := range resp.Candidates[0].Content.Parts {
- if part.FunctionCall != nil {
- id := "call_" + uuid.New().String()
- args, _ := json.Marshal(part.FunctionCall.Args)
- toolCalls = append(toolCalls, message.ToolCall{
- ID: id,
- Name: part.FunctionCall.Name,
- Input: string(args),
- Type: "function",
- })
- }
- }
- }
-
- return toolCalls
-}
-
func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
@@ -18,23 +18,14 @@ import (
"github.com/openai/openai-go/shared"
)
-type openaiOptions struct {
- reasoningEffort string
-}
-
type openaiClient struct {
providerOptions providerClientOptions
- options openaiOptions
client openai.Client
}
type OpenAIClient ProviderClient
func newOpenAIClient(opts providerClientOptions) OpenAIClient {
- openaiOpts := openaiOptions{
- reasoningEffort: "medium",
- }
-
openaiClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
@@ -52,7 +43,6 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient {
client := openai.NewClient(openaiClientOptions...)
return &openaiClient{
providerOptions: opts,
- options: openaiOpts,
client: client,
}
}
@@ -153,6 +143,18 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
model := o.providerOptions.model(o.providerOptions.modelType)
+ cfg := config.Get()
+
+ modelConfig := cfg.Models.Large
+ if o.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+
+ reasoningEffort := model.ReasoningEffort
+ if modelConfig.ReasoningEffort != "" {
+ reasoningEffort = modelConfig.ReasoningEffort
+ }
+
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(model.ID),
Messages: messages,
@@ -160,7 +162,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
}
if model.CanReason {
params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
- switch o.options.reasoningEffort {
+ switch reasoningEffort {
case "low":
params.ReasoningEffort = shared.ReasoningEffortLow
case "medium":