Detailed changes
@@ -2,8 +2,6 @@ module github.com/charmbracelet/crush
go 1.24.3
-replace github.com/charmbracelet/fur => ../fur
-
require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/JohannesKaufmann/html-to-markdown v1.6.0
@@ -17,7 +15,6 @@ require (
github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e
github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.3.0.20250609143341-c76fa36f1b94
github.com/charmbracelet/fang v0.1.0
- github.com/charmbracelet/fur v0.0.0-00010101000000-000000000000
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413
@@ -10,8 +10,8 @@ import (
"strings"
"sync"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/logging"
- "github.com/charmbracelet/fur/pkg/provider"
)
const (
@@ -22,6 +22,29 @@ const (
MaxTokensFallbackDefault = 4096
)
+var defaultContextPaths = []string{
+ ".github/copilot-instructions.md",
+ ".cursorrules",
+ ".cursor/rules/",
+ "CLAUDE.md",
+ "CLAUDE.local.md",
+ "crush.md",
+ "crush.local.md",
+ "Crush.md",
+ "Crush.local.md",
+ "CRUSH.md",
+ "CRUSH.local.md",
+}
+
+type AgentID string
+
+const (
+ AgentCoder AgentID = "coder"
+ AgentTask AgentID = "task"
+ AgentTitle AgentID = "title"
+ AgentSummarize AgentID = "summarize"
+)
+
type Model struct {
ID string `json:"id"`
Name string `json:"model"`
@@ -43,40 +66,43 @@ type VertexAIOptions struct {
}
type ProviderConfig struct {
- BaseURL string `json:"base_url,omitempty"`
- ProviderType provider.Type `json:"provider_type"`
- APIKey string `json:"api_key,omitempty"`
- Disabled bool `json:"disabled"`
- ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ ID provider.InferenceProvider `json:"id"`
+ BaseURL string `json:"base_url,omitempty"`
+ ProviderType provider.Type `json:"provider_type"`
+ APIKey string `json:"api_key,omitempty"`
+ Disabled bool `json:"disabled"`
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
// used for e.x for vertex to set the project
ExtraParams map[string]string `json:"extra_params,omitempty"`
- DefaultModel string `json:"default_model"`
+ DefaultLargeModel string `json:"default_large_model,omitempty"`
+ DefaultSmallModel string `json:"default_small_model,omitempty"`
+
+ Models []Model `json:"models,omitempty"`
}
type Agent struct {
- Name string `json:"name"`
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
// This is the id of the system prompt used by the agent
- // TODO: still needs to be implemented
- PromptID string `json:"prompt_id"`
- Disabled bool `json:"disabled"`
+ Disabled bool `json:"disabled"`
Provider provider.InferenceProvider `json:"provider"`
- Model Model `json:"model"`
+ Model string `json:"model"`
// The available tools for the agent
- // if this is empty, all tools are available
+ // if this is nil, all tools are available
AllowedTools []string `json:"allowed_tools"`
// this tells us which MCPs are available for this agent
// if this is empty all mcps are available
- // the string array is the list of tools from the MCP the agent has available
- // if the string array is empty, all tools from the MCP are available
- MCP map[string][]string `json:"mcp"`
+ // the string array is the list of tools from the AllowedMCP the agent has available
+ // if the string array is nil, all tools from the AllowedMCP are available
+ AllowedMCP map[string][]string `json:"allowed_mcp"`
// The list of LSPs that this agent can use
- // if this is empty, all LSPs are available
- LSP []string `json:"lsp"`
+ // if this is nil, all LSPs are available
+ AllowedLSP []string `json:"allowed_lsp"`
// Overrides the context paths for this agent
ContextPaths []string `json:"context_paths"`
@@ -125,7 +151,7 @@ type Config struct {
Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
// List of configured agents
- Agents map[string]Agent `json:"agents,omitempty"`
+ Agents map[AgentID]Agent `json:"agents,omitempty"`
// List of configured MCPs
MCP map[string]MCP `json:"mcp,omitempty"`
@@ -135,15 +161,13 @@ type Config struct {
// Miscellaneous options
Options Options `json:"options"`
-
- // Used to add models that are not already in the repository
- Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"`
}
var (
instance *Config // The single instance of the Singleton
cwd string
once sync.Once // Ensures the initialization happens only once
+
)
func loadConfig(cwd string) (*Config, error) {
@@ -190,10 +214,73 @@ func loadConfig(cwd string) (*Config, error) {
}
// merge options
- cfg.Options = mergeOptions(cfg.Options, globalCfg.Options)
- cfg.Options = mergeOptions(cfg.Options, localConfig.Options)
+ mergeOptions(cfg, globalCfg, localConfig)
mergeProviderConfigs(cfg, globalCfg, localConfig)
+ // no providers found the app is not initialized yet
+ if len(cfg.Providers) == 0 {
+ return cfg, nil
+ }
+ preferredProvider := getPreferredProvider(cfg.Providers)
+
+ if preferredProvider == nil {
+ return nil, errors.New("no valid providers configured")
+ }
+
+ agents := map[AgentID]Agent{
+ AgentCoder: {
+ Name: "Coder",
+ Description: "An agent that helps with executing coding tasks.",
+ Provider: preferredProvider.ID,
+ Model: preferredProvider.DefaultLargeModel,
+ ContextPaths: cfg.Options.ContextPaths,
+ // All tools allowed
+ },
+ AgentTask: {
+ Name: "Task",
+ Description: "An agent that helps with searching for context and finding implementation details.",
+ Provider: preferredProvider.ID,
+ Model: preferredProvider.DefaultLargeModel,
+ ContextPaths: cfg.Options.ContextPaths,
+ AllowedTools: []string{
+ "glob",
+ "grep",
+ "ls",
+ "sourcegraph",
+ "view",
+ },
+ // NO MCPs or LSPs by default
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ AgentTitle: {
+ Name: "Title",
+ Description: "An agent that helps with generating titles for sessions.",
+ Provider: preferredProvider.ID,
+ Model: preferredProvider.DefaultSmallModel,
+ ContextPaths: cfg.Options.ContextPaths,
+ AllowedTools: []string{},
+ // NO MCPs or LSPs by default
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ AgentSummarize: {
+ Name: "Summarize",
+ Description: "An agent that helps with summarizing sessions.",
+ Provider: preferredProvider.ID,
+ Model: preferredProvider.DefaultSmallModel,
+ ContextPaths: cfg.Options.ContextPaths,
+ AllowedTools: []string{},
+ // NO MCPs or LSPs by default
+ AllowedMCP: map[string][]string{},
+ AllowedLSP: []string{},
+ },
+ }
+ cfg.Agents = agents
+ mergeAgents(cfg, globalCfg, localConfig)
+ mergeMCPs(cfg, globalCfg, localConfig)
+ mergeLSPs(cfg, globalCfg, localConfig)
+
return cfg, nil
}
@@ -219,6 +306,22 @@ func GetConfig() *Config {
return instance
}
+func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
+ providers := Providers()
+ for _, p := range providers {
+ if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
+ return &providerConfig
+ }
+ }
+ // if none found return the first configured provider
+ for _, providerConfig := range configuredProviders {
+ if !providerConfig.Disabled {
+ return &providerConfig
+ }
+ }
+ return nil
+}
+
func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
if other.APIKey != "" {
base.APIKey = other.APIKey
@@ -249,6 +352,26 @@ func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfi
base.Disabled = other.Disabled
}
+ if other.DefaultLargeModel != "" {
+ base.DefaultLargeModel = other.DefaultLargeModel
+ }
+ // Add new models if they don't exist
+ if other.Models != nil {
+ for _, model := range other.Models {
+ // check if the model already exists
+ exists := false
+ for _, existingModel := range base.Models {
+ if existingModel.ID == model.ID {
+ exists = true
+ break
+ }
+ }
+ if !exists {
+ base.Models = append(base.Models, model)
+ }
+ }
+ }
+
return base
}
@@ -267,39 +390,114 @@ func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfi
return nil
}
-func mergeOptions(base, other Options) Options {
- result := base
+func mergeOptions(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ baseOptions := base.Options
+ other := cfg.Options
+ if len(other.ContextPaths) > 0 {
+ baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
+ }
- if len(other.ContextPaths) > 0 {
- base.ContextPaths = append(base.ContextPaths, other.ContextPaths...)
- }
+ if other.TUI.CompactMode {
+ baseOptions.TUI.CompactMode = other.TUI.CompactMode
+ }
- if other.TUI.CompactMode {
- result.TUI.CompactMode = other.TUI.CompactMode
- }
+ if other.Debug {
+ baseOptions.Debug = other.Debug
+ }
- if other.Debug {
- result.Debug = other.Debug
- }
+ if other.DebugLSP {
+ baseOptions.DebugLSP = other.DebugLSP
+ }
- if other.DebugLSP {
- result.DebugLSP = other.DebugLSP
+ if other.DisableAutoSummarize {
+ baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
+ }
+
+ if other.DataDirectory != "" {
+ baseOptions.DataDirectory = other.DataDirectory
+ }
+ base.Options = baseOptions
}
+}
- if other.DisableAutoSummarize {
- result.DisableAutoSummarize = other.DisableAutoSummarize
+func mergeAgents(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ for agentID, globalAgent := range cfg.Agents {
+ if _, ok := base.Agents[agentID]; !ok {
+ base.Agents[agentID] = globalAgent
+ } else {
+ switch agentID {
+ case AgentCoder:
+ baseAgent := base.Agents[agentID]
+ baseAgent.Model = globalAgent.Model
+ baseAgent.Provider = globalAgent.Provider
+ baseAgent.AllowedMCP = globalAgent.AllowedMCP
+ baseAgent.AllowedLSP = globalAgent.AllowedLSP
+ base.Agents[agentID] = baseAgent
+ case AgentTask:
+ baseAgent := base.Agents[agentID]
+ baseAgent.Model = globalAgent.Model
+ baseAgent.Provider = globalAgent.Provider
+ base.Agents[agentID] = baseAgent
+ case AgentTitle:
+ baseAgent := base.Agents[agentID]
+ baseAgent.Model = globalAgent.Model
+ baseAgent.Provider = globalAgent.Provider
+ base.Agents[agentID] = baseAgent
+ case AgentSummarize:
+ baseAgent := base.Agents[agentID]
+ baseAgent.Model = globalAgent.Model
+ baseAgent.Provider = globalAgent.Provider
+ base.Agents[agentID] = baseAgent
+ default:
+ baseAgent := base.Agents[agentID]
+ baseAgent.Name = globalAgent.Name
+ baseAgent.Description = globalAgent.Description
+ baseAgent.Disabled = globalAgent.Disabled
+ baseAgent.Provider = globalAgent.Provider
+ baseAgent.Model = globalAgent.Model
+ baseAgent.AllowedTools = globalAgent.AllowedTools
+ baseAgent.AllowedMCP = globalAgent.AllowedMCP
+ baseAgent.AllowedLSP = globalAgent.AllowedLSP
+ base.Agents[agentID] = baseAgent
+
+ }
+ }
+ }
}
+}
- if other.DataDirectory != "" {
- result.DataDirectory = other.DataDirectory
+func mergeMCPs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ maps.Copy(base.MCP, cfg.MCP)
}
+}
- return result
+func mergeLSPs(base, global, local *Config) {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ maps.Copy(base.LSP, cfg.LSP)
+ }
}
func mergeProviderConfigs(base, global, local *Config) {
- if global != nil {
- for providerName, globalProvider := range global.Providers {
+ for _, cfg := range []*Config{global, local} {
+ if cfg == nil {
+ continue
+ }
+ for providerName, globalProvider := range cfg.Providers {
if _, ok := base.Providers[providerName]; !ok {
base.Providers[providerName] = globalProvider
} else {
@@ -307,15 +505,6 @@ func mergeProviderConfigs(base, global, local *Config) {
}
}
}
- if local != nil {
- for providerName, localProvider := range local.Providers {
- if _, ok := base.Providers[providerName]; !ok {
- base.Providers[providerName] = localProvider
- } else {
- base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider)
- }
- }
- }
finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
for providerName, providerConfig := range base.Providers {
@@ -328,30 +517,36 @@ func mergeProviderConfigs(base, global, local *Config) {
base.Providers = finalProviders
}
-func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig {
- switch providerName {
+func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
+ switch providerId {
case provider.InferenceProviderAnthropic:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeAnthropic,
}
case provider.InferenceProviderOpenAI:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeOpenAI,
}
case provider.InferenceProviderGemini:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeGemini,
}
case provider.InferenceProviderBedrock:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeBedrock,
}
case provider.InferenceProviderAzure:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeAzure,
}
case provider.InferenceProviderOpenRouter:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeOpenAI,
BaseURL: "https://openrouter.ai/api/v1",
ExtraHeaders: map[string]string{
@@ -361,15 +556,18 @@ func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConf
}
case provider.InferenceProviderXAI:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeXAI,
BaseURL: "https://api.x.ai/v1",
}
case provider.InferenceProviderVertexAI:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeVertexAI,
}
default:
return ProviderConfig{
+ ID: providerId,
ProviderType: provider.TypeOpenAI,
}
}
@@ -379,6 +577,7 @@ func defaultConfigBasedOnEnv() *Config {
cfg := &Config{
Options: Options{
DataDirectory: defaultDataDirectory,
+ ContextPaths: defaultContextPaths,
},
Providers: make(map[provider.InferenceProvider]ProviderConfig),
}
@@ -391,7 +590,22 @@ func defaultConfigBasedOnEnv() *Config {
if apiKey := os.Getenv(envVar); apiKey != "" {
providerConfig := providerDefaultConfig(p.ID)
providerConfig.APIKey = apiKey
- providerConfig.DefaultModel = p.DefaultModelID
+ providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+ providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+ for _, model := range p.Models {
+ providerConfig.Models = append(providerConfig.Models, Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ SupportsImages: model.SupportsImages,
+ })
+ }
cfg.Providers[p.ID] = providerConfig
}
}
@@ -1,6 +1,7 @@
package configv2
import (
+ "encoding/json"
"fmt"
"os"
"testing"
@@ -28,6 +29,7 @@ func TestConfigWithEnv(t *testing.T) {
os.Setenv("XAI_API_KEY", "test-xai-key")
os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
cfg := InitConfig(cwdDir)
- fmt.Println(cfg)
+ data, _ := json.MarshalIndent(cfg, "", " ")
+ fmt.Println(string(data))
assert.Len(t, cfg.Providers, 5)
}
@@ -6,8 +6,8 @@ import (
"path/filepath"
"sync"
- "github.com/charmbracelet/fur/pkg/client"
- "github.com/charmbracelet/fur/pkg/provider"
+ "github.com/charmbracelet/crush/internal/fur/client"
+ "github.com/charmbracelet/crush/internal/fur/provider"
)
var fur = client.New()
@@ -0,0 +1,63 @@
+// Package client provides a client for interacting with the fur service.
+package client
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "os"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+)
+
+const defaultURL = "http://localhost:8080"
+
+// Client represents a client for the fur service.
+type Client struct {
+ baseURL string
+ httpClient *http.Client
+}
+
+// New creates a new client instance
+// Uses FUR_URL environment variable or falls back to localhost:8080.
+func New() *Client {
+ baseURL := os.Getenv("FUR_URL")
+ if baseURL == "" {
+ baseURL = defaultURL
+ }
+
+ return &Client{
+ baseURL: baseURL,
+ httpClient: &http.Client{},
+ }
+}
+
+// NewWithURL creates a new client with a specific URL.
+func NewWithURL(url string) *Client {
+ return &Client{
+ baseURL: url,
+ httpClient: &http.Client{},
+ }
+}
+
+// GetProviders retrieves all available providers from the service.
+func (c *Client) GetProviders() ([]provider.Provider, error) {
+ url := fmt.Sprintf("%s/providers", c.baseURL)
+
+ resp, err := c.httpClient.Get(url) //nolint:noctx
+ if err != nil {
+ return nil, fmt.Errorf("failed to make request: %w", err)
+ }
+ defer resp.Body.Close() //nolint:errcheck
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ var providers []provider.Provider
+ if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil {
+ return nil, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ return providers, nil
+}
@@ -0,0 +1,72 @@
+// Package provider provides types and constants for AI providers.
+package provider
+
+// Type represents the type of AI provider.
+type Type string
+
+// All the supported AI provider types.
+const (
+ TypeOpenAI Type = "openai"
+ TypeAnthropic Type = "anthropic"
+ TypeGemini Type = "gemini"
+ TypeAzure Type = "azure"
+ TypeBedrock Type = "bedrock"
+ TypeVertexAI Type = "vertexai"
+ TypeXAI Type = "xai"
+ TypeOpenRouter Type = "openrouter"
+)
+
+// InferenceProvider represents the inference provider identifier.
+type InferenceProvider string
+
+// All the inference providers supported by the system.
+const (
+ InferenceProviderOpenAI InferenceProvider = "openai"
+ InferenceProviderAnthropic InferenceProvider = "anthropic"
+ InferenceProviderGemini InferenceProvider = "gemini"
+ InferenceProviderAzure InferenceProvider = "azure"
+ InferenceProviderBedrock InferenceProvider = "bedrock"
+ InferenceProviderVertexAI InferenceProvider = "vertexai"
+ InferenceProviderXAI InferenceProvider = "xai"
+ InferenceProviderOpenRouter InferenceProvider = "openrouter"
+)
+
+// Provider represents an AI provider configuration.
+type Provider struct {
+ Name string `json:"name"`
+ ID InferenceProvider `json:"id"`
+ APIKey string `json:"api_key,omitempty"`
+ APIEndpoint string `json:"api_endpoint,omitempty"`
+ Type Type `json:"type,omitempty"`
+ DefaultLargeModelID string `json:"default_large_model_id,omitempty"`
+ DefaultSmallModelID string `json:"default_small_model_id,omitempty"`
+ Models []Model `json:"models,omitempty"`
+}
+
+// Model represents an AI model configuration.
+type Model struct {
+ ID string `json:"id"`
+ Name string `json:"model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in"`
+ CostPer1MOut float64 `json:"cost_per_1m_out"`
+ CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
+ CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+ ContextWindow int64 `json:"context_window"`
+ DefaultMaxTokens int64 `json:"default_max_tokens"`
+ CanReason bool `json:"can_reason"`
+ SupportsImages bool `json:"supports_attachments"`
+}
+
+// KnownProviders returns all the known inference providers.
+func KnownProviders() []InferenceProvider {
+ return []InferenceProvider{
+ InferenceProviderOpenAI,
+ InferenceProviderAnthropic,
+ InferenceProviderGemini,
+ InferenceProviderAzure,
+ InferenceProviderBedrock,
+ InferenceProviderVertexAI,
+ InferenceProviderXAI,
+ InferenceProviderOpenRouter,
+ }
+}
@@ -734,21 +734,15 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
provider.WithMaxTokens(maxTokens),
}
- if (model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal) && model.CanReason {
- opts = append(
- opts,
- provider.WithOpenAIOptions(
- provider.WithReasoningEffort(agentConfig.ReasoningEffort),
- ),
- )
- } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
- opts = append(
- opts,
- provider.WithAnthropicOptions(
- provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
- ),
- )
- }
+ // TODO: reimplement
+ // if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
+ // opts = append(
+ // opts,
+ // provider.WithOpenAIOptions(
+ // provider.WithReasoningEffort(agentConfig.ReasoningEffort),
+ // ),
+ // )
+ // }
agentProvider, err := provider.NewProvider(
model.Provider,
opts...,
@@ -19,40 +19,25 @@ import (
"github.com/charmbracelet/crush/internal/message"
)
-type anthropicOptions struct {
- useBedrock bool
- disableCache bool
- shouldThink func(userMessage string) bool
-}
-
-type AnthropicOption func(*anthropicOptions)
-
type anthropicClient struct {
providerOptions providerClientOptions
- options anthropicOptions
client anthropic.Client
}
type AnthropicClient ProviderClient
-func newAnthropicClient(opts providerClientOptions) AnthropicClient {
- anthropicOpts := anthropicOptions{}
- for _, o := range opts.anthropicOptions {
- o(&anthropicOpts)
- }
-
+func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
anthropicClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
}
- if anthropicOpts.useBedrock {
+ if useBedrock {
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
}
client := anthropic.NewClient(anthropicClientOptions...)
return &anthropicClient{
providerOptions: opts,
- options: anthropicOpts,
client: client,
}
}
@@ -66,7 +51,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content().String())
- if cache && !a.options.disableCache {
+ if cache && !a.providerOptions.disableCache {
content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -84,7 +69,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
- if cache && !a.options.disableCache {
+ if cache && !a.providerOptions.disableCache {
content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -132,7 +117,7 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU
},
}
- if i == len(tools)-1 && !a.options.disableCache {
+ if i == len(tools)-1 && !a.providerOptions.disableCache {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -161,21 +146,22 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason {
func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
var thinkingParam anthropic.ThinkingConfigParamUnion
- lastMessage := messages[len(messages)-1]
- isUser := lastMessage.Role == anthropic.MessageParamRoleUser
- messageContent := ""
+ // TODO: Implement a proper thinking function
+ // lastMessage := messages[len(messages)-1]
+ // isUser := lastMessage.Role == anthropic.MessageParamRoleUser
+ // messageContent := ""
temperature := anthropic.Float(0)
- if isUser {
- for _, m := range lastMessage.Content {
- if m.OfText != nil && m.OfText.Text != "" {
- messageContent = m.OfText.Text
- }
- }
- if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) {
- thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
- temperature = anthropic.Float(1)
- }
- }
+ // if isUser {
+ // for _, m := range lastMessage.Content {
+ // if m.OfText != nil && m.OfText.Text != "" {
+ // messageContent = m.OfText.Text
+ // }
+ // }
+ // if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) {
+ // thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
+ // temperature = anthropic.Float(1)
+ // }
+ // }
return anthropic.MessageNewParams{
Model: anthropic.Model(a.providerOptions.model.APIModel),
@@ -439,24 +425,7 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
}
}
-func WithAnthropicBedrock(useBedrock bool) AnthropicOption {
- return func(options *anthropicOptions) {
- options.useBedrock = useBedrock
- }
-}
-
-func WithAnthropicDisableCache() AnthropicOption {
- return func(options *anthropicOptions) {
- options.disableCache = true
- }
-}
-
+// TODO: check if we need
func DefaultShouldThinkFn(s string) bool {
return strings.Contains(strings.ToLower(s), "think")
}
-
-func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption {
- return func(options *anthropicOptions) {
- options.shouldThink = fn
- }
-}
@@ -11,22 +11,14 @@ import (
"github.com/charmbracelet/crush/internal/message"
)
-type bedrockOptions struct {
- // Bedrock specific options can be added here
-}
-
-type BedrockOption func(*bedrockOptions)
-
type bedrockClient struct {
providerOptions providerClientOptions
- options bedrockOptions
childProvider ProviderClient
}
type BedrockClient ProviderClient
func newBedrockClient(opts providerClientOptions) BedrockClient {
- bedrockOpts := bedrockOptions{}
// Apply bedrock specific options if they are added in the future
// Get AWS region from environment
@@ -41,7 +33,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
if len(region) < 2 {
return &bedrockClient{
providerOptions: opts,
- options: bedrockOpts,
childProvider: nil, // Will cause an error when used
}
}
@@ -55,14 +46,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
if strings.Contains(string(opts.model.APIModel), "anthropic") {
// Create Anthropic client with Bedrock configuration
anthropicOpts := opts
- anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
- WithAnthropicBedrock(true),
- WithAnthropicDisableCache(),
- )
+ // TODO: later find a way to check if the AWS account has caching enabled
+ opts.disableCache = true // Disable cache for Bedrock
return &bedrockClient{
providerOptions: opts,
- options: bedrockOpts,
- childProvider: newAnthropicClient(anthropicOpts),
+ childProvider: newAnthropicClient(anthropicOpts, true),
}
}
@@ -70,7 +58,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
// This will cause an error when used
return &bedrockClient{
providerOptions: opts,
- options: bedrockOpts,
childProvider: nil,
}
}
@@ -17,26 +17,14 @@ import (
"google.golang.org/genai"
)
-type geminiOptions struct {
- disableCache bool
-}
-
-type GeminiOption func(*geminiOptions)
-
type geminiClient struct {
providerOptions providerClientOptions
- options geminiOptions
client *genai.Client
}
type GeminiClient ProviderClient
func newGeminiClient(opts providerClientOptions) GeminiClient {
- geminiOpts := geminiOptions{}
- for _, o := range opts.geminiOptions {
- o(&geminiOpts)
- }
-
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
if err != nil {
logging.Error("Failed to create Gemini client", "error", err)
@@ -45,7 +33,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
return &geminiClient{
providerOptions: opts,
- options: geminiOpts,
client: client,
}
}
@@ -452,12 +439,6 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
}
}
-func WithGeminiDisableCache() GeminiOption {
- return func(options *geminiOptions) {
- options.disableCache = true
- }
-}
-
// Helper functions
func parseJsonToMap(jsonStr string) (map[string]any, error) {
var result map[string]any
@@ -19,14 +19,9 @@ import (
)
type openaiOptions struct {
- baseURL string
- disableCache bool
reasoningEffort string
- extraHeaders map[string]string
}
-type OpenAIOption func(*openaiOptions)
-
type openaiClient struct {
providerOptions providerClientOptions
options openaiOptions
@@ -39,20 +34,17 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient {
openaiOpts := openaiOptions{
reasoningEffort: "medium",
}
- for _, o := range opts.openaiOptions {
- o(&openaiOpts)
- }
openaiClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
- if openaiOpts.baseURL != "" {
- openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
+ if opts.baseURL != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
}
- if openaiOpts.extraHeaders != nil {
- for key, value := range openaiOpts.extraHeaders {
+ if opts.extraHeaders != nil {
+ for key, value := range opts.extraHeaders {
openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
}
}
@@ -392,34 +384,3 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
CacheReadTokens: cachedTokens,
}
}
-
-func WithOpenAIBaseURL(baseURL string) OpenAIOption {
- return func(options *openaiOptions) {
- options.baseURL = baseURL
- }
-}
-
-func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
- return func(options *openaiOptions) {
- options.extraHeaders = headers
- }
-}
-
-func WithOpenAIDisableCache() OpenAIOption {
- return func(options *openaiOptions) {
- options.disableCache = true
- }
-}
-
-func WithReasoningEffort(effort string) OpenAIOption {
- return func(options *openaiOptions) {
- defaultReasoningEffort := "medium"
- switch effort {
- case "low", "medium", "high":
- defaultReasoningEffort = effort
- default:
- logging.Warn("Invalid reasoning effort, using default: medium")
- }
- options.reasoningEffort = defaultReasoningEffort
- }
-}
@@ -3,6 +3,7 @@ package provider
import (
"context"
"fmt"
+ "maps"
"os"
"github.com/charmbracelet/crush/internal/llm/models"
@@ -59,15 +60,13 @@ type Provider interface {
}
type providerClientOptions struct {
+ baseURL string
apiKey string
model models.Model
+ disableCache bool
maxTokens int64
systemMessage string
-
- anthropicOptions []AnthropicOption
- openaiOptions []OpenAIOption
- geminiOptions []GeminiOption
- bedrockOptions []BedrockOption
+ extraHeaders map[string]string
}
type ProviderClientOption func(*providerClientOptions)
@@ -91,7 +90,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp
case models.ProviderAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
- client: newAnthropicClient(clientOptions),
+ client: newAnthropicClient(clientOptions, false),
}, nil
case models.ProviderOpenAI:
return &baseProvider[OpenAIClient]{
@@ -109,9 +108,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp
client: newBedrockClient(clientOptions),
}, nil
case models.ProviderGROQ:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
- )
+ clientOptions.baseURL = "https://api.groq.com/openai/v1"
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
@@ -127,29 +124,23 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp
client: newVertexAIClient(clientOptions),
}, nil
case models.ProviderOpenRouter:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
- WithOpenAIExtraHeaders(map[string]string{
- "HTTP-Referer": "crush.charm.land",
- "X-Title": "Crush",
- }),
- )
+ clientOptions.baseURL = "https://openrouter.ai/api/v1"
+ clientOptions.extraHeaders = map[string]string{
+ "HTTP-Referer": "crush.charm.land",
+ "X-Title": "Crush",
+ }
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderXAI:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL("https://api.x.ai/v1"),
- )
+ clientOptions.baseURL = "https://api.x.ai/v1"
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderLocal:
- clientOptions.openaiOptions = append(clientOptions.openaiOptions,
- WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
- )
+ clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT")
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
@@ -186,50 +177,47 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
return p.client.stream(ctx, messages, tools)
}
-func WithAPIKey(apiKey string) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.apiKey = apiKey
- }
-}
-
-func WithModel(model models.Model) ProviderClientOption {
+func WithBaseURL(baseURL string) ProviderClientOption {
return func(options *providerClientOptions) {
- options.model = model
+ options.baseURL = baseURL
}
}
-func WithMaxTokens(maxTokens int64) ProviderClientOption {
+func WithAPIKey(apiKey string) ProviderClientOption {
return func(options *providerClientOptions) {
- options.maxTokens = maxTokens
+ options.apiKey = apiKey
}
}
-func WithSystemMessage(systemMessage string) ProviderClientOption {
+func WithModel(model models.Model) ProviderClientOption {
return func(options *providerClientOptions) {
- options.systemMessage = systemMessage
+ options.model = model
}
}
-func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
+func WithDisableCache(disableCache bool) ProviderClientOption {
return func(options *providerClientOptions) {
- options.anthropicOptions = anthropicOptions
+ options.disableCache = disableCache
}
}
-func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
+func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption {
return func(options *providerClientOptions) {
- options.openaiOptions = openaiOptions
+ if options.extraHeaders == nil {
+ options.extraHeaders = make(map[string]string)
+ }
+ maps.Copy(options.extraHeaders, extraHeaders)
}
}
-func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
+func WithMaxTokens(maxTokens int64) ProviderClientOption {
return func(options *providerClientOptions) {
- options.geminiOptions = geminiOptions
+ options.maxTokens = maxTokens
}
}
-func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
+func WithSystemMessage(systemMessage string) ProviderClientOption {
return func(options *providerClientOptions) {
- options.bedrockOptions = bedrockOptions
+ options.systemMessage = systemMessage
}
}
@@ -11,11 +11,6 @@ import (
type VertexAIClient ProviderClient
func newVertexAIClient(opts providerClientOptions) VertexAIClient {
- geminiOpts := geminiOptions{}
- for _, o := range opts.geminiOptions {
- o(&geminiOpts)
- }
-
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
Project: os.Getenv("GOOGLE_CLOUD_PROJECT"),
Location: os.Getenv("GOOGLE_CLOUD_LOCATION"),
@@ -28,7 +23,6 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
return &geminiClient{
providerOptions: opts,
- options: geminiOpts,
client: client,
}
}