chore: move fur structs, small provider changes

Kujtim Hoxha created

Change summary

go.mod                             |   3 
internal/config_v2/config.go       | 326 ++++++++++++++++++++++++++-----
internal/config_v2/config_test.go  |   4 
internal/config_v2/provider.go     |   4 
internal/fur/client/client.go      |  63 ++++++
internal/fur/provider/provider.go  |  72 +++++++
internal/llm/agent/agent.go        |  24 -
internal/llm/provider/anthropic.go |  73 ++-----
internal/llm/provider/bedrock.go   |  19 -
internal/llm/provider/gemini.go    |  19 -
internal/llm/provider/openai.go    |  47 ----
internal/llm/provider/provider.go  |  72 ++----
internal/llm/provider/vertexai.go  |   6 
13 files changed, 477 insertions(+), 255 deletions(-)

Detailed changes

go.mod 🔗

@@ -2,8 +2,6 @@ module github.com/charmbracelet/crush
 
 go 1.24.3
 
-replace github.com/charmbracelet/fur => ../fur
-
 require (
 	github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
 	github.com/JohannesKaufmann/html-to-markdown v1.6.0
@@ -17,7 +15,6 @@ require (
 	github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250607113720-eb5e1cf3b09e
 	github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.3.0.20250609143341-c76fa36f1b94
 	github.com/charmbracelet/fang v0.1.0
-	github.com/charmbracelet/fur v0.0.0-00010101000000-000000000000
 	github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
 	github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c
 	github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413

internal/config_v2/config.go 🔗

@@ -10,8 +10,8 @@ import (
 	"strings"
 	"sync"
 
+	"github.com/charmbracelet/crush/internal/fur/provider"
 	"github.com/charmbracelet/crush/internal/logging"
-	"github.com/charmbracelet/fur/pkg/provider"
 )
 
 const (
@@ -22,6 +22,29 @@ const (
 	MaxTokensFallbackDefault = 4096
 )
 
+var defaultContextPaths = []string{
+	".github/copilot-instructions.md",
+	".cursorrules",
+	".cursor/rules/",
+	"CLAUDE.md",
+	"CLAUDE.local.md",
+	"crush.md",
+	"crush.local.md",
+	"Crush.md",
+	"Crush.local.md",
+	"CRUSH.md",
+	"CRUSH.local.md",
+}
+
+type AgentID string
+
+const (
+	AgentCoder     AgentID = "coder"
+	AgentTask      AgentID = "task"
+	AgentTitle     AgentID = "title"
+	AgentSummarize AgentID = "summarize"
+)
+
 type Model struct {
 	ID                 string  `json:"id"`
 	Name               string  `json:"model"`
@@ -43,40 +66,43 @@ type VertexAIOptions struct {
 }
 
 type ProviderConfig struct {
-	BaseURL      string            `json:"base_url,omitempty"`
-	ProviderType provider.Type     `json:"provider_type"`
-	APIKey       string            `json:"api_key,omitempty"`
-	Disabled     bool              `json:"disabled"`
-	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+	ID           provider.InferenceProvider `json:"id"`
+	BaseURL      string                     `json:"base_url,omitempty"`
+	ProviderType provider.Type              `json:"provider_type"`
+	APIKey       string                     `json:"api_key,omitempty"`
+	Disabled     bool                       `json:"disabled"`
+	ExtraHeaders map[string]string          `json:"extra_headers,omitempty"`
 	// used for e.x for vertex to set the project
 	ExtraParams map[string]string `json:"extra_params,omitempty"`
 
-	DefaultModel string `json:"default_model"`
+	DefaultLargeModel string `json:"default_large_model,omitempty"`
+	DefaultSmallModel string `json:"default_small_model,omitempty"`
+
+	Models []Model `json:"models,omitempty"`
 }
 
 type Agent struct {
-	Name string `json:"name"`
+	Name        string `json:"name"`
+	Description string `json:"description,omitempty"`
 	// This is the id of the system prompt used by the agent
-	//  TODO: still needs to be implemented
-	PromptID string `json:"prompt_id"`
-	Disabled bool   `json:"disabled"`
+	Disabled bool `json:"disabled"`
 
 	Provider provider.InferenceProvider `json:"provider"`
-	Model    Model                      `json:"model"`
+	Model    string                     `json:"model"`
 
 	// The available tools for the agent
-	//  if this is empty, all tools are available
+	//  if this is nil, all tools are available
 	AllowedTools []string `json:"allowed_tools"`
 
 	// this tells us which MCPs are available for this agent
 	//  if this is empty all mcps are available
-	//  the string array is the list of tools from the MCP the agent has available
-	//  if the string array is empty, all tools from the MCP are available
-	MCP map[string][]string `json:"mcp"`
+	//  the string array is the list of tools from the AllowedMCP the agent has available
+	//  if the string array is nil, all tools from the AllowedMCP are available
+	AllowedMCP map[string][]string `json:"allowed_mcp"`
 
 	// The list of LSPs that this agent can use
-	//  if this is empty, all LSPs are available
-	LSP []string `json:"lsp"`
+	//  if this is nil, all LSPs are available
+	AllowedLSP []string `json:"allowed_lsp"`
 
 	// Overrides the context paths for this agent
 	ContextPaths []string `json:"context_paths"`
@@ -125,7 +151,7 @@ type Config struct {
 	Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
 
 	// List of configured agents
-	Agents map[string]Agent `json:"agents,omitempty"`
+	Agents map[AgentID]Agent `json:"agents,omitempty"`
 
 	// List of configured MCPs
 	MCP map[string]MCP `json:"mcp,omitempty"`
@@ -135,15 +161,13 @@ type Config struct {
 
 	// Miscellaneous options
 	Options Options `json:"options"`
-
-	// Used to add models that are not already in the repository
-	Models map[provider.InferenceProvider][]provider.Model `json:"models,omitempty"`
 }
 
 var (
 	instance *Config // The single instance of the Singleton
 	cwd      string
 	once     sync.Once // Ensures the initialization happens only once
+
 )
 
 func loadConfig(cwd string) (*Config, error) {
@@ -190,10 +214,73 @@ func loadConfig(cwd string) (*Config, error) {
 	}
 
 	// merge options
-	cfg.Options = mergeOptions(cfg.Options, globalCfg.Options)
-	cfg.Options = mergeOptions(cfg.Options, localConfig.Options)
+	mergeOptions(cfg, globalCfg, localConfig)
 
 	mergeProviderConfigs(cfg, globalCfg, localConfig)
+	// no providers found the app is not initialized yet
+	if len(cfg.Providers) == 0 {
+		return cfg, nil
+	}
+	preferredProvider := getPreferredProvider(cfg.Providers)
+
+	if preferredProvider == nil {
+		return nil, errors.New("no valid providers configured")
+	}
+
+	agents := map[AgentID]Agent{
+		AgentCoder: {
+			Name:         "Coder",
+			Description:  "An agent that helps with executing coding tasks.",
+			Provider:     preferredProvider.ID,
+			Model:        preferredProvider.DefaultLargeModel,
+			ContextPaths: cfg.Options.ContextPaths,
+			// All tools allowed
+		},
+		AgentTask: {
+			Name:         "Task",
+			Description:  "An agent that helps with searching for context and finding implementation details.",
+			Provider:     preferredProvider.ID,
+			Model:        preferredProvider.DefaultLargeModel,
+			ContextPaths: cfg.Options.ContextPaths,
+			AllowedTools: []string{
+				"glob",
+				"grep",
+				"ls",
+				"sourcegraph",
+				"view",
+			},
+			// NO MCPs or LSPs by default
+			AllowedMCP: map[string][]string{},
+			AllowedLSP: []string{},
+		},
+		AgentTitle: {
+			Name:         "Title",
+			Description:  "An agent that helps with generating titles for sessions.",
+			Provider:     preferredProvider.ID,
+			Model:        preferredProvider.DefaultSmallModel,
+			ContextPaths: cfg.Options.ContextPaths,
+			AllowedTools: []string{},
+			// NO MCPs or LSPs by default
+			AllowedMCP: map[string][]string{},
+			AllowedLSP: []string{},
+		},
+		AgentSummarize: {
+			Name:         "Summarize",
+			Description:  "An agent that helps with summarizing sessions.",
+			Provider:     preferredProvider.ID,
+			Model:        preferredProvider.DefaultSmallModel,
+			ContextPaths: cfg.Options.ContextPaths,
+			AllowedTools: []string{},
+			// NO MCPs or LSPs by default
+			AllowedMCP: map[string][]string{},
+			AllowedLSP: []string{},
+		},
+	}
+	cfg.Agents = agents
+	mergeAgents(cfg, globalCfg, localConfig)
+	mergeMCPs(cfg, globalCfg, localConfig)
+	mergeLSPs(cfg, globalCfg, localConfig)
+
 	return cfg, nil
 }
 
@@ -219,6 +306,22 @@ func GetConfig() *Config {
 	return instance
 }
 
+func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
+	providers := Providers()
+	for _, p := range providers {
+		if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
+			return &providerConfig
+		}
+	}
+	// if none found return the first configured provider
+	for _, providerConfig := range configuredProviders {
+		if !providerConfig.Disabled {
+			return &providerConfig
+		}
+	}
+	return nil
+}
+
 func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
 	if other.APIKey != "" {
 		base.APIKey = other.APIKey
@@ -249,6 +352,26 @@ func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfi
 		base.Disabled = other.Disabled
 	}
 
+	if other.DefaultLargeModel != "" {
+		base.DefaultLargeModel = other.DefaultLargeModel
+	}
+	// Add new models if they don't exist
+	if other.Models != nil {
+		for _, model := range other.Models {
+			// check if the model already exists
+			exists := false
+			for _, existingModel := range base.Models {
+				if existingModel.ID == model.ID {
+					exists = true
+					break
+				}
+			}
+			if !exists {
+				base.Models = append(base.Models, model)
+			}
+		}
+	}
+
 	return base
 }
 
@@ -267,39 +390,114 @@ func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfi
 	return nil
 }
 
-func mergeOptions(base, other Options) Options {
-	result := base
+func mergeOptions(base, global, local *Config) {
+	for _, cfg := range []*Config{global, local} {
+		if cfg == nil {
+			continue
+		}
+		baseOptions := base.Options
+		other := cfg.Options
+		if len(other.ContextPaths) > 0 {
+			baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
+		}
 
-	if len(other.ContextPaths) > 0 {
-		base.ContextPaths = append(base.ContextPaths, other.ContextPaths...)
-	}
+		if other.TUI.CompactMode {
+			baseOptions.TUI.CompactMode = other.TUI.CompactMode
+		}
 
-	if other.TUI.CompactMode {
-		result.TUI.CompactMode = other.TUI.CompactMode
-	}
+		if other.Debug {
+			baseOptions.Debug = other.Debug
+		}
 
-	if other.Debug {
-		result.Debug = other.Debug
-	}
+		if other.DebugLSP {
+			baseOptions.DebugLSP = other.DebugLSP
+		}
 
-	if other.DebugLSP {
-		result.DebugLSP = other.DebugLSP
+		if other.DisableAutoSummarize {
+			baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
+		}
+
+		if other.DataDirectory != "" {
+			baseOptions.DataDirectory = other.DataDirectory
+		}
+		base.Options = baseOptions
 	}
+}
 
-	if other.DisableAutoSummarize {
-		result.DisableAutoSummarize = other.DisableAutoSummarize
+func mergeAgents(base, global, local *Config) {
+	for _, cfg := range []*Config{global, local} {
+		if cfg == nil {
+			continue
+		}
+		for agentID, globalAgent := range cfg.Agents {
+			if _, ok := base.Agents[agentID]; !ok {
+				base.Agents[agentID] = globalAgent
+			} else {
+				switch agentID {
+				case AgentCoder:
+					baseAgent := base.Agents[agentID]
+					baseAgent.Model = globalAgent.Model
+					baseAgent.Provider = globalAgent.Provider
+					baseAgent.AllowedMCP = globalAgent.AllowedMCP
+					baseAgent.AllowedLSP = globalAgent.AllowedLSP
+					base.Agents[agentID] = baseAgent
+				case AgentTask:
+					baseAgent := base.Agents[agentID]
+					baseAgent.Model = globalAgent.Model
+					baseAgent.Provider = globalAgent.Provider
+					base.Agents[agentID] = baseAgent
+				case AgentTitle:
+					baseAgent := base.Agents[agentID]
+					baseAgent.Model = globalAgent.Model
+					baseAgent.Provider = globalAgent.Provider
+					base.Agents[agentID] = baseAgent
+				case AgentSummarize:
+					baseAgent := base.Agents[agentID]
+					baseAgent.Model = globalAgent.Model
+					baseAgent.Provider = globalAgent.Provider
+					base.Agents[agentID] = baseAgent
+				default:
+					baseAgent := base.Agents[agentID]
+					baseAgent.Name = globalAgent.Name
+					baseAgent.Description = globalAgent.Description
+					baseAgent.Disabled = globalAgent.Disabled
+					baseAgent.Provider = globalAgent.Provider
+					baseAgent.Model = globalAgent.Model
+					baseAgent.AllowedTools = globalAgent.AllowedTools
+					baseAgent.AllowedMCP = globalAgent.AllowedMCP
+					baseAgent.AllowedLSP = globalAgent.AllowedLSP
+					base.Agents[agentID] = baseAgent
+
+				}
+			}
+		}
 	}
+}
 
-	if other.DataDirectory != "" {
-		result.DataDirectory = other.DataDirectory
+func mergeMCPs(base, global, local *Config) {
+	for _, cfg := range []*Config{global, local} {
+		if cfg == nil {
+			continue
+		}
+		maps.Copy(base.MCP, cfg.MCP)
 	}
+}
 
-	return result
+func mergeLSPs(base, global, local *Config) {
+	for _, cfg := range []*Config{global, local} {
+		if cfg == nil {
+			continue
+		}
+		maps.Copy(base.LSP, cfg.LSP)
+	}
 }
 
 func mergeProviderConfigs(base, global, local *Config) {
-	if global != nil {
-		for providerName, globalProvider := range global.Providers {
+	for _, cfg := range []*Config{global, local} {
+		if cfg == nil {
+			continue
+		}
+		for providerName, globalProvider := range cfg.Providers {
 			if _, ok := base.Providers[providerName]; !ok {
 				base.Providers[providerName] = globalProvider
 			} else {
@@ -307,15 +505,6 @@ func mergeProviderConfigs(base, global, local *Config) {
 			}
 		}
 	}
-	if local != nil {
-		for providerName, localProvider := range local.Providers {
-			if _, ok := base.Providers[providerName]; !ok {
-				base.Providers[providerName] = localProvider
-			} else {
-				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], localProvider)
-			}
-		}
-	}
 
 	finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
 	for providerName, providerConfig := range base.Providers {
@@ -328,30 +517,36 @@ func mergeProviderConfigs(base, global, local *Config) {
 	base.Providers = finalProviders
 }
 
-func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConfig {
-	switch providerName {
+func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
+	switch providerId {
 	case provider.InferenceProviderAnthropic:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeAnthropic,
 		}
 	case provider.InferenceProviderOpenAI:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeOpenAI,
 		}
 	case provider.InferenceProviderGemini:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeGemini,
 		}
 	case provider.InferenceProviderBedrock:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeBedrock,
 		}
 	case provider.InferenceProviderAzure:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeAzure,
 		}
 	case provider.InferenceProviderOpenRouter:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeOpenAI,
 			BaseURL:      "https://openrouter.ai/api/v1",
 			ExtraHeaders: map[string]string{
@@ -361,15 +556,18 @@ func providerDefaultConfig(providerName provider.InferenceProvider) ProviderConf
 		}
 	case provider.InferenceProviderXAI:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeXAI,
 			BaseURL:      "https://api.x.ai/v1",
 		}
 	case provider.InferenceProviderVertexAI:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeVertexAI,
 		}
 	default:
 		return ProviderConfig{
+			ID:           providerId,
 			ProviderType: provider.TypeOpenAI,
 		}
 	}
@@ -379,6 +577,7 @@ func defaultConfigBasedOnEnv() *Config {
 	cfg := &Config{
 		Options: Options{
 			DataDirectory: defaultDataDirectory,
+			ContextPaths:  defaultContextPaths,
 		},
 		Providers: make(map[provider.InferenceProvider]ProviderConfig),
 	}
@@ -391,7 +590,22 @@ func defaultConfigBasedOnEnv() *Config {
 			if apiKey := os.Getenv(envVar); apiKey != "" {
 				providerConfig := providerDefaultConfig(p.ID)
 				providerConfig.APIKey = apiKey
-				providerConfig.DefaultModel = p.DefaultModelID
+				providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+				providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+				for _, model := range p.Models {
+					providerConfig.Models = append(providerConfig.Models, Model{
+						ID:                 model.ID,
+						Name:               model.Name,
+						CostPer1MIn:        model.CostPer1MIn,
+						CostPer1MOut:       model.CostPer1MOut,
+						CostPer1MInCached:  model.CostPer1MInCached,
+						CostPer1MOutCached: model.CostPer1MOutCached,
+						ContextWindow:      model.ContextWindow,
+						DefaultMaxTokens:   model.DefaultMaxTokens,
+						CanReason:          model.CanReason,
+						SupportsImages:     model.SupportsImages,
+					})
+				}
 				cfg.Providers[p.ID] = providerConfig
 			}
 		}

internal/config_v2/config_test.go 🔗

@@ -1,6 +1,7 @@
 package configv2
 
 import (
+	"encoding/json"
 	"fmt"
 	"os"
 	"testing"
@@ -28,6 +29,7 @@ func TestConfigWithEnv(t *testing.T) {
 	os.Setenv("XAI_API_KEY", "test-xai-key")
 	os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
 	cfg := InitConfig(cwdDir)
-	fmt.Println(cfg)
+	data, _ := json.MarshalIndent(cfg, "", "  ")
+	fmt.Println(string(data))
 	assert.Len(t, cfg.Providers, 5)
 }

internal/config_v2/provider.go 🔗

@@ -6,8 +6,8 @@ import (
 	"path/filepath"
 	"sync"
 
-	"github.com/charmbracelet/fur/pkg/client"
-	"github.com/charmbracelet/fur/pkg/provider"
+	"github.com/charmbracelet/crush/internal/fur/client"
+	"github.com/charmbracelet/crush/internal/fur/provider"
 )
 
 var fur = client.New()

internal/fur/client/client.go 🔗

@@ -0,0 +1,63 @@
+// Package client provides a client for interacting with the fur service.
+package client
+
+import (
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"os"
+
+	"github.com/charmbracelet/crush/internal/fur/provider"
+)
+
+const defaultURL = "http://localhost:8080"
+
+// Client represents a client for the fur service.
+type Client struct {
+	baseURL    string
+	httpClient *http.Client
+}
+
+// New creates a new client instance
+// Uses FUR_URL environment variable or falls back to localhost:8080.
+func New() *Client {
+	baseURL := os.Getenv("FUR_URL")
+	if baseURL == "" {
+		baseURL = defaultURL
+	}
+
+	return &Client{
+		baseURL:    baseURL,
+		httpClient: &http.Client{},
+	}
+}
+
+// NewWithURL creates a new client with a specific URL.
+func NewWithURL(url string) *Client {
+	return &Client{
+		baseURL:    url,
+		httpClient: &http.Client{},
+	}
+}
+
+// GetProviders retrieves all available providers from the service.
+func (c *Client) GetProviders() ([]provider.Provider, error) {
+	url := fmt.Sprintf("%s/providers", c.baseURL)
+
+	resp, err := c.httpClient.Get(url) //nolint:noctx
+	if err != nil {
+		return nil, fmt.Errorf("failed to make request: %w", err)
+	}
+	defer resp.Body.Close() //nolint:errcheck
+
+	if resp.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+	}
+
+	var providers []provider.Provider
+	if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil {
+		return nil, fmt.Errorf("failed to decode response: %w", err)
+	}
+
+	return providers, nil
+}

internal/fur/provider/provider.go 🔗

@@ -0,0 +1,72 @@
+// Package provider provides types and constants for AI providers.
+package provider
+
+// Type represents the type of AI provider.
+type Type string
+
+// All the supported AI provider types.
+const (
+	TypeOpenAI     Type = "openai"
+	TypeAnthropic  Type = "anthropic"
+	TypeGemini     Type = "gemini"
+	TypeAzure      Type = "azure"
+	TypeBedrock    Type = "bedrock"
+	TypeVertexAI   Type = "vertexai"
+	TypeXAI        Type = "xai"
+	TypeOpenRouter Type = "openrouter"
+)
+
+// InferenceProvider represents the inference provider identifier.
+type InferenceProvider string
+
+// All the inference providers supported by the system.
+const (
+	InferenceProviderOpenAI     InferenceProvider = "openai"
+	InferenceProviderAnthropic  InferenceProvider = "anthropic"
+	InferenceProviderGemini     InferenceProvider = "gemini"
+	InferenceProviderAzure      InferenceProvider = "azure"
+	InferenceProviderBedrock    InferenceProvider = "bedrock"
+	InferenceProviderVertexAI   InferenceProvider = "vertexai"
+	InferenceProviderXAI        InferenceProvider = "xai"
+	InferenceProviderOpenRouter InferenceProvider = "openrouter"
+)
+
+// Provider represents an AI provider configuration.
+type Provider struct {
+	Name                string            `json:"name"`
+	ID                  InferenceProvider `json:"id"`
+	APIKey              string            `json:"api_key,omitempty"`
+	APIEndpoint         string            `json:"api_endpoint,omitempty"`
+	Type                Type              `json:"type,omitempty"`
+	DefaultLargeModelID string            `json:"default_large_model_id,omitempty"`
+	DefaultSmallModelID string            `json:"default_small_model_id,omitempty"`
+	Models              []Model           `json:"models,omitempty"`
+}
+
+// Model represents an AI model configuration.
+type Model struct {
+	ID                 string  `json:"id"`
+	Name               string  `json:"model"`
+	CostPer1MIn        float64 `json:"cost_per_1m_in"`
+	CostPer1MOut       float64 `json:"cost_per_1m_out"`
+	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
+	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
+	ContextWindow      int64   `json:"context_window"`
+	DefaultMaxTokens   int64   `json:"default_max_tokens"`
+	CanReason          bool    `json:"can_reason"`
+	SupportsImages     bool    `json:"supports_attachments"`
+}
+
+// KnownProviders returns all the known inference providers.
+func KnownProviders() []InferenceProvider {
+	return []InferenceProvider{
+		InferenceProviderOpenAI,
+		InferenceProviderAnthropic,
+		InferenceProviderGemini,
+		InferenceProviderAzure,
+		InferenceProviderBedrock,
+		InferenceProviderVertexAI,
+		InferenceProviderXAI,
+		InferenceProviderOpenRouter,
+	}
+}

internal/llm/agent/agent.go 🔗

@@ -734,21 +734,15 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
 		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
 		provider.WithMaxTokens(maxTokens),
 	}
-	if (model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal) && model.CanReason {
-		opts = append(
-			opts,
-			provider.WithOpenAIOptions(
-				provider.WithReasoningEffort(agentConfig.ReasoningEffort),
-			),
-		)
-	} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
-		opts = append(
-			opts,
-			provider.WithAnthropicOptions(
-				provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
-			),
-		)
-	}
+	// TODO: reimplement
+	// if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
+	// 	opts = append(
+	// 		opts,
+	// 		provider.WithOpenAIOptions(
+	// 			provider.WithReasoningEffort(agentConfig.ReasoningEffort),
+	// 		),
+	// 	)
+	// }
 	agentProvider, err := provider.NewProvider(
 		model.Provider,
 		opts...,

internal/llm/provider/anthropic.go 🔗

@@ -19,40 +19,25 @@ import (
 	"github.com/charmbracelet/crush/internal/message"
 )
 
-type anthropicOptions struct {
-	useBedrock   bool
-	disableCache bool
-	shouldThink  func(userMessage string) bool
-}
-
-type AnthropicOption func(*anthropicOptions)
-
 type anthropicClient struct {
 	providerOptions providerClientOptions
-	options         anthropicOptions
 	client          anthropic.Client
 }
 
 type AnthropicClient ProviderClient
 
-func newAnthropicClient(opts providerClientOptions) AnthropicClient {
-	anthropicOpts := anthropicOptions{}
-	for _, o := range opts.anthropicOptions {
-		o(&anthropicOpts)
-	}
-
+func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 	anthropicClientOptions := []option.RequestOption{}
 	if opts.apiKey != "" {
 		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 	}
-	if anthropicOpts.useBedrock {
+	if useBedrock {
 		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 	}
 
 	client := anthropic.NewClient(anthropicClientOptions...)
 	return &anthropicClient{
 		providerOptions: opts,
-		options:         anthropicOpts,
 		client:          client,
 	}
 }
@@ -66,7 +51,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
 		switch msg.Role {
 		case message.User:
 			content := anthropic.NewTextBlock(msg.Content().String())
-			if cache && !a.options.disableCache {
+			if cache && !a.providerOptions.disableCache {
 				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 					Type: "ephemeral",
 				}
@@ -84,7 +69,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
 			blocks := []anthropic.ContentBlockParamUnion{}
 			if msg.Content().String() != "" {
 				content := anthropic.NewTextBlock(msg.Content().String())
-				if cache && !a.options.disableCache {
+				if cache && !a.providerOptions.disableCache {
 					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 						Type: "ephemeral",
 					}
@@ -132,7 +117,7 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU
 			},
 		}
 
-		if i == len(tools)-1 && !a.options.disableCache {
+		if i == len(tools)-1 && !a.providerOptions.disableCache {
 			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
 				Type: "ephemeral",
 			}
@@ -161,21 +146,22 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason {
 
 func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
 	var thinkingParam anthropic.ThinkingConfigParamUnion
-	lastMessage := messages[len(messages)-1]
-	isUser := lastMessage.Role == anthropic.MessageParamRoleUser
-	messageContent := ""
+	// TODO: Implement a proper thinking function
+	// lastMessage := messages[len(messages)-1]
+	// isUser := lastMessage.Role == anthropic.MessageParamRoleUser
+	// messageContent := ""
 	temperature := anthropic.Float(0)
-	if isUser {
-		for _, m := range lastMessage.Content {
-			if m.OfText != nil && m.OfText.Text != "" {
-				messageContent = m.OfText.Text
-			}
-		}
-		if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) {
-			thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
-			temperature = anthropic.Float(1)
-		}
-	}
+	// if isUser {
+	// 	for _, m := range lastMessage.Content {
+	// 		if m.OfText != nil && m.OfText.Text != "" {
+	// 			messageContent = m.OfText.Text
+	// 		}
+	// 	}
+	// 	if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) {
+	// 		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
+	// 		temperature = anthropic.Float(1)
+	// 	}
+	// }
 
 	return anthropic.MessageNewParams{
 		Model:       anthropic.Model(a.providerOptions.model.APIModel),
@@ -439,24 +425,7 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
 	}
 }
 
-func WithAnthropicBedrock(useBedrock bool) AnthropicOption {
-	return func(options *anthropicOptions) {
-		options.useBedrock = useBedrock
-	}
-}
-
-func WithAnthropicDisableCache() AnthropicOption {
-	return func(options *anthropicOptions) {
-		options.disableCache = true
-	}
-}
-
+// TODO: check if we need
 func DefaultShouldThinkFn(s string) bool {
 	return strings.Contains(strings.ToLower(s), "think")
 }
-
-func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption {
-	return func(options *anthropicOptions) {
-		options.shouldThink = fn
-	}
-}

internal/llm/provider/bedrock.go 🔗

@@ -11,22 +11,14 @@ import (
 	"github.com/charmbracelet/crush/internal/message"
 )
 
-type bedrockOptions struct {
-	// Bedrock specific options can be added here
-}
-
-type BedrockOption func(*bedrockOptions)
-
 type bedrockClient struct {
 	providerOptions providerClientOptions
-	options         bedrockOptions
 	childProvider   ProviderClient
 }
 
 type BedrockClient ProviderClient
 
 func newBedrockClient(opts providerClientOptions) BedrockClient {
-	bedrockOpts := bedrockOptions{}
 	// Apply bedrock specific options if they are added in the future
 
 	// Get AWS region from environment
@@ -41,7 +33,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
 	if len(region) < 2 {
 		return &bedrockClient{
 			providerOptions: opts,
-			options:         bedrockOpts,
 			childProvider:   nil, // Will cause an error when used
 		}
 	}
@@ -55,14 +46,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
 	if strings.Contains(string(opts.model.APIModel), "anthropic") {
 		// Create Anthropic client with Bedrock configuration
 		anthropicOpts := opts
-		anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
-			WithAnthropicBedrock(true),
-			WithAnthropicDisableCache(),
-		)
+		// TODO: later find a way to check if the AWS account has caching enabled
+		opts.disableCache = true // Disable cache for Bedrock
 		return &bedrockClient{
 			providerOptions: opts,
-			options:         bedrockOpts,
-			childProvider:   newAnthropicClient(anthropicOpts),
+			childProvider:   newAnthropicClient(anthropicOpts, true),
 		}
 	}
 
@@ -70,7 +58,6 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
 	// This will cause an error when used
 	return &bedrockClient{
 		providerOptions: opts,
-		options:         bedrockOpts,
 		childProvider:   nil,
 	}
 }

internal/llm/provider/gemini.go 🔗

@@ -17,26 +17,14 @@ import (
 	"google.golang.org/genai"
 )
 
-type geminiOptions struct {
-	disableCache bool
-}
-
-type GeminiOption func(*geminiOptions)
-
 type geminiClient struct {
 	providerOptions providerClientOptions
-	options         geminiOptions
 	client          *genai.Client
 }
 
 type GeminiClient ProviderClient
 
 func newGeminiClient(opts providerClientOptions) GeminiClient {
-	geminiOpts := geminiOptions{}
-	for _, o := range opts.geminiOptions {
-		o(&geminiOpts)
-	}
-
 	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
 	if err != nil {
 		logging.Error("Failed to create Gemini client", "error", err)
@@ -45,7 +33,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
 
 	return &geminiClient{
 		providerOptions: opts,
-		options:         geminiOpts,
 		client:          client,
 	}
 }
@@ -452,12 +439,6 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
 	}
 }
 
-func WithGeminiDisableCache() GeminiOption {
-	return func(options *geminiOptions) {
-		options.disableCache = true
-	}
-}
-
 // Helper functions
 func parseJsonToMap(jsonStr string) (map[string]any, error) {
 	var result map[string]any

internal/llm/provider/openai.go 🔗

@@ -19,14 +19,9 @@ import (
 )
 
 type openaiOptions struct {
-	baseURL         string
-	disableCache    bool
 	reasoningEffort string
-	extraHeaders    map[string]string
 }
 
-type OpenAIOption func(*openaiOptions)
-
 type openaiClient struct {
 	providerOptions providerClientOptions
 	options         openaiOptions
@@ -39,20 +34,17 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 	openaiOpts := openaiOptions{
 		reasoningEffort: "medium",
 	}
-	for _, o := range opts.openaiOptions {
-		o(&openaiOpts)
-	}
 
 	openaiClientOptions := []option.RequestOption{}
 	if opts.apiKey != "" {
 		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 	}
-	if openaiOpts.baseURL != "" {
-		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
+	if opts.baseURL != "" {
+		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
 	}
 
-	if openaiOpts.extraHeaders != nil {
-		for key, value := range openaiOpts.extraHeaders {
+	if opts.extraHeaders != nil {
+		for key, value := range opts.extraHeaders {
 			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 		}
 	}
@@ -392,34 +384,3 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
 		CacheReadTokens:     cachedTokens,
 	}
 }
-
-func WithOpenAIBaseURL(baseURL string) OpenAIOption {
-	return func(options *openaiOptions) {
-		options.baseURL = baseURL
-	}
-}
-
-func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
-	return func(options *openaiOptions) {
-		options.extraHeaders = headers
-	}
-}
-
-func WithOpenAIDisableCache() OpenAIOption {
-	return func(options *openaiOptions) {
-		options.disableCache = true
-	}
-}
-
-func WithReasoningEffort(effort string) OpenAIOption {
-	return func(options *openaiOptions) {
-		defaultReasoningEffort := "medium"
-		switch effort {
-		case "low", "medium", "high":
-			defaultReasoningEffort = effort
-		default:
-			logging.Warn("Invalid reasoning effort, using default: medium")
-		}
-		options.reasoningEffort = defaultReasoningEffort
-	}
-}

internal/llm/provider/provider.go 🔗

@@ -3,6 +3,7 @@ package provider
 import (
 	"context"
 	"fmt"
+	"maps"
 	"os"
 
 	"github.com/charmbracelet/crush/internal/llm/models"
@@ -59,15 +60,13 @@ type Provider interface {
 }
 
 type providerClientOptions struct {
+	baseURL       string
 	apiKey        string
 	model         models.Model
+	disableCache  bool
 	maxTokens     int64
 	systemMessage string
-
-	anthropicOptions []AnthropicOption
-	openaiOptions    []OpenAIOption
-	geminiOptions    []GeminiOption
-	bedrockOptions   []BedrockOption
+	extraHeaders  map[string]string
 }
 
 type ProviderClientOption func(*providerClientOptions)
@@ -91,7 +90,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp
 	case models.ProviderAnthropic:
 		return &baseProvider[AnthropicClient]{
 			options: clientOptions,
-			client:  newAnthropicClient(clientOptions),
+			client:  newAnthropicClient(clientOptions, false),
 		}, nil
 	case models.ProviderOpenAI:
 		return &baseProvider[OpenAIClient]{
@@ -109,9 +108,7 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp
 			client:  newBedrockClient(clientOptions),
 		}, nil
 	case models.ProviderGROQ:
-		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
-			WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
-		)
+		clientOptions.baseURL = "https://api.groq.com/openai/v1"
 		return &baseProvider[OpenAIClient]{
 			options: clientOptions,
 			client:  newOpenAIClient(clientOptions),
@@ -127,29 +124,23 @@ func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOp
 			client:  newVertexAIClient(clientOptions),
 		}, nil
 	case models.ProviderOpenRouter:
-		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
-			WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
-			WithOpenAIExtraHeaders(map[string]string{
-				"HTTP-Referer": "crush.charm.land",
-				"X-Title":      "Crush",
-			}),
-		)
+		clientOptions.baseURL = "https://openrouter.ai/api/v1"
+		clientOptions.extraHeaders = map[string]string{
+			"HTTP-Referer": "crush.charm.land",
+			"X-Title":      "Crush",
+		}
 		return &baseProvider[OpenAIClient]{
 			options: clientOptions,
 			client:  newOpenAIClient(clientOptions),
 		}, nil
 	case models.ProviderXAI:
-		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
-			WithOpenAIBaseURL("https://api.x.ai/v1"),
-		)
+		clientOptions.baseURL = "https://api.x.ai/v1"
 		return &baseProvider[OpenAIClient]{
 			options: clientOptions,
 			client:  newOpenAIClient(clientOptions),
 		}, nil
 	case models.ProviderLocal:
-		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
-			WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
-		)
+		clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT")
 		return &baseProvider[OpenAIClient]{
 			options: clientOptions,
 			client:  newOpenAIClient(clientOptions),
@@ -186,50 +177,47 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
 	return p.client.stream(ctx, messages, tools)
 }
 
-func WithAPIKey(apiKey string) ProviderClientOption {
-	return func(options *providerClientOptions) {
-		options.apiKey = apiKey
-	}
-}
-
-func WithModel(model models.Model) ProviderClientOption {
+func WithBaseURL(baseURL string) ProviderClientOption {
 	return func(options *providerClientOptions) {
-		options.model = model
+		options.baseURL = baseURL
 	}
 }
 
-func WithMaxTokens(maxTokens int64) ProviderClientOption {
+func WithAPIKey(apiKey string) ProviderClientOption {
 	return func(options *providerClientOptions) {
-		options.maxTokens = maxTokens
+		options.apiKey = apiKey
 	}
 }
 
-func WithSystemMessage(systemMessage string) ProviderClientOption {
+func WithModel(model models.Model) ProviderClientOption {
 	return func(options *providerClientOptions) {
-		options.systemMessage = systemMessage
+		options.model = model
 	}
 }
 
-func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
+func WithDisableCache(disableCache bool) ProviderClientOption {
 	return func(options *providerClientOptions) {
-		options.anthropicOptions = anthropicOptions
+		options.disableCache = disableCache
 	}
 }
 
-func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
+func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption {
 	return func(options *providerClientOptions) {
-		options.openaiOptions = openaiOptions
+		if options.extraHeaders == nil {
+			options.extraHeaders = make(map[string]string)
+		}
+		maps.Copy(options.extraHeaders, extraHeaders)
 	}
 }
 
-func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
+func WithMaxTokens(maxTokens int64) ProviderClientOption {
 	return func(options *providerClientOptions) {
-		options.geminiOptions = geminiOptions
+		options.maxTokens = maxTokens
 	}
 }
 
-func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
+func WithSystemMessage(systemMessage string) ProviderClientOption {
 	return func(options *providerClientOptions) {
-		options.bedrockOptions = bedrockOptions
+		options.systemMessage = systemMessage
 	}
 }

internal/llm/provider/vertexai.go 🔗

@@ -11,11 +11,6 @@ import (
 type VertexAIClient ProviderClient
 
 func newVertexAIClient(opts providerClientOptions) VertexAIClient {
-	geminiOpts := geminiOptions{}
-	for _, o := range opts.geminiOptions {
-		o(&geminiOpts)
-	}
-
 	client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
 		Project:  os.Getenv("GOOGLE_CLOUD_PROJECT"),
 		Location: os.Getenv("GOOGLE_CLOUD_LOCATION"),
@@ -28,7 +23,6 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
 
 	return &geminiClient{
 		providerOptions: opts,
-		options:         geminiOpts,
 		client:          client,
 	}
 }