From a92be192e1bc42dd6040fc770e2f1ae0872cff20 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 26 Jun 2025 14:11:17 +0200 Subject: [PATCH] chore: move fur structs, small provider changes --- go.mod | 3 - internal/config_v2/config.go | 326 ++++++++++++++++++++++++----- internal/config_v2/config_test.go | 4 +- internal/config_v2/provider.go | 4 +- internal/fur/client/client.go | 63 ++++++ internal/fur/provider/provider.go | 72 +++++++ internal/llm/agent/agent.go | 24 +-- internal/llm/provider/anthropic.go | 73 ++----- internal/llm/provider/bedrock.go | 19 +- internal/llm/provider/gemini.go | 19 -- internal/llm/provider/openai.go | 47 +---- internal/llm/provider/provider.go | 72 +++---- internal/llm/provider/vertexai.go | 6 - 13 files changed, 477 insertions(+), 255 deletions(-) create mode 100644 internal/fur/client/client.go create mode 100644 internal/fur/provider/provider.go diff --git a/go.mod b/go.mod index 8bc77cf3ffe7cdd96131027fe09f5b8f1a50796a..99e808c14a24cd34d5274c74eba183229f51dd07 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/charmbracelet/crush go 1.24.3 -replace github.com/charmbracelet/fur => ../fur - require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 github.com/JohannesKaufmann/html-to-markdown v1.6.0 @@ -17,7 +15,6 @@ require ( github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.3.0.20250609143341-c76fa36f1b94 github.com/charmbracelet/fang v0.1.0 - github.com/charmbracelet/fur v0.0.0-00010101000000-000000000000 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 diff --git a/internal/config_v2/config.go b/internal/config_v2/config.go index 4ab12a83fe6de3e94105cf00d4045f652dd26cae..0de27aa1b16cf71e97c655561a5c8ce20bae0838 100644 --- a/internal/config_v2/config.go +++ b/internal/config_v2/config.go @@ -10,8 +10,8 @@ import ( "strings" "sync" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/logging" - "github.com/charmbracelet/fur/pkg/provider" ) const ( @@ -22,6 +22,29 @@ const ( MaxTokensFallbackDefault = 4096 ) +var defaultContextPaths = []string{ + ".github/copilot-instructions.md", + ".cursorrules", + ".cursor/rules/", + "CLAUDE.md", + "CLAUDE.local.md", + "crush.md", + "crush.local.md", + "Crush.md", + "Crush.local.md", + "CRUSH.md", + "CRUSH.local.md", +} + +type AgentID string + +const ( + AgentCoder AgentID = "coder" + AgentTask AgentID = "task" + AgentTitle AgentID = "title" + AgentSummarize AgentID = "summarize" +) + type Model struct { ID string `json:"id"` Name string `json:"model"` @@ -43,40 +66,43 @@ type VertexAIOptions struct { } type ProviderConfig struct { - BaseURL string `json:"base_url,omitempty"` - ProviderType provider.Type `json:"provider_type"` - APIKey string `json:"api_key,omitempty"` - Disabled bool `json:"disabled"` - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + ID provider.InferenceProvider `json:"id"` + BaseURL string `json:"base_url,omitempty"` + ProviderType provider.Type `json:"provider_type"` + APIKey string `json:"api_key,omitempty"` + Disabled bool `json:"disabled"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // used for e.x for vertex to set the project ExtraParams map[string]string `json:"extra_params,omitempty"` - DefaultModel string `json:"default_model"` + DefaultLargeModel string `json:"default_large_model,omitempty"` + DefaultSmallModel string `json:"default_small_model,omitempty"` + + Models []Model `json:"models,omitempty"` } type Agent struct { - Name string `json:"name"` + Name string `json:"name"` + Description string `json:"description,omitempty"` // This is the id of the system prompt used by the agent - // TODO: still needs to be implemented - PromptID string `json:"prompt_id"` - Disabled bool `json:"disabled"` + Disabled bool `json:"disabled"` Provider provider.InferenceProvider `json:"provider"` - Model Model `json:"model"` + Model string `json:"model"` // The available tools for the agent - // if this is empty, all tools are available + // if this is nil, all tools are available AllowedTools []string `json:"allowed_tools"` // this tells us which MCPs are available for this agent // if this is empty all mcps are available - // the string array is the list of tools from the MCP the agent has available - // if the string array is empty, all tools from the MCP are available - MCP map[string][]string `json:"mcp"` + // the string array is the list of tools from the AllowedMCP the agent has available + // if the string array is nil, all tools from the AllowedMCP are available + AllowedMCP map[string][]string `json:"allowed_mcp"` // The list of LSPs that this agent can use - // if this is empty, all LSPs are available - LSP []string `json:"lsp"` + // if this is nil, all LSPs are available + AllowedLSP []string `json:"allowed_lsp"` // Overrides the context paths for this agent ContextPaths []string `json:"context_paths"` @@ -125,7 +151,7 @@ type Config struct { Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"` // List of configured agents - Agents map[string]Agent `json:"agents,omitempty"` + Agents map[AgentID]Agent `json:"agents,omitempty"` // List of configured MCPs MCP map[string]MCP `json:"mcp,omitempty"` @@ -135,15 +161,13 @@ type Config struct { // Miscellaneous options Options Options `json:"options"` - - // Used to add models that are not already in the repository - Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"` } var ( instance *Config // The single instance of the Singleton cwd string once sync.Once // Ensures the initialization happens only once + ) func loadConfig(cwd string) (*Config, error) { @@ -190,10 +214,73 @@ func loadConfig(cwd string) (*Config, error) { } // merge options - cfg.Options = mergeOptions(cfg.Options, globalCfg.Options) - cfg.Options = mergeOptions(cfg.Options, localConfig.Options) + mergeOptions(cfg, globalCfg, localConfig) mergeProviderConfigs(cfg, globalCfg, localConfig) + // no providers found the app is not initialized yet + if len(cfg.Providers) == 0 { + return cfg, nil + } + preferredProvider := getPreferredProvider(cfg.Providers) + + if preferredProvider == nil { + return nil, errors.New("no valid providers configured") + } + + agents := map[AgentID]Agent{ + AgentCoder: { + Name: "Coder", + Description: "An agent that helps with executing coding tasks.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultLargeModel, + ContextPaths: cfg.Options.ContextPaths, + // All tools allowed + }, + AgentTask: { + Name: "Task", + Description: "An agent that helps with searching for context and finding implementation details.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultLargeModel, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{ + "glob", + "grep", + "ls", + "sourcegraph", + "view", + }, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + AgentTitle: { + Name: "Title", + Description: "An agent that helps with generating titles for sessions.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultSmallModel, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{}, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + AgentSummarize: { + Name: "Summarize", + Description: "An agent that helps with summarizing sessions.", + Provider: preferredProvider.ID, + Model: preferredProvider.DefaultSmallModel, + ContextPaths: cfg.Options.ContextPaths, + AllowedTools: []string{}, + // NO MCPs or LSPs by default + AllowedMCP: map[string][]string{}, + AllowedLSP: []string{}, + }, + } + cfg.Agents = agents + mergeAgents(cfg, globalCfg, localConfig) + mergeMCPs(cfg, globalCfg, localConfig) + mergeLSPs(cfg, globalCfg, localConfig) + return cfg, nil } @@ -219,6 +306,22 @@ func GetConfig() *Config { return instance } +func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig { + providers := Providers() + for _, p := range providers { + if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled { + return &providerConfig + } + } + // if none found return the first configured provider + for _, providerConfig := range configuredProviders { + if !providerConfig.Disabled { + return &providerConfig + } + } + return nil +} + func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig { if other.APIKey != "" { base.APIKey = other.APIKey @@ -249,6 +352,26 @@ func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfi base.Disabled = other.Disabled } + if other.DefaultLargeModel != "" { + base.DefaultLargeModel = other.DefaultLargeModel + } + // Add new models if they don't exist + if other.Models != nil { + for _, model := range other.Models { + // check if the model already exists + exists := false + for _, existingModel := range base.Models { + if existingModel.ID == model.ID { + exists = true + break + } + } + if !exists { + base.Models = append(base.Models, model) + } + } + } + return base } @@ -267,39 +390,114 @@ func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfi return nil } -func mergeOptions(base, other Options) Options { - result := base +func mergeOptions(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + baseOptions := base.Options + other := cfg.Options + if len(other.ContextPaths) > 0 { + baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...) + } - if len(other.ContextPaths) > 0 { - base.ContextPaths = append(base.ContextPaths, other.ContextPaths...) - } + if other.TUI.CompactMode { + baseOptions.TUI.CompactMode = other.TUI.CompactMode + } - if other.TUI.CompactMode { - result.TUI.CompactMode = other.TUI.CompactMode - } + if other.Debug { + baseOptions.Debug = other.Debug + } - if other.Debug { - result.Debug = other.Debug - } + if other.DebugLSP { + baseOptions.DebugLSP = other.DebugLSP + } - if other.DebugLSP { - result.DebugLSP = other.DebugLSP + if other.DisableAutoSummarize { + baseOptions.DisableAutoSummarize = other.DisableAutoSummarize + } + + if other.DataDirectory != "" { + baseOptions.DataDirectory = other.DataDirectory + } + base.Options = baseOptions } +} - if other.DisableAutoSummarize { - result.DisableAutoSummarize = other.DisableAutoSummarize +func mergeAgents(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + for agentID, globalAgent := range cfg.Agents { + if _, ok := base.Agents[agentID]; !ok { + base.Agents[agentID] = globalAgent + } else { + switch agentID { + case AgentCoder: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + baseAgent.AllowedMCP = globalAgent.AllowedMCP + baseAgent.AllowedLSP = globalAgent.AllowedLSP + base.Agents[agentID] = baseAgent + case AgentTask: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + base.Agents[agentID] = baseAgent + case AgentTitle: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + base.Agents[agentID] = baseAgent + case AgentSummarize: + baseAgent := base.Agents[agentID] + baseAgent.Model = globalAgent.Model + baseAgent.Provider = globalAgent.Provider + base.Agents[agentID] = baseAgent + default: + baseAgent := base.Agents[agentID] + baseAgent.Name = globalAgent.Name + baseAgent.Description = globalAgent.Description + baseAgent.Disabled = globalAgent.Disabled + baseAgent.Provider = globalAgent.Provider + baseAgent.Model = globalAgent.Model + baseAgent.AllowedTools = globalAgent.AllowedTools + baseAgent.AllowedMCP = globalAgent.AllowedMCP + baseAgent.AllowedLSP = globalAgent.AllowedLSP + base.Agents[agentID] = baseAgent + + } + } + } } +} - if other.DataDirectory != "" { - result.DataDirectory = other.DataDirectory +func mergeMCPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + maps.Copy(base.MCP, cfg.MCP) } +} - return result +func mergeLSPs(base, global, local *Config) { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + maps.Copy(base.LSP, cfg.LSP) + } } func mergeProviderConfigs(base, global, local *Config) { - if global != nil { - for providerName, globalProvider := range global.Providers { + for _, cfg := range []*Config{global, local} { + if cfg == nil { + continue + } + for providerName, globalProvider := range cfg.Providers { if _, ok := base.Providers[providerName]; !ok { base.Providers[providerName] = globalProvider } else { @@ -307,15 +505,6 @@ func mergeProviderConfigs(base, global, local *Config) { } } } - if local != nil { - for providerName, localProvider := range local.Providers { - if _, ok := base.Providers[providerName]; !ok { - base.Providers[providerName] = localProvider - } else { - base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider) - } - } - } finalProviders := make(map[provider.InferenceProvider]ProviderConfig) for providerName, providerConfig := range base.Providers { @@ -328,30 +517,36 @@ func mergeProviderConfigs(base, global, local *Config) { base.Providers = finalProviders } -func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig { - switch providerName { +func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig { + switch providerId { case provider.InferenceProviderAnthropic: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeAnthropic, } case provider.InferenceProviderOpenAI: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeOpenAI, } case provider.InferenceProviderGemini: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeGemini, } case provider.InferenceProviderBedrock: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeBedrock, } case provider.InferenceProviderAzure: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeAzure, } case provider.InferenceProviderOpenRouter: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeOpenAI, BaseURL: "https://openrouter.ai/api/v1", ExtraHeaders: map[string]string{ @@ -361,15 +556,18 @@ func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConf } case provider.InferenceProviderXAI: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeXAI, BaseURL: "https://api.x.ai/v1", } case provider.InferenceProviderVertexAI: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeVertexAI, } default: return ProviderConfig{ + ID: providerId, ProviderType: provider.TypeOpenAI, } } @@ -379,6 +577,7 @@ func defaultConfigBasedOnEnv() *Config { cfg := &Config{ Options: Options{ DataDirectory: defaultDataDirectory, + ContextPaths: defaultContextPaths, }, Providers: make(map[provider.InferenceProvider]ProviderConfig), } @@ -391,7 +590,22 @@ func defaultConfigBasedOnEnv() *Config { if apiKey := os.Getenv(envVar); apiKey != "" { providerConfig := providerDefaultConfig(p.ID) providerConfig.APIKey = apiKey - providerConfig.DefaultModel = p.DefaultModelID + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + for _, model := range p.Models { + providerConfig.Models = append(providerConfig.Models, Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + }) + } cfg.Providers[p.ID] = providerConfig } } diff --git a/internal/config_v2/config_test.go b/internal/config_v2/config_test.go index 50b829271dcd42213141ecf2b9b72f5890480668..9bcfcdc78375e1a3a35726b513f04e3cb1e2c3b3 100644 --- a/internal/config_v2/config_test.go +++ b/internal/config_v2/config_test.go @@ -1,6 +1,7 @@ package configv2 import ( + "encoding/json" "fmt" "os" "testing" @@ -28,6 +29,7 @@ func TestConfigWithEnv(t *testing.T) { os.Setenv("XAI_API_KEY", "test-xai-key") os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") cfg := InitConfig(cwdDir) - fmt.Println(cfg) + data, _ := json.MarshalIndent(cfg, "", " ") + fmt.Println(string(data)) assert.Len(t, cfg.Providers, 5) } diff --git a/internal/config_v2/provider.go b/internal/config_v2/provider.go index 94fe2d44d74e3039dcdeaa0dc76e95b840a03125..ec6b5bdb701876af4705c9e78fcc55a87646edd2 100644 --- a/internal/config_v2/provider.go +++ b/internal/config_v2/provider.go @@ -6,8 +6,8 @@ import ( "path/filepath" "sync" - "github.com/charmbracelet/fur/pkg/client" - "github.com/charmbracelet/fur/pkg/provider" + "github.com/charmbracelet/crush/internal/fur/client" + "github.com/charmbracelet/crush/internal/fur/provider" ) var fur = client.New() diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go new file mode 100644 index 0000000000000000000000000000000000000000..263e8317ce8ac92d8820ba5288f2e40d2616e0e1 --- /dev/null +++ b/internal/fur/client/client.go @@ -0,0 +1,63 @@ +// Package client provides a client for interacting with the fur service. +package client + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + + "github.com/charmbracelet/crush/internal/fur/provider" +) + +const defaultURL = "http://localhost:8080" + +// Client represents a client for the fur service. +type Client struct { + baseURL string + httpClient *http.Client +} + +// New creates a new client instance +// Uses FUR_URL environment variable or falls back to localhost:8080. +func New() *Client { + baseURL := os.Getenv("FUR_URL") + if baseURL == "" { + baseURL = defaultURL + } + + return &Client{ + baseURL: baseURL, + httpClient: &http.Client{}, + } +} + +// NewWithURL creates a new client with a specific URL. +func NewWithURL(url string) *Client { + return &Client{ + baseURL: url, + httpClient: &http.Client{}, + } +} + +// GetProviders retrieves all available providers from the service. +func (c *Client) GetProviders() ([]provider.Provider, error) { + url := fmt.Sprintf("%s/providers", c.baseURL) + + resp, err := c.httpClient.Get(url) //nolint:noctx + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var providers []provider.Provider + if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return providers, nil +} diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..85275f1155eff219c87d85fce3cdcc436f4a4e47 --- /dev/null +++ b/internal/fur/provider/provider.go @@ -0,0 +1,72 @@ +// Package provider provides types and constants for AI providers. +package provider + +// Type represents the type of AI provider. +type Type string + +// All the supported AI provider types. +const ( + TypeOpenAI Type = "openai" + TypeAnthropic Type = "anthropic" + TypeGemini Type = "gemini" + TypeAzure Type = "azure" + TypeBedrock Type = "bedrock" + TypeVertexAI Type = "vertexai" + TypeXAI Type = "xai" + TypeOpenRouter Type = "openrouter" +) + +// InferenceProvider represents the inference provider identifier. +type InferenceProvider string + +// All the inference providers supported by the system. +const ( + InferenceProviderOpenAI InferenceProvider = "openai" + InferenceProviderAnthropic InferenceProvider = "anthropic" + InferenceProviderGemini InferenceProvider = "gemini" + InferenceProviderAzure InferenceProvider = "azure" + InferenceProviderBedrock InferenceProvider = "bedrock" + InferenceProviderVertexAI InferenceProvider = "vertexai" + InferenceProviderXAI InferenceProvider = "xai" + InferenceProviderOpenRouter InferenceProvider = "openrouter" +) + +// Provider represents an AI provider configuration. +type Provider struct { + Name string `json:"name"` + ID InferenceProvider `json:"id"` + APIKey string `json:"api_key,omitempty"` + APIEndpoint string `json:"api_endpoint,omitempty"` + Type Type `json:"type,omitempty"` + DefaultLargeModelID string `json:"default_large_model_id,omitempty"` + DefaultSmallModelID string `json:"default_small_model_id,omitempty"` + Models []Model `json:"models,omitempty"` +} + +// Model represents an AI model configuration. +type Model struct { + ID string `json:"id"` + Name string `json:"model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + SupportsImages bool `json:"supports_attachments"` +} + +// KnownProviders returns all the known inference providers. +func KnownProviders() []InferenceProvider { + return []InferenceProvider{ + InferenceProviderOpenAI, + InferenceProviderAnthropic, + InferenceProviderGemini, + InferenceProviderAzure, + InferenceProviderBedrock, + InferenceProviderVertexAI, + InferenceProviderXAI, + InferenceProviderOpenRouter, + } +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 26d952441317607323e1170d6a06559f9173605d..ea2a3bd2b11735c1f0422e859adcfa65a82fdb98 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -734,21 +734,15 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), provider.WithMaxTokens(maxTokens), } - if (model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal) && model.CanReason { - opts = append( - opts, - provider.WithOpenAIOptions( - provider.WithReasoningEffort(agentConfig.ReasoningEffort), - ), - ) - } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder { - opts = append( - opts, - provider.WithAnthropicOptions( - provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn), - ), - ) - } + // TODO: reimplement + // if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason { + // opts = append( + // opts, + // provider.WithOpenAIOptions( + // provider.WithReasoningEffort(agentConfig.ReasoningEffort), + // ), + // ) + // } agentProvider, err := provider.NewProvider( model.Provider, opts..., diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 634040ebc3dad8d6dc9e7642ebbe95ac3b051c63..709a56263e0a8880d444c8ee7e9cab1373e67344 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -19,40 +19,25 @@ import ( "github.com/charmbracelet/crush/internal/message" ) -type anthropicOptions struct { - useBedrock bool - disableCache bool - shouldThink func(userMessage string) bool -} - -type AnthropicOption func(*anthropicOptions) - type anthropicClient struct { providerOptions providerClientOptions - options anthropicOptions client anthropic.Client } type AnthropicClient ProviderClient -func newAnthropicClient(opts providerClientOptions) AnthropicClient { - anthropicOpts := anthropicOptions{} - for _, o := range opts.anthropicOptions { - o(&anthropicOpts) - } - +func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient { anthropicClientOptions := []option.RequestOption{} if opts.apiKey != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) } - if anthropicOpts.useBedrock { + if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } client := anthropic.NewClient(anthropicClientOptions...) return &anthropicClient{ providerOptions: opts, - options: anthropicOpts, client: client, } } @@ -66,7 +51,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic switch msg.Role { case message.User: content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.options.disableCache { + if cache && !a.providerOptions.disableCache { content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -84,7 +69,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic blocks := []anthropic.ContentBlockParamUnion{} if msg.Content().String() != "" { content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.options.disableCache { + if cache && !a.providerOptions.disableCache { content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -132,7 +117,7 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU }, } - if i == len(tools)-1 && !a.options.disableCache { + if i == len(tools)-1 && !a.providerOptions.disableCache { toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -161,21 +146,22 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason { func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { var thinkingParam anthropic.ThinkingConfigParamUnion - lastMessage := messages[len(messages)-1] - isUser := lastMessage.Role == anthropic.MessageParamRoleUser - messageContent := "" + // TODO: Implement a proper thinking function + // lastMessage := messages[len(messages)-1] + // isUser := lastMessage.Role == anthropic.MessageParamRoleUser + // messageContent := "" temperature := anthropic.Float(0) - if isUser { - for _, m := range lastMessage.Content { - if m.OfText != nil && m.OfText.Text != "" { - messageContent = m.OfText.Text - } - } - if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) { - thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8)) - temperature = anthropic.Float(1) - } - } + // if isUser { + // for _, m := range lastMessage.Content { + // if m.OfText != nil && m.OfText.Text != "" { + // messageContent = m.OfText.Text + // } + // } + // if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) { + // thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8)) + // temperature = anthropic.Float(1) + // } + // } return anthropic.MessageNewParams{ Model: anthropic.Model(a.providerOptions.model.APIModel), @@ -439,24 +425,7 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { } } -func WithAnthropicBedrock(useBedrock bool) AnthropicOption { - return func(options *anthropicOptions) { - options.useBedrock = useBedrock - } -} - -func WithAnthropicDisableCache() AnthropicOption { - return func(options *anthropicOptions) { - options.disableCache = true - } -} - +// TODO: check if we need func DefaultShouldThinkFn(s string) bool { return strings.Contains(strings.ToLower(s), "think") } - -func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption { - return func(options *anthropicOptions) { - options.shouldThink = fn - } -} diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 8d3a86198aab5a38742e33b167f2545efd808873..8db9c1e84a4e8496be77e69e612de4abb9ce0c07 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -11,22 +11,14 @@ import ( "github.com/charmbracelet/crush/internal/message" ) -type bedrockOptions struct { - // Bedrock specific options can be added here -} - -type BedrockOption func(*bedrockOptions) - type bedrockClient struct { providerOptions providerClientOptions - options bedrockOptions childProvider ProviderClient } type BedrockClient ProviderClient func newBedrockClient(opts providerClientOptions) BedrockClient { - bedrockOpts := bedrockOptions{} // Apply bedrock specific options if they are added in the future // Get AWS region from environment @@ -41,7 +33,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { if len(region) < 2 { return &bedrockClient{ providerOptions: opts, - options: bedrockOpts, childProvider: nil, // Will cause an error when used } } @@ -55,14 +46,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { if strings.Contains(string(opts.model.APIModel), "anthropic") { // Create Anthropic client with Bedrock configuration anthropicOpts := opts - anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions, - WithAnthropicBedrock(true), - WithAnthropicDisableCache(), - ) + // TODO: later find a way to check if the AWS account has caching enabled + opts.disableCache = true // Disable cache for Bedrock return &bedrockClient{ providerOptions: opts, - options: bedrockOpts, - childProvider: newAnthropicClient(anthropicOpts), + childProvider: newAnthropicClient(anthropicOpts, true), } } @@ -70,7 +58,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { // This will cause an error when used return &bedrockClient{ providerOptions: opts, - options: bedrockOpts, childProvider: nil, } } diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 9481d8d545aab12a3739fe99b4af61f4ed99a514..dd54dac4491634de06a31ee00f1ffd13ea935076 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -17,26 +17,14 @@ import ( "google.golang.org/genai" ) -type geminiOptions struct { - disableCache bool -} - -type GeminiOption func(*geminiOptions) - type geminiClient struct { providerOptions providerClientOptions - options geminiOptions client *genai.Client } type GeminiClient ProviderClient func newGeminiClient(opts providerClientOptions) GeminiClient { - geminiOpts := geminiOptions{} - for _, o := range opts.geminiOptions { - o(&geminiOpts) - } - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) if err != nil { logging.Error("Failed to create Gemini client", "error", err) @@ -45,7 +33,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient { return &geminiClient{ providerOptions: opts, - options: geminiOpts, client: client, } } @@ -452,12 +439,6 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } -func WithGeminiDisableCache() GeminiOption { - return func(options *geminiOptions) { - options.disableCache = true - } -} - // Helper functions func parseJsonToMap(jsonStr string) (map[string]any, error) { var result map[string]any diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 05658dd6db760a1d05a88ae4931de5c70d9cc453..334312f9e8c41f5d68251d9e7bbd890074fa3982 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -19,14 +19,9 @@ import ( ) type openaiOptions struct { - baseURL string - disableCache bool reasoningEffort string - extraHeaders map[string]string } -type OpenAIOption func(*openaiOptions) - type openaiClient struct { providerOptions providerClientOptions options openaiOptions @@ -39,20 +34,17 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient { openaiOpts := openaiOptions{ reasoningEffort: "medium", } - for _, o := range opts.openaiOptions { - o(&openaiOpts) - } openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } - if openaiOpts.baseURL != "" { - openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL)) + if opts.baseURL != "" { + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL)) } - if openaiOpts.extraHeaders != nil { - for key, value := range openaiOpts.extraHeaders { + if opts.extraHeaders != nil { + for key, value := range opts.extraHeaders { openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) } } @@ -392,34 +384,3 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { CacheReadTokens: cachedTokens, } } - -func WithOpenAIBaseURL(baseURL string) OpenAIOption { - return func(options *openaiOptions) { - options.baseURL = baseURL - } -} - -func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption { - return func(options *openaiOptions) { - options.extraHeaders = headers - } -} - -func WithOpenAIDisableCache() OpenAIOption { - return func(options *openaiOptions) { - options.disableCache = true - } -} - -func WithReasoningEffort(effort string) OpenAIOption { - return func(options *openaiOptions) { - defaultReasoningEffort := "medium" - switch effort { - case "low", "medium", "high": - defaultReasoningEffort = effort - default: - logging.Warn("Invalid reasoning effort, using default: medium") - } - options.reasoningEffort = defaultReasoningEffort - } -} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 40c7317fba3eb944ba83421bcee8bf1702882fcb..86c47c2e0c24f2f99d91eb51c946da7bbf90dfa0 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "maps" "os" "github.com/charmbracelet/crush/internal/llm/models" @@ -59,15 +60,13 @@ type Provider interface { } type providerClientOptions struct { + baseURL string apiKey string model models.Model + disableCache bool maxTokens int64 systemMessage string - - anthropicOptions []AnthropicOption - openaiOptions []OpenAIOption - geminiOptions []GeminiOption - bedrockOptions []BedrockOption + extraHeaders map[string]string } type ProviderClientOption func(*providerClientOptions) @@ -91,7 +90,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp case models.ProviderAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, - client: newAnthropicClient(clientOptions), + client: newAnthropicClient(clientOptions, false), }, nil case models.ProviderOpenAI: return &baseProvider[OpenAIClient]{ @@ -109,9 +108,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp client: newBedrockClient(clientOptions), }, nil case models.ProviderGROQ: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) + clientOptions.baseURL = "https://api.groq.com/openai/v1" return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), @@ -127,29 +124,23 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp client: newVertexAIClient(clientOptions), }, nil case models.ProviderOpenRouter: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL("https://openrouter.ai/api/v1"), - WithOpenAIExtraHeaders(map[string]string{ - "HTTP-Referer": "crush.charm.land", - "X-Title": "Crush", - }), - ) + clientOptions.baseURL = "https://openrouter.ai/api/v1" + clientOptions.extraHeaders = map[string]string{ + "HTTP-Referer": "crush.charm.land", + "X-Title": "Crush", + } return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderXAI: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL("https://api.x.ai/v1"), - ) + clientOptions.baseURL = "https://api.x.ai/v1" return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil case models.ProviderLocal: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")), - ) + clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT") return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), @@ -186,50 +177,47 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func WithAPIKey(apiKey string) ProviderClientOption { - return func(options *providerClientOptions) { - options.apiKey = apiKey - } -} - -func WithModel(model models.Model) ProviderClientOption { +func WithBaseURL(baseURL string) ProviderClientOption { return func(options *providerClientOptions) { - options.model = model + options.baseURL = baseURL } } -func WithMaxTokens(maxTokens int64) ProviderClientOption { +func WithAPIKey(apiKey string) ProviderClientOption { return func(options *providerClientOptions) { - options.maxTokens = maxTokens + options.apiKey = apiKey } } -func WithSystemMessage(systemMessage string) ProviderClientOption { +func WithModel(model models.Model) ProviderClientOption { return func(options *providerClientOptions) { - options.systemMessage = systemMessage + options.model = model } } -func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption { +func WithDisableCache(disableCache bool) ProviderClientOption { return func(options *providerClientOptions) { - options.anthropicOptions = anthropicOptions + options.disableCache = disableCache } } -func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption { +func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption { return func(options *providerClientOptions) { - options.openaiOptions = openaiOptions + if options.extraHeaders == nil { + options.extraHeaders = make(map[string]string) + } + maps.Copy(options.extraHeaders, extraHeaders) } } -func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption { +func WithMaxTokens(maxTokens int64) ProviderClientOption { return func(options *providerClientOptions) { - options.geminiOptions = geminiOptions + options.maxTokens = maxTokens } } -func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { +func WithSystemMessage(systemMessage string) ProviderClientOption { return func(options *providerClientOptions) { - options.bedrockOptions = bedrockOptions + options.systemMessage = systemMessage } } diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index 0e01ef9c3a082faea86bd6c76d6e9d53c9f8a933..49374d33fa81ab42e9f0c4d6e7905bfa37a6154e 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -11,11 +11,6 @@ import ( type VertexAIClient ProviderClient func newVertexAIClient(opts providerClientOptions) VertexAIClient { - geminiOpts := geminiOptions{} - for _, o := range opts.geminiOptions { - o(&geminiOpts) - } - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ Project: os.Getenv("GOOGLE_CLOUD_PROJECT"), Location: os.Getenv("GOOGLE_CLOUD_LOCATION"), @@ -28,7 +23,6 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient { return &geminiClient{ providerOptions: opts, - options: geminiOpts, client: client, } }