From 464477c829ee1f39c7d5bd0e2020364ecd8f0ecb Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Wed, 23 Jul 2025 15:13:22 -0300 Subject: [PATCH 1/3] feat: use new catwalk --- go.mod | 1 + go.sum | 4 +- internal/config/config.go | 24 +-- internal/config/load.go | 37 ++-- internal/config/load_test.go | 160 +++++++++--------- internal/config/provider.go | 21 ++- internal/config/provider_test.go | 8 +- internal/fur/client/client.go | 63 ------- internal/fur/provider/provider.go | 75 -------- internal/llm/agent/agent.go | 8 +- internal/llm/prompt/coder.go | 6 +- internal/llm/provider/anthropic.go | 6 +- internal/llm/provider/bedrock.go | 6 +- internal/llm/provider/gemini.go | 4 +- internal/llm/provider/openai.go | 6 +- internal/llm/provider/openai_test.go | 10 +- internal/llm/provider/provider.go | 26 +-- internal/message/content.go | 6 +- .../tui/components/chat/messages/messages.go | 8 +- .../tui/components/chat/sidebar/sidebar.go | 8 +- internal/tui/components/chat/splash/splash.go | 6 +- .../components/dialogs/commands/commands.go | 4 +- .../tui/components/dialogs/models/list.go | 22 +-- .../tui/components/dialogs/models/models.go | 8 +- internal/tui/page/chat/chat.go | 2 +- 25 files changed, 196 insertions(+), 333 deletions(-) delete mode 100644 internal/fur/client/client.go delete mode 100644 internal/fur/provider/provider.go diff --git a/go.mod b/go.mod index 8607bd2923fc6b7c4146368d906bd81dfd29dd5a..e17354c051a21b593a385b1e3995cc543aafd0dd 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/charlievieth/fastwalk v1.0.11 github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5 github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac + github.com/charmbracelet/catwalk v0.3.1 github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112 diff --git a/go.sum b/go.sum index df50a1529358a935fc628070ecc0498ff06024a3..755edeb81ead60da60196e2834c9e6354af168b7 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,8 @@ github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5 github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5/go.mod h1:6HamsBKWqEC/FVHuQMHgQL+knPyvHH55HwJDHl/adMw= github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac h1:murtkvFYxZ/73vk4Z/tpE4biB+WDZcFmmBp8je/yV6M= github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac/go.mod h1:m240IQxo1/eDQ7klblSzOCAUyc3LddHcV3Rc/YEGAgw= +github.com/charmbracelet/catwalk v0.3.1 h1:MkGWspcMyE659zDkqS+9wsaCMTKRFEDBFY2A2sap6+U= +github.com/charmbracelet/catwalk v0.3.1/go.mod h1:gUUCqqZ8bk4D7ZzGTu3I77k7cC2x4exRuJBN1H2u2pc= github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40= github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0= github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0= @@ -82,8 +84,6 @@ github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112 github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112/go.mod h1:BXY7j7rZgAprFwzNcO698++5KTd6GKI6lU83Pr4o0r0= github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 h1:WkwO6Ks3mSIGnGuSdKl9qDSyfbYK50z2wc2gGMggegE= github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706/go.mod h1:mjJGp00cxcfvD5xdCa+bso251Jt4owrQvuimJtVmEmM= -github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42 h1:Zqw2oP9Wo8VzMijVJbtIJcAaZviYyU07stvmCFCfn0Y= -github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42/go.mod h1:XrrgNFfXLrFAyd9DUmrqVc3yQFVv8Uk+okj4PsNNzpc= github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1 h1:tsw1mOuIEIKlmm614bXctvJ3aavaFhyPG+y+wrKtuKQ= github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1/go.mod h1:XrrgNFfXLrFAyd9DUmrqVc3yQFVv8Uk+okj4PsNNzpc= github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0= diff --git a/internal/config/config.go b/internal/config/config.go index 1c20188a12a3955fde6b6eeed9f12ea39288e328..bfbcc5ed91ec86075d269cc53d0e5bf22252ed47 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,8 +9,8 @@ import ( "strings" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/tidwall/sjson" "golang.org/x/exp/slog" ) @@ -70,7 +70,7 @@ type ProviderConfig struct { // The provider's API endpoint. BaseURL string `json:"base_url,omitempty"` // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. - Type provider.Type `json:"type,omitempty"` + Type catwalk.Type `json:"type,omitempty"` // The provider's API key. APIKey string `json:"api_key,omitempty"` // Marks the provider as disabled. @@ -85,7 +85,7 @@ type ProviderConfig struct { ExtraParams map[string]string `json:"-"` // The provider models - Models []provider.Model `json:"models,omitempty"` + Models []catwalk.Model `json:"models,omitempty"` } type MCPType string @@ -250,8 +250,8 @@ type Config struct { Agents map[string]Agent `json:"-"` // TODO: find a better way to do this this should probably not be part of the config resolver VariableResolver - dataConfigDir string `json:"-"` - knownProviders []provider.Provider `json:"-"` + dataConfigDir string `json:"-"` + knownProviders []catwalk.Provider `json:"-"` } func (c *Config) WorkingDir() string { @@ -273,7 +273,7 @@ func (c *Config) IsConfigured() bool { return len(c.EnabledProviders()) > 0 } -func (c *Config) GetModel(provider, model string) *provider.Model { +func (c *Config) GetModel(provider, model string) *catwalk.Model { if providerConfig, ok := c.Providers[provider]; ok { for _, m := range providerConfig.Models { if m.ID == model { @@ -295,7 +295,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi return nil } -func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model { +func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model { model, ok := c.Models[modelType] if !ok { return nil @@ -303,7 +303,7 @@ func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) LargeModel() *provider.Model { +func (c *Config) LargeModel() *catwalk.Model { model, ok := c.Models[SelectedModelTypeLarge] if !ok { return nil @@ -311,7 +311,7 @@ func (c *Config) LargeModel() *provider.Model { return c.GetModel(model.Provider, model.Model) } -func (c *Config) SmallModel() *provider.Model { +func (c *Config) SmallModel() *catwalk.Model { model, ok := c.Models[SelectedModelTypeSmall] if !ok { return nil @@ -381,7 +381,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return nil } - var foundProvider *provider.Provider + var foundProvider *catwalk.Provider for _, p := range c.knownProviders { if string(p.ID) == providerID { foundProvider = &p @@ -450,14 +450,14 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { headers := make(map[string]string) apiKey, _ := resolver.ResolveValue(c.APIKey) switch c.Type { - case provider.TypeOpenAI: + case catwalk.TypeOpenAI: baseURL, _ := resolver.ResolveValue(c.BaseURL) if baseURL == "" { baseURL = "https://api.openai.com/v1" } testURL = baseURL + "/models" headers["Authorization"] = "Bearer " + apiKey - case provider.TypeAnthropic: + case catwalk.TypeAnthropic: baseURL, _ := resolver.ResolveValue(c.BaseURL) if baseURL == "" { baseURL = "https://api.anthropic.com/v1" diff --git a/internal/config/load.go b/internal/config/load.go index cd4ccd08c46e48155091407962137da2cb913869..bf0fc3d562d6a38544399a095f4efe0a5f75fcd2 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -11,13 +11,14 @@ import ( "strings" "sync" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/client" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/log" "golang.org/x/exp/slog" ) +const catwalkURL = "https://catwalk.charm.sh" + // LoadReader config via io.Reader. func LoadReader(fd io.Reader) (*Config, error) { data, err := io.ReadAll(fd) @@ -61,8 +62,8 @@ func Load(workingDir string, debug bool) (*Config, error) { cfg.Options.Debug, ) - // Load known providers, this loads the config from fur - providers, err := LoadProviders(client.New()) + // Load known providers, this loads the config from catwalk + providers, err := LoadProviders(catwalk.NewWithURL(catwalkURL)) if err != nil || len(providers) == 0 { return nil, fmt.Errorf("failed to load providers: %w", err) } @@ -81,7 +82,7 @@ func Load(workingDir string, debug bool) (*Config, error) { var wg sync.WaitGroup for _, p := range cfg.Providers { - if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { + if p.Type == catwalk.TypeOpenAI || p.Type == catwalk.TypeAnthropic { wg.Add(1) go func(provider ProviderConfig) { defer wg.Done() @@ -117,7 +118,7 @@ func Load(workingDir string, debug bool) (*Config, error) { return cfg, nil } -func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { +func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true @@ -136,7 +137,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know p.APIKey = config.APIKey } if len(config.Models) > 0 { - models := []provider.Model{} + models := []catwalk.Model{} seen := make(map[string]bool) for _, model := range config.Models { @@ -144,8 +145,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know continue } seen[model.ID] = true - if model.Model == "" { - model.Model = model.ID + if model.Name == "" { + model.Name = model.ID } models = append(models, model) } @@ -154,8 +155,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know continue } seen[model.ID] = true - if model.Model == "" { - model.Model = model.ID + if model.Name == "" { + model.Name = model.ID } models = append(models, model) } @@ -178,7 +179,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch p.ID { // Handle specific providers that require additional configuration - case provider.InferenceProviderVertexAI: + case catwalk.InferenceProviderVertexAI: if !hasVertexCredentials(env) { if configExists { slog.Warn("Skipping Vertex AI provider due to missing credentials") @@ -188,7 +189,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT") prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION") - case provider.InferenceProviderAzure: + case catwalk.InferenceProviderAzure: endpoint, err := resolver.ResolveValue(p.APIEndpoint) if err != nil || endpoint == "" { if configExists { @@ -199,7 +200,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } prepared.BaseURL = endpoint prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION") - case provider.InferenceProviderBedrock: + case catwalk.InferenceProviderBedrock: if !hasAWSCredentials(env) { if configExists { slog.Warn("Skipping Bedrock provider due to missing AWS credentials") @@ -239,7 +240,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } // default to OpenAI if not set if providerConfig.Type == "" { - providerConfig.Type = provider.TypeOpenAI + providerConfig.Type = catwalk.TypeOpenAI } if providerConfig.Disable { @@ -260,7 +261,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know delete(c.Providers, id) continue } - if providerConfig.Type != provider.TypeOpenAI { + if providerConfig.Type != catwalk.TypeOpenAI { slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) delete(c.Providers, id) continue @@ -315,7 +316,7 @@ func (c *Config) setDefaults(workingDir string) { c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths) } -func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { +func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { if len(knownProviders) == 0 && len(c.Providers) == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return @@ -384,7 +385,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg return } -func (c *Config) configureSelectedModels(knownProviders []provider.Provider) error { +func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error { defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders) if err != nil { return fmt.Errorf("failed to select default models: %w", err) diff --git a/internal/config/load_test.go b/internal/config/load_test.go index b96ca5e81cd265cbcd1bdf9d456603ad3f22c558..a3d224d443b8995b747481871a82a097afa02e1b 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -8,8 +8,8 @@ import ( "strings" "testing" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/stretchr/testify/assert" ) @@ -54,12 +54,12 @@ func TestConfig_setDefaults(t *testing.T) { } func TestConfig_configureProviders(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -80,12 +80,12 @@ func TestConfig_configureProviders(t *testing.T) { } func TestConfig_configureProvidersWithOverride(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -96,10 +96,10 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { "openai": { APIKey: "xyz", BaseURL: "https://api.openai.com/v2", - Models: []provider.Model{ + Models: []catwalk.Model{ { - ID: "test-model", - Model: "Updated", + ID: "test-model", + Name: "Updated", }, { ID: "another-model", @@ -122,16 +122,16 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey) assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL) assert.Len(t, cfg.Providers["openai"].Models, 2) - assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model) + assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name) } func TestConfig_configureProvidersWithNewProvider(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -142,7 +142,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "test-model", }, @@ -172,12 +172,12 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { } func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, @@ -201,12 +201,12 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { } func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, @@ -223,12 +223,12 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { } func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "some-random-model", }}, }, @@ -246,12 +246,12 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { } func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -278,12 +278,12 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { } func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -304,12 +304,12 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { } func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -329,12 +329,12 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { } func TestConfig_configureProvidersSetProviderID(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -450,12 +450,12 @@ func TestConfig_IsConfigured(t *testing.T) { } func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -489,7 +489,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { Providers: map[string]ProviderConfig{ "custom": { BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -502,7 +502,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Len(t, cfg.Providers, 1) @@ -515,7 +515,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -525,7 +525,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Len(t, cfg.Providers, 0) @@ -539,7 +539,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{}, + Models: []catwalk.Model{}, }, }, } @@ -547,7 +547,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Len(t, cfg.Providers, 0) @@ -562,7 +562,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Type: "unsupported", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -572,7 +572,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Len(t, cfg.Providers, 0) @@ -586,8 +586,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Type: provider.TypeOpenAI, - Models: []provider.Model{{ + Type: catwalk.TypeOpenAI, + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -597,7 +597,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Len(t, cfg.Providers, 1) @@ -614,9 +614,9 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Type: provider.TypeOpenAI, + Type: catwalk.TypeOpenAI, Disable: true, - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -626,7 +626,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []provider.Provider{}) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) assert.Len(t, cfg.Providers, 0) @@ -637,12 +637,12 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderVertexAI, + ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, @@ -670,12 +670,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }) t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { - ID: provider.InferenceProviderBedrock, + ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, @@ -701,12 +701,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }) t.Run("provider removed when API key missing with existing config", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING_API_KEY", APIEndpoint: "https://api.openai.com/v1", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -732,12 +732,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { }) t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "$MISSING_ENDPOINT", - Models: []provider.Model{{ + Models: []catwalk.Model{{ ID: "test-model", }}, }, @@ -767,13 +767,13 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { func TestConfig_defaultModelSelection(t *testing.T) { t.Run("default behavior uses the default models for given provider", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -803,13 +803,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { assert.Equal(t, int64(500), small.MaxTokens) }) t.Run("should error if no providers configured", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING_KEY", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -833,13 +833,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { assert.Error(t, err) }) t.Run("should error if model is missing", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, @@ -863,13 +863,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { }) t.Run("should configure the default models with a custom provider", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING", // will not be included in the config DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, @@ -887,7 +887,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "model", DefaultMaxTokens: 600, @@ -912,13 +912,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { }) t.Run("should fail if no model configured", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING", // will not be included in the config DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, @@ -936,7 +936,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{}, + Models: []catwalk.Model{}, }, }, } @@ -949,13 +949,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { assert.Error(t, err) }) t.Run("should use the default provider first", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "set", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -973,7 +973,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -1000,13 +1000,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { func TestConfig_configureSelectedModels(t *testing.T) { t.Run("should override defaults", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "larger-model", DefaultMaxTokens: 2000, @@ -1048,13 +1048,13 @@ func TestConfig_configureSelectedModels(t *testing.T) { assert.Equal(t, int64(500), small.MaxTokens) }) t.Run("should be possible to use multiple providers", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, @@ -1070,7 +1070,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { APIKey: "abc", DefaultLargeModelID: "a-large-model", DefaultSmallModelID: "a-small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "a-large-model", DefaultMaxTokens: 1000, @@ -1111,13 +1111,13 @@ func TestConfig_configureSelectedModels(t *testing.T) { }) t.Run("should override the max tokens only", func(t *testing.T) { - knownProviders := []provider.Provider{ + knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", - Models: []provider.Model{ + Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, diff --git a/internal/config/provider.go b/internal/config/provider.go index b8369b934963aca0a7f449fb219764ee079493ef..9b5cdf608c36c36d62faffdb19e84c74013a1884 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -7,17 +7,16 @@ import ( "runtime" "sync" - "github.com/charmbracelet/crush/internal/fur/client" - "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/catwalk/pkg/catwalk" ) type ProviderClient interface { - GetProviders() ([]provider.Provider, error) + GetProviders() ([]catwalk.Provider, error) } var ( providerOnce sync.Once - providerList []provider.Provider + providerList []catwalk.Provider ) // file to cache provider data @@ -41,7 +40,7 @@ func providerCacheFileData() string { return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json") } -func saveProvidersInCache(path string, providers []provider.Provider) error { +func saveProvidersInCache(path string, providers []catwalk.Provider) error { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0o755); err != nil { return err @@ -55,18 +54,18 @@ func saveProvidersInCache(path string, providers []provider.Provider) error { return os.WriteFile(path, data, 0o644) } -func loadProvidersFromCache(path string) ([]provider.Provider, error) { +func loadProvidersFromCache(path string) ([]catwalk.Provider, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } - var providers []provider.Provider + var providers []catwalk.Provider err = json.Unmarshal(data, &providers) return providers, err } -func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) { +func loadProviders(path string, client ProviderClient) ([]catwalk.Provider, error) { providers, err := client.GetProviders() if err != nil { fallbackToCache, err := loadProvidersFromCache(path) @@ -82,11 +81,11 @@ func loadProviders(path string, client ProviderClient) ([]provider.Provider, err return providers, nil } -func Providers() ([]provider.Provider, error) { - return LoadProviders(client.New()) +func Providers() ([]catwalk.Provider, error) { + return LoadProviders(catwalk.NewWithURL(catwalkURL)) } -func LoadProviders(client ProviderClient) ([]provider.Provider, error) { +func LoadProviders(client ProviderClient) ([]catwalk.Provider, error) { var err error providerOnce.Do(func() { providerList, err = loadProviders(providerCacheFileData(), client) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index a3562838c7103239aa303c906c866220164a4ba0..a63099ee27c96abb97d2781b186bb5aa9e060396 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -6,7 +6,7 @@ import ( "os" "testing" - "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/stretchr/testify/assert" ) @@ -14,11 +14,11 @@ type mockProviderClient struct { shouldFail bool } -func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) { +func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { if m.shouldFail { return nil, errors.New("failed to load providers") } - return []provider.Provider{ + return []catwalk.Provider{ { Name: "Mock", }, @@ -43,7 +43,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" // store providers to a temporary file - oldProviders := []provider.Provider{ + oldProviders := []catwalk.Provider{ { Name: "OldProvider", }, diff --git a/internal/fur/client/client.go b/internal/fur/client/client.go deleted file mode 100644 index d007c9aee18f77c8b03fe804726b4196e474d0b4..0000000000000000000000000000000000000000 --- a/internal/fur/client/client.go +++ /dev/null @@ -1,63 +0,0 @@ -// Package client provides a client for interacting with the fur service. -package client - -import ( - "encoding/json" - "fmt" - "net/http" - "os" - - "github.com/charmbracelet/crush/internal/fur/provider" -) - -const defaultURL = "https://fur.charm.sh" - -// Client represents a client for the fur service. -type Client struct { - baseURL string - httpClient *http.Client -} - -// New creates a new client instance -// Uses FUR_URL environment variable or falls back to localhost:8080. -func New() *Client { - baseURL := os.Getenv("FUR_URL") - if baseURL == "" { - baseURL = defaultURL - } - - return &Client{ - baseURL: baseURL, - httpClient: &http.Client{}, - } -} - -// NewWithURL creates a new client with a specific URL. -func NewWithURL(url string) *Client { - return &Client{ - baseURL: url, - httpClient: &http.Client{}, - } -} - -// GetProviders retrieves all available providers from the service. -func (c *Client) GetProviders() ([]provider.Provider, error) { - url := fmt.Sprintf("%s/providers", c.baseURL) - - resp, err := c.httpClient.Get(url) //nolint:noctx - if err != nil { - return nil, fmt.Errorf("failed to make request: %w", err) - } - defer resp.Body.Close() //nolint:errcheck - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - var providers []provider.Provider - if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - return providers, nil -} diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go deleted file mode 100644 index 2bfe95a5bc3db4f1e52feebcaf7d484f4d5de948..0000000000000000000000000000000000000000 --- a/internal/fur/provider/provider.go +++ /dev/null @@ -1,75 +0,0 @@ -// Package provider provides types and constants for AI providers. -package provider - -// Type represents the type of AI provider. -type Type string - -// All the supported AI provider types. -const ( - TypeOpenAI Type = "openai" - TypeAnthropic Type = "anthropic" - TypeGemini Type = "gemini" - TypeAzure Type = "azure" - TypeBedrock Type = "bedrock" - TypeVertexAI Type = "vertexai" - TypeXAI Type = "xai" -) - -// InferenceProvider represents the inference provider identifier. -type InferenceProvider string - -// All the inference providers supported by the system. -const ( - InferenceProviderOpenAI InferenceProvider = "openai" - InferenceProviderAnthropic InferenceProvider = "anthropic" - InferenceProviderGemini InferenceProvider = "gemini" - InferenceProviderAzure InferenceProvider = "azure" - InferenceProviderBedrock InferenceProvider = "bedrock" - InferenceProviderVertexAI InferenceProvider = "vertexai" - InferenceProviderXAI InferenceProvider = "xai" - InferenceProviderGROQ InferenceProvider = "groq" - InferenceProviderOpenRouter InferenceProvider = "openrouter" -) - -// Provider represents an AI provider configuration. -type Provider struct { - Name string `json:"name"` - ID InferenceProvider `json:"id"` - APIKey string `json:"api_key,omitempty"` - APIEndpoint string `json:"api_endpoint,omitempty"` - Type Type `json:"type,omitempty"` - DefaultLargeModelID string `json:"default_large_model_id,omitempty"` - DefaultSmallModelID string `json:"default_small_model_id,omitempty"` - Models []Model `json:"models,omitempty"` -} - -// Model represents an AI model configuration. -type Model struct { - ID string `json:"id"` - Model string `json:"model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - HasReasoningEffort bool `json:"has_reasoning_efforts"` - DefaultReasoningEffort string `json:"default_reasoning_effort,omitempty"` - SupportsImages bool `json:"supports_attachments"` -} - -// KnownProviders returns all the known inference providers. -func KnownProviders() []InferenceProvider { - return []InferenceProvider{ - InferenceProviderOpenAI, - InferenceProviderAnthropic, - InferenceProviderGemini, - InferenceProviderAzure, - InferenceProviderBedrock, - InferenceProviderVertexAI, - InferenceProviderXAI, - InferenceProviderGROQ, - InferenceProviderOpenRouter, - } -} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 39c762991019f339348efab8cd9b769077e316f5..25545db4c53895d389d8f94c4e260926232a0251 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -10,8 +10,8 @@ import ( "sync" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - fur "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" @@ -52,7 +52,7 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() fur.Model + Model() catwalk.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() @@ -219,7 +219,7 @@ func NewAgent( return agent, nil } -func (a *agent) Model() fur.Model { +func (a *agent) Model() catwalk.Model { return *config.Get().GetModelByType(a.agentCfg.Model) } @@ -638,7 +638,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg return nil } -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error { +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error { sess, err := a.sessions.Get(ctx, sessionID) if err != nil { return fmt.Errorf("failed to get session: %w", err) diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index f4284faccee052e82e8ed82a820b16af58ccc64c..2ffbf2111931ad111751af1bfcd492422da205ee 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,17 +9,17 @@ import ( "runtime" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" ) func CoderPrompt(p string, contextFiles ...string) string { var basePrompt string switch p { - case string(provider.InferenceProviderOpenAI): + case string(catwalk.InferenceProviderOpenAI): basePrompt = baseOpenAICoderPrompt - case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI): + case string(catwalk.InferenceProviderGemini), string(catwalk.InferenceProviderVertexAI): basePrompt = baseGeminiCoderPrompt default: basePrompt = baseAnthropicCoderPrompt diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 1e8364b08cb76ec7210d9937302cd1c647857b2d..00a84be57422df97a4773a6849c2c09561a6bd77 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -15,8 +15,8 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -71,7 +71,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic var contentBlocks []anthropic.ContentBlockParamUnion contentBlocks = append(contentBlocks, content) for _, binaryContent := range msg.BinaryContent() { - base64Image := binaryContent.String(provider.InferenceProviderAnthropic) + base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic) imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image) contentBlocks = append(contentBlocks, imageBlock) } @@ -529,6 +529,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { } } -func (a *anthropicClient) Model() provider.Model { +func (a *anthropicClient) Model() catwalk.Model { return a.providerOptions.model(a.providerOptions.modelType) } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 0c0ccdbab2d642f139a2b1ab2f19f6298f1ac73d..8b5b21c36a390e80843504c7c9f6c257156f6379 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -6,8 +6,8 @@ import ( "fmt" "strings" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -32,7 +32,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { } } - opts.model = func(modelType config.SelectedModelType) provider.Model { + opts.model = func(modelType config.SelectedModelType) catwalk.Model { model := config.Get().GetModelByType(modelType) // Prefix the model name with region @@ -88,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, return b.childProvider.stream(ctx, messages, tools) } -func (b *bedrockClient) Model() provider.Model { +func (b *bedrockClient) Model() catwalk.Model { return b.providerOptions.model(b.providerOptions.modelType) } diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index d2aee5090029e207ef1bdf5e0dad8e011e763267..5d73eb82461e08e236e10e898da221c12a69e576 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -10,8 +10,8 @@ import ( "strings" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/google/uuid" @@ -463,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } -func (g *geminiClient) Model() provider.Model { +func (g *geminiClient) Model() catwalk.Model { return g.providerOptions.model(g.providerOptions.modelType) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index f55914520774e2fcf5e6283e22365f4ce3621dc1..9ce48b70adf987883c93bc3f21cc7e0abaa1e38a 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -9,8 +9,8 @@ import ( "log/slog" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" @@ -66,7 +66,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) for _, binaryContent := range msg.BinaryContent() { - imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)} + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)} imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) @@ -486,6 +486,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { } } -func (o *openaiClient) Model() provider.Model { +func (o *openaiClient) Model() catwalk.Model { return o.providerOptions.model(o.providerOptions.modelType) } diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go index c11e8ff14d7995859cccd3c95eeae4008fb20ac9..26c4d85ae35bbf4681719a12b568befccd8012af 100644 --- a/internal/llm/provider/openai_test.go +++ b/internal/llm/provider/openai_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" @@ -55,10 +55,10 @@ func TestOpenAIClientStreamChoices(t *testing.T) { modelType: config.SelectedModelTypeLarge, apiKey: "test-key", systemMessage: "test", - model: func(config.SelectedModelType) provider.Model { - return provider.Model{ - ID: "test-model", - Model: "test-model", + model: func(config.SelectedModelType) catwalk.Model { + return catwalk.Model{ + ID: "test-model", + Name: "test-model", } }, }, diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 412093334169b4c0d59fdd4f3f72b1e427651307..062c2aa977c6ff101d1d8ab6f32809845bd48ff3 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -4,8 +4,8 @@ import ( "context" "fmt" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) @@ -57,7 +57,7 @@ type Provider interface { StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() provider.Model + Model() catwalk.Model } type providerClientOptions struct { @@ -65,7 +65,7 @@ type providerClientOptions struct { config config.ProviderConfig apiKey string modelType config.SelectedModelType - model func(config.SelectedModelType) provider.Model + model func(config.SelectedModelType) catwalk.Model disableCache bool systemMessage string maxTokens int64 @@ -80,7 +80,7 @@ type ProviderClient interface { send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - Model() provider.Model + Model() catwalk.Model } type baseProvider[C ProviderClient] struct { @@ -109,7 +109,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message return p.client.stream(ctx, messages, tools) } -func (p *baseProvider[C]) Model() provider.Model { +func (p *baseProvider[C]) Model() catwalk.Model { return p.client.Model() } @@ -149,7 +149,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi apiKey: resolvedAPIKey, extraHeaders: cfg.ExtraHeaders, extraBody: cfg.ExtraBody, - model: func(tp config.SelectedModelType) provider.Model { + model: func(tp config.SelectedModelType) catwalk.Model { return *config.Get().GetModelByType(tp) }, } @@ -157,37 +157,37 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi o(&clientOptions) } switch cfg.Type { - case provider.TypeAnthropic: + case catwalk.TypeAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, client: newAnthropicClient(clientOptions, false), }, nil - case provider.TypeOpenAI: + case catwalk.TypeOpenAI: return &baseProvider[OpenAIClient]{ options: clientOptions, client: newOpenAIClient(clientOptions), }, nil - case provider.TypeGemini: + case catwalk.TypeGemini: return &baseProvider[GeminiClient]{ options: clientOptions, client: newGeminiClient(clientOptions), }, nil - case provider.TypeBedrock: + case catwalk.TypeBedrock: return &baseProvider[BedrockClient]{ options: clientOptions, client: newBedrockClient(clientOptions), }, nil - case provider.TypeAzure: + case catwalk.TypeAzure: return &baseProvider[AzureClient]{ options: clientOptions, client: newAzureClient(clientOptions), }, nil - case provider.TypeVertexAI: + case catwalk.TypeVertexAI: return &baseProvider[VertexAIClient]{ options: clientOptions, client: newVertexAIClient(clientOptions), }, nil - case provider.TypeXAI: + case catwalk.TypeXAI: clientOptions.baseURL = "https://api.x.ai/v1" return &baseProvider[OpenAIClient]{ options: clientOptions, diff --git a/internal/message/content.go b/internal/message/content.go index bdaf1577e34a4667bdb5c8cd2683865ec5cd08ac..b3f212187c86fb57667d95943fd15b8c6e3cccdb 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -5,7 +5,7 @@ import ( "slices" "time" - "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/catwalk/pkg/catwalk" ) type MessageRole string @@ -74,9 +74,9 @@ type BinaryContent struct { Data []byte } -func (bc BinaryContent) String(p provider.InferenceProvider) string { +func (bc BinaryContent) String(p catwalk.InferenceProvider) string { base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) - if p == provider.InferenceProviderOpenAI { + if p == catwalk.InferenceProviderOpenAI { return "data:" + bc.MIMEType + ";base64," + base64Encoded } return base64Encoded diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index 2ffa1601f84fcb9028faf67bd94d70920a193864..d5aca88108cad83115cad5bd046c72e146935f78 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -8,11 +8,11 @@ import ( "github.com/charmbracelet/bubbles/v2/viewport" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/x/ansi" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/tui/components/anim" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -369,11 +369,11 @@ func (m *assistantSectionModel) View() string { model := config.Get().GetModel(m.message.Provider, m.message.Model) if model == nil { // This means the model is not configured anymore - model = &provider.Model{ - Model: "Unknown Model", + model = &catwalk.Model{ + Name: "Unknown Model", } } - modelFormatted := t.S().Muted.Render(model.Model) + modelFormatted := t.S().Muted.Render(model.Name) assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg) return t.S().Base.PaddingLeft(2).Render( core.Section(assistant, m.width-2), diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 3d9e572b5192354bd97fd6274c482057646ad41c..1aa239bdc15cec6898a4cba1e4dc7a867b5e4ce0 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -9,10 +9,10 @@ import ( "sync" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" @@ -897,7 +897,7 @@ func (s *sidebarCmp) currentModelBlock() string { t := styles.CurrentTheme() modelIcon := t.S().Base.Foreground(t.FgSubtle).Render(styles.ModelIcon) - modelName := t.S().Text.Render(model.Model) + modelName := t.S().Text.Render(model.Name) modelInfo := fmt.Sprintf("%s %s", modelIcon, modelName) parts := []string{ modelInfo, @@ -905,14 +905,14 @@ func (s *sidebarCmp) currentModelBlock() string { if model.CanReason { reasoningInfoStyle := t.S().Subtle.PaddingLeft(2) switch modelProvider.Type { - case provider.TypeOpenAI: + case catwalk.TypeOpenAI: reasoningEffort := model.DefaultReasoningEffort if selectedModel.ReasoningEffort != "" { reasoningEffort = selectedModel.ReasoningEffort } formatter := cases.Title(language.English, cases.NoLower) parts = append(parts, reasoningInfoStyle.Render(formatter.String(fmt.Sprintf("Reasoning %s", reasoningEffort)))) - case provider.TypeAnthropic: + case catwalk.TypeAnthropic: formatter := cases.Title(language.English, cases.NoLower) if selectedModel.Think { parts = append(parts, reasoningInfoStyle.Render(formatter.String("Thinking on"))) diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index f7a6dce4baa2c3a2798c30baa6b995f6da72d05b..a7cf0b27de678dc63d1f9058a5d9d6bd3957d2ae 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -10,8 +10,8 @@ import ( "github.com/charmbracelet/bubbles/v2/key" "github.com/charmbracelet/bubbles/v2/spinner" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/completions" @@ -109,7 +109,7 @@ func (s *splashCmp) SetOnboarding(onboarding bool) { if err != nil { return } - filteredProviders := []provider.Provider{} + filteredProviders := []catwalk.Provider{} simpleProviders := []string{ "anthropic", "openai", @@ -407,7 +407,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { return nil } -func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) { +func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { providers, err := config.Providers() if err != nil { return nil, err diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index a14138ff51ecf8164cf0fc595c758b0247aa3277..c1b96f0bac7d0b665aad77794392b7417d60457a 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -4,10 +4,10 @@ import ( "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/completions" @@ -270,7 +270,7 @@ func (c *commandDialogCmp) defaultCommands() []Command { providerCfg := cfg.GetProviderForModel(agentCfg.Model) model := cfg.GetModelByType(agentCfg.Model) if providerCfg != nil && model != nil && - providerCfg.Type == provider.TypeAnthropic && model.CanReason { + providerCfg.Type == catwalk.TypeAnthropic && model.CanReason { selectedModel := cfg.Models[agentCfg.Model] status := "Enable" if selectedModel.Think { diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 86b1b9a3fa0b4b6faa56a927a9011673aa8365af..13051067413379b7b80968ca4d8eec4bc354d893 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -5,8 +5,8 @@ import ( "slices" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core/list" "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands" @@ -18,7 +18,7 @@ import ( type ModelListComponent struct { list list.ListModel modelType int - providers []provider.Provider + providers []catwalk.Provider } func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent { @@ -109,19 +109,19 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } // Check if this provider is not in the known providers list - if !slices.ContainsFunc(knownProviders, func(p provider.Provider) bool { return p.ID == provider.InferenceProvider(providerID) }) { + if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) { // Convert config provider to provider.Provider format - configProvider := provider.Provider{ + configProvider := catwalk.Provider{ Name: providerConfig.Name, - ID: provider.InferenceProvider(providerID), - Models: make([]provider.Model, len(providerConfig.Models)), + ID: catwalk.InferenceProvider(providerID), + Models: make([]catwalk.Model, len(providerConfig.Models)), } // Convert models for i, model := range providerConfig.Models { - configProvider.Models[i] = provider.Model{ + configProvider.Models[i] = catwalk.Model{ ID: model.ID, - Model: model.Model, + Name: model.Name, CostPer1MIn: model.CostPer1MIn, CostPer1MOut: model.CostPer1MOut, CostPer1MInCached: model.CostPer1MInCached, @@ -144,7 +144,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { section.SetInfo(configured) modelItems = append(modelItems, section) for _, model := range configProvider.Models { - modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{ + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ Provider: configProvider, Model: model, })) @@ -179,7 +179,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } modelItems = append(modelItems, section) for _, model := range provider.Models { - modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{ + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ Provider: provider, Model: model, })) @@ -201,6 +201,6 @@ func (m *ModelListComponent) SetInputPlaceholder(placeholder string) { m.list.SetFilterPlaceholder(placeholder) } -func (m *ModelListComponent) SetProviders(providers []provider.Provider) { +func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) { m.providers = providers } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index b28efc6010582a503c34e87ad101832925d8acca..eb0ed9eebcb5ebce41eff33ab09f7c0e5b995bde 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -8,8 +8,8 @@ import ( "github.com/charmbracelet/bubbles/v2/key" "github.com/charmbracelet/bubbles/v2/spinner" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/list" @@ -48,8 +48,8 @@ type ModelDialog interface { } type ModelOption struct { - Provider provider.Provider - Model provider.Model + Provider catwalk.Provider + Model catwalk.Model } type modelDialogCmp struct { @@ -363,7 +363,7 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { return false } -func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) { +func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { providers, err := config.Providers() if err != nil { return nil, err diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 0d28f13f3ca0a42c9ae15612f21678cdeb8f4bf2..9deac1e9e48c1cff576e84746d3976b4b670a700 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -279,7 +279,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if model.SupportsImages { return p, util.CmdHandler(OpenFilePickerMsg{}) } else { - return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Model) + return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name) } case key.Matches(msg, p.keyMap.Tab): if p.session.ID == "" { From 9d77028fc90b4cf441654e54a5b48ffe916c3c96 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 24 Jul 2025 09:39:58 +0200 Subject: [PATCH 3/3] chore: update readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b21dae25fb7f32169e0fdd9528b3ec06f5c739f0..8c0900d173783918d49b3f5c33ce08bd54f523dd 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ rm -rf ./crush Then, run Crush by typing `crush`. -*** +--- @@ -108,7 +108,7 @@ Crush supports Model Context Protocol (MCP) servers through three transport type "mcp": { "filesystem": { "type": "stdio", - "command": "node", + "command": "node", "args": ["/path/to/mcp-server.js"], "env": { "NODE_ENV": "production" @@ -143,7 +143,7 @@ crush -d # View last 1000 lines crush logs -# Follow logs in real-time +# Follow logs in real-time crush logs -f # Show last 500 lines @@ -174,7 +174,7 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D "models": [ { "id": "deepseek-chat", - "model": "Deepseek V3", + "name": "Deepseek V3", "cost_per_1m_in": 0.27, "cost_per_1m_out": 1.1, "cost_per_1m_in_cached": 0.07,