diff --git a/internal/config/config.go b/internal/config/config.go index e33aab02a492e8a1a4c55554fe5a3656d101ec1e..f6814db44cdadefd0e88e57e2bcd2521bf8a3c28 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,6 +65,7 @@ type Model struct { DefaultMaxTokens int64 `json:"default_max_tokens"` CanReason bool `json:"can_reason"` ReasoningEffort string `json:"reasoning_effort"` + HasReasoningEffort bool `json:"has_reasoning_effort"` SupportsImages bool `json:"supports_attachments"` } @@ -156,8 +157,9 @@ type Options struct { } type PreferredModel struct { - ModelID string `json:"model_id"` - Provider provider.InferenceProvider `json:"provider"` + ModelID string `json:"model_id"` + Provider provider.InferenceProvider `json:"provider"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` } type PreferredModels struct { @@ -693,7 +695,7 @@ func defaultConfigBasedOnEnv() *Config { } providerConfig.BaseURL = baseURL for _, model := range p.Models { - providerConfig.Models = append(providerConfig.Models, Model{ + configModel := Model{ ID: model.ID, Name: model.Name, CostPer1MIn: model.CostPer1MIn, @@ -704,7 +706,13 @@ func defaultConfigBasedOnEnv() *Config { DefaultMaxTokens: model.DefaultMaxTokens, CanReason: model.CanReason, SupportsImages: model.SupportsImages, - }) + } + // Set reasoning effort for reasoning models + if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { + configModel.HasReasoningEffort = model.HasReasoningEffort + configModel.ReasoningEffort = model.DefaultReasoningEffort + } + providerConfig.Models = append(providerConfig.Models, configModel) } cfg.Providers[p.ID] = providerConfig } @@ -980,25 +988,13 @@ func (c *Config) validateProviders(errors *ValidationErrors) { } // Validate provider type - validType := false - for _, vt := range validTypes { - if providerConfig.ProviderType == vt { - validType = true - break - } - } + validType := slices.Contains(validTypes, providerConfig.ProviderType) if !validType { errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType)) } // Validate custom providers - isKnownProvider := false - for _, kp := range knownProviders { - if providerID == kp { - isKnownProvider = true - break - } - } + isKnownProvider := slices.Contains(knownProviders, providerID) if !isKnownProvider { // Custom provider validation @@ -1200,13 +1196,7 @@ func (c *Config) validateAgents(errors *ValidationErrors) { // Validate allowed tools if agent.AllowedTools != nil { for i, tool := range agent.AllowedTools { - validTool := false - for _, vt := range validTools { - if tool == vt { - validTool = true - break - } - } + validTool := slices.Contains(validTools, tool) if !validTool { errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool)) } diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go index 86b87768b95246654e176ca5f40af5aef249c23f..af92cc2c33f0b0adbe65dbd728b29727c35aeaa8 100644 --- a/internal/config/provider_mock.go +++ b/internal/config/provider_mock.go @@ -136,6 +136,34 @@ func MockProviders() []provider.Provider { CanReason: false, SupportsImages: true, }, + { + ID: "o1-preview", + Name: "o1-preview", + CostPer1MIn: 15.0, + CostPer1MOut: 60.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 32768, + CanReason: true, + HasReasoningEffort: true, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, + { + ID: "o1-mini", + Name: "o1-mini", + CostPer1MIn: 3.0, + CostPer1MOut: 12.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 65536, + CanReason: true, + HasReasoningEffort: true, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, }, }, { @@ -173,5 +201,57 @@ func MockProviders() []provider.Provider { }, }, }, + { + Name: "xAI", + ID: provider.InferenceProviderXAI, + APIKey: "$XAI_API_KEY", + APIEndpoint: "https://api.x.ai/v1", + Type: provider.TypeXAI, + DefaultLargeModelID: "grok-beta", + DefaultSmallModelID: "grok-beta", + Models: []provider.Model{ + { + ID: "grok-beta", + Name: "Grok Beta", + CostPer1MIn: 5.0, + CostPer1MOut: 15.0, + ContextWindow: 131072, + DefaultMaxTokens: 4096, + CanReason: false, + SupportsImages: true, + }, + }, + }, + { + Name: "OpenRouter", + ID: provider.InferenceProviderOpenRouter, + APIKey: "$OPENROUTER_API_KEY", + APIEndpoint: "https://openrouter.ai/api/v1", + Type: provider.TypeOpenAI, + DefaultLargeModelID: "anthropic/claude-3.5-sonnet", + DefaultSmallModelID: "anthropic/claude-3.5-haiku", + Models: []provider.Model{ + { + ID: "anthropic/claude-3.5-sonnet", + Name: "Claude 3.5 Sonnet", + CostPer1MIn: 3.0, + CostPer1MOut: 15.0, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + { + ID: "anthropic/claude-3.5-haiku", + Name: "Claude 3.5 Haiku", + CostPer1MIn: 0.8, + CostPer1MOut: 4.0, + ContextWindow: 200000, + DefaultMaxTokens: 8192, + CanReason: false, + SupportsImages: true, + }, + }, + }, } } diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 92547ff2925699d8519c33656395d3979a095b35..b175107d0df2bfaabc29e88550dc89471baf5188 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,6 +1,7 @@ package config import ( + "encoding/json" "testing" "github.com/charmbracelet/crush/internal/fur/provider" @@ -103,3 +104,177 @@ func TestResetProviders(t *testing.T) { // Should get the same mock data assert.Equal(t, len(providers1), len(providers2)) } + +func TestReasoningEffortSupport(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + providers := Providers() + + var openaiProvider provider.Provider + for _, p := range providers { + if p.ID == provider.InferenceProviderOpenAI { + openaiProvider = p + break + } + } + require.NotEmpty(t, openaiProvider.ID) + + var reasoningModel, nonReasoningModel provider.Model + for _, model := range openaiProvider.Models { + if model.CanReason && model.HasReasoningEffort { + reasoningModel = model + } else if !model.CanReason { + nonReasoningModel = model + } + } + + require.NotEmpty(t, reasoningModel.ID) + assert.Equal(t, "medium", reasoningModel.DefaultReasoningEffort) + assert.True(t, reasoningModel.HasReasoningEffort) + + require.NotEmpty(t, nonReasoningModel.ID) + assert.False(t, nonReasoningModel.HasReasoningEffort) + assert.Empty(t, nonReasoningModel.DefaultReasoningEffort) +} + +func TestReasoningEffortConfigTransfer(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + t.Setenv("OPENAI_API_KEY", "test-openai-key") + + cfg, err := Init(t.TempDir(), false) + require.NoError(t, err) + + openaiProviderConfig, exists := cfg.Providers[provider.InferenceProviderOpenAI] + require.True(t, exists) + + var foundReasoning, foundNonReasoning bool + for _, model := range openaiProviderConfig.Models { + if model.CanReason && model.HasReasoningEffort && model.ReasoningEffort != "" { + assert.Equal(t, "medium", model.ReasoningEffort) + assert.True(t, model.HasReasoningEffort) + foundReasoning = true + } else if !model.CanReason { + assert.Empty(t, model.ReasoningEffort) + assert.False(t, model.HasReasoningEffort) + foundNonReasoning = true + } + } + + assert.True(t, foundReasoning, "Should find at least one reasoning model") + assert.True(t, foundNonReasoning, "Should find at least one non-reasoning model") +} + +func TestNewProviders(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + providers := Providers() + require.NotEmpty(t, providers) + + var xaiProvider, openRouterProvider provider.Provider + for _, p := range providers { + switch p.ID { + case provider.InferenceProviderXAI: + xaiProvider = p + case provider.InferenceProviderOpenRouter: + openRouterProvider = p + } + } + + require.NotEmpty(t, xaiProvider.ID) + assert.Equal(t, "xAI", xaiProvider.Name) + assert.Equal(t, "grok-beta", xaiProvider.DefaultLargeModelID) + + require.NotEmpty(t, openRouterProvider.ID) + assert.Equal(t, "OpenRouter", openRouterProvider.Name) + assert.Equal(t, "anthropic/claude-3.5-sonnet", openRouterProvider.DefaultLargeModelID) +} + +func TestO1ModelsInMockProvider(t *testing.T) { + originalUseMock := UseMockProviders + UseMockProviders = true + defer func() { + UseMockProviders = originalUseMock + ResetProviders() + }() + + ResetProviders() + providers := Providers() + + var openaiProvider provider.Provider + for _, p := range providers { + if p.ID == provider.InferenceProviderOpenAI { + openaiProvider = p + break + } + } + require.NotEmpty(t, openaiProvider.ID) + + modelTests := []struct { + id string + name string + }{ + {"o1-preview", "o1-preview"}, + {"o1-mini", "o1-mini"}, + } + + for _, test := range modelTests { + var model provider.Model + var found bool + for _, m := range openaiProvider.Models { + if m.ID == test.id { + model = m + found = true + break + } + } + require.True(t, found, "Should find %s model", test.id) + assert.Equal(t, test.name, model.Name) + assert.True(t, model.CanReason) + assert.True(t, model.HasReasoningEffort) + assert.Equal(t, "medium", model.DefaultReasoningEffort) + } +} + +func TestPreferredModelReasoningEffort(t *testing.T) { + // Test that PreferredModel struct can hold reasoning effort + preferredModel := PreferredModel{ + ModelID: "o1-preview", + Provider: provider.InferenceProviderOpenAI, + ReasoningEffort: "high", + } + + assert.Equal(t, "o1-preview", preferredModel.ModelID) + assert.Equal(t, provider.InferenceProviderOpenAI, preferredModel.Provider) + assert.Equal(t, "high", preferredModel.ReasoningEffort) + + // Test JSON marshaling/unmarshaling + jsonData, err := json.Marshal(preferredModel) + require.NoError(t, err) + + var unmarshaled PreferredModel + err = json.Unmarshal(jsonData, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, preferredModel.ModelID, unmarshaled.ModelID) + assert.Equal(t, preferredModel.Provider, unmarshaled.Provider) + assert.Equal(t, preferredModel.ReasoningEffort, unmarshaled.ReasoningEffort) +} diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index a5c012861ad9e6b537c0e9bca8e957ef3f38bf2f..3531c5cc89cced262d5c22f2598216d9cfe4883e 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -407,27 +407,6 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) return true, int64(retryMs), nil } -func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall { - var toolCalls []message.ToolCall - - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - if part.FunctionCall != nil { - id := "call_" + uuid.New().String() - args, _ := json.Marshal(part.FunctionCall.Args) - toolCalls = append(toolCalls, message.ToolCall{ - ID: id, - Name: part.FunctionCall.Name, - Input: string(args), - Type: "function", - }) - } - } - } - - return toolCalls -} - func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { if resp == nil || resp.UsageMetadata == nil { return TokenUsage{} diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 9af060a80f75309e1e314e3c33df72e607c9c77a..f6aaacb0ce09ae665fab4bdbb14b28e13e7684c7 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -18,23 +18,14 @@ import ( "github.com/openai/openai-go/shared" ) -type openaiOptions struct { - reasoningEffort string -} - type openaiClient struct { providerOptions providerClientOptions - options openaiOptions client openai.Client } type OpenAIClient ProviderClient func newOpenAIClient(opts providerClientOptions) OpenAIClient { - openaiOpts := openaiOptions{ - reasoningEffort: "medium", - } - openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) @@ -52,7 +43,6 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient { client := openai.NewClient(openaiClientOptions...) return &openaiClient{ providerOptions: opts, - options: openaiOpts, client: client, } } @@ -153,6 +143,18 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { model := o.providerOptions.model(o.providerOptions.modelType) + cfg := config.Get() + + modelConfig := cfg.Models.Large + if o.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + + reasoningEffort := model.ReasoningEffort + if modelConfig.ReasoningEffort != "" { + reasoningEffort = modelConfig.ReasoningEffort + } + params := openai.ChatCompletionNewParams{ Model: openai.ChatModel(model.ID), Messages: messages, @@ -160,7 +162,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar } if model.CanReason { params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) - switch o.options.reasoningEffort { + switch reasoningEffort { case "low": params.ReasoningEffort = shared.ReasoningEffortLow case "medium":