From eab7e3d15bd7f909c516fdc847da2dc0480779e7 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 28 Jul 2025 17:33:20 +0200 Subject: [PATCH] chore: remove global config --- {internal/cmd => cmd}/logs.go | 0 {internal/cmd => cmd}/root.go | 2 +- internal/app/app.go | 33 +- internal/config/config.go | 235 +--------- internal/config/init.go | 31 +- internal/config/load.go | 28 +- internal/config/load_test.go | 119 ++--- internal/csync/slices.go | 26 ++ internal/llm/agent/agent.go | 418 ++++++++---------- internal/llm/agent/coder.go | 52 +++ internal/llm/agent/{mcp-tools.go => mcp.go} | 97 +++- internal/llm/agent/tools.go | 56 +++ internal/llm/prompt/coder.go | 27 +- internal/llm/prompt/prompt.go | 46 +- internal/llm/prompt/task.go | 4 +- internal/llm/provider/anthropic.go | 109 ++--- internal/llm/provider/azure.go | 22 +- internal/llm/provider/bedrock.go | 86 ++-- internal/llm/provider/gemini.go | 100 ++--- internal/llm/provider/openai.go | 97 ++-- internal/llm/provider/openai_test.go | 90 ---- internal/llm/provider/provider.go | 268 +++++++---- internal/llm/provider/vertexai.go | 14 +- internal/llm/tools/tools.go | 51 +++ internal/resolver/resolver.go | 188 ++++++++ internal/resolver/resolver_test.go | 332 ++++++++++++++ internal/tui/components/chat/chat.go | 3 +- internal/tui/components/chat/header/header.go | 18 +- .../tui/components/chat/messages/messages.go | 6 +- .../tui/components/chat/sidebar/sidebar.go | 48 +- internal/tui/components/chat/splash/splash.go | 60 +-- .../components/dialogs/commands/commands.go | 17 +- .../tui/components/dialogs/commands/loader.go | 3 +- .../tui/components/dialogs/models/list.go | 18 +- .../tui/components/dialogs/models/models.go | 32 +- internal/tui/page/chat/chat.go | 26 +- internal/tui/tui.go | 12 +- main.go | 2 +- 38 files changed, 1573 insertions(+), 1203 deletions(-) rename {internal/cmd => cmd}/logs.go (100%) rename {internal/cmd => cmd}/root.go (99%) create mode 100644 internal/llm/agent/coder.go rename internal/llm/agent/{mcp-tools.go => mcp.go} (70%) create mode 100644 internal/llm/agent/tools.go delete mode 100644 internal/llm/provider/openai_test.go create mode 100644 internal/resolver/resolver.go create mode 100644 internal/resolver/resolver_test.go diff --git a/internal/cmd/logs.go b/cmd/logs.go similarity index 100% rename from internal/cmd/logs.go rename to cmd/logs.go diff --git a/internal/cmd/root.go b/cmd/root.go similarity index 99% rename from internal/cmd/root.go rename to cmd/root.go index c6c24d5963c57981b1e91911146c1893728ffe37..b2a09fc3abd868734877601a4ae7ea71f34bdfae 100644 --- a/internal/cmd/root.go +++ b/cmd/root.go @@ -69,7 +69,7 @@ to assist developers in writing, debugging, and understanding code directly from cwd = c } - cfg, err := config.Init(cwd, debug) + cfg, err := config.Load(cwd, debug) if err != nil { return err } diff --git a/internal/app/app.go b/internal/app/app.go index f3362c7276389b6669d6c9977d3565f482a44062..7af22cdce94dc9e9d12ff022dbdbed2d9e860be0 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -17,12 +17,12 @@ import ( "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/llm/provider" "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/pubsub" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" ) @@ -196,7 +196,9 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool } func (app *App) UpdateAgentModel() error { - return app.CoderAgent.UpdateModel() + small := app.config.Models[config.SelectedModelTypeSmall] + large := app.config.Models[config.SelectedModelTypeLarge] + return app.CoderAgent.UpdateModels(small, large) } func (app *App) setupEvents() { @@ -250,23 +252,32 @@ func setupSubscriber[T any]( } func (app *App) InitCoderAgent() error { - coderAgentCfg := app.config.Agents["coder"] - if coderAgentCfg.ID == "" { - return fmt.Errorf("coder agent configuration is missing") - } var err error - app.CoderAgent, err = agent.NewAgent( - coderAgentCfg, - app.Permissions, + providers := map[string]provider.Config{} + maps.Insert(providers, app.config.Providers.Seq2()) + app.CoderAgent, err = agent.NewCoderAgent( + app.globalCtx, + app.config.WorkingDir(), + providers, + app.config.Models[config.SelectedModelTypeSmall], + app.config.Models[config.SelectedModelTypeLarge], + app.config.Options.ContextPaths, app.Sessions, app.Messages, - app.History, + app.Permissions, app.LSPClients, + app.History, + app.config.MCP, ) if err != nil { slog.Error("Failed to create coder agent", "err", err) return err } + err = app.CoderAgent.WithAgentTool() + if err != nil { + slog.Error("Failed to create agent tool", "err", err) + return err + } setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events) return nil } diff --git a/internal/config/config.go b/internal/config/config.go index 0f9fc99b5ce7677b0009933c447c0f7959825501..729266d5e0f2042cef82196abc8c0ff27a9da3eb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,18 +1,16 @@ package config import ( - "context" "fmt" - "log/slog" - "net/http" "os" "slices" "strings" - "time" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/llm/provider" + "github.com/charmbracelet/crush/internal/resolver" "github.com/tidwall/sjson" ) @@ -45,73 +43,6 @@ const ( SelectedModelTypeSmall SelectedModelType = "small" ) -type SelectedModel struct { - // The model id as used by the provider API. - // Required. - Model string `json:"model"` - // The model provider, same as the key/id used in the providers config. - // Required. - Provider string `json:"provider"` - - // Only used by models that use the openai provider and need this set. - ReasoningEffort string `json:"reasoning_effort,omitempty"` - - // Overrides the default model configuration. - MaxTokens int64 `json:"max_tokens,omitempty"` - - // Used by anthropic models that can reason to indicate if the model should think. - Think bool `json:"think,omitempty"` -} - -type ProviderConfig struct { - // The provider's id. - ID string `json:"id,omitempty"` - // The provider's name, used for display purposes. - Name string `json:"name,omitempty"` - // The provider's API endpoint. - BaseURL string `json:"base_url,omitempty"` - // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. - Type catwalk.Type `json:"type,omitempty"` - // The provider's API key. - APIKey string `json:"api_key,omitempty"` - // Marks the provider as disabled. - Disable bool `json:"disable,omitempty"` - - // Custom system prompt prefix. - SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"` - - // Extra headers to send with each request to the provider. - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` - // Extra body - ExtraBody map[string]any `json:"extra_body,omitempty"` - - // Used to pass extra parameters to the provider. - ExtraParams map[string]string `json:"-"` - - // The provider models - Models []catwalk.Model `json:"models,omitempty"` -} - -type MCPType string - -const ( - MCPStdio MCPType = "stdio" - MCPSse MCPType = "sse" - MCPHttp MCPType = "http" -) - -type MCPConfig struct { - Command string `json:"command,omitempty" ` - Env map[string]string `json:"env,omitempty"` - Args []string `json:"args,omitempty"` - Type MCPType `json:"type"` - URL string `json:"url,omitempty"` - Disabled bool `json:"disabled,omitempty"` - - // TODO: maybe make it possible to get the value from the env - Headers map[string]string `json:"headers,omitempty"` -} - type LSPConfig struct { Disabled bool `json:"enabled,omitempty"` Command string `json:"command"` @@ -138,11 +69,11 @@ type Options struct { DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd } -type MCPs map[string]MCPConfig +type MCPs map[string]agent.MCPConfig type MCP struct { - Name string `json:"name"` - MCP MCPConfig `json:"mcp"` + Name string `json:"name"` + MCP agent.MCPConfig `json:"mcp"` } func (m MCPs) Sorted() []MCP { @@ -180,71 +111,13 @@ func (l LSPs) Sorted() []LSP { return sorted } -func (m MCPConfig) ResolvedEnv() []string { - resolver := NewShellVariableResolver(env.New()) - for e, v := range m.Env { - var err error - m.Env[e], err = resolver.ResolveValue(v) - if err != nil { - slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v) - continue - } - } - - env := make([]string, 0, len(m.Env)) - for k, v := range m.Env { - env = append(env, fmt.Sprintf("%s=%s", k, v)) - } - return env -} - -func (m MCPConfig) ResolvedHeaders() map[string]string { - resolver := NewShellVariableResolver(env.New()) - for e, v := range m.Headers { - var err error - m.Headers[e], err = resolver.ResolveValue(v) - if err != nil { - slog.Error("error resolving header variable", "error", err, "variable", e, "value", v) - continue - } - } - return m.Headers -} - -type Agent struct { - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` - // This is the id of the system prompt used by the agent - Disabled bool `json:"disabled,omitempty"` - - Model SelectedModelType `json:"model"` - - // The available tools for the agent - // if this is nil, all tools are available - AllowedTools []string `json:"allowed_tools,omitempty"` - - // this tells us which MCPs are available for this agent - // if this is empty all mcps are available - // the string array is the list of tools from the AllowedMCP the agent has available - // if the string array is nil, all tools from the AllowedMCP are available - AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"` - - // The list of LSPs that this agent can use - // if this is nil, all LSPs are available - AllowedLSP []string `json:"allowed_lsp,omitempty"` - - // Overrides the context paths for this agent - ContextPaths []string `json:"context_paths,omitempty"` -} - // Config holds the configuration for crush. type Config struct { // We currently only support large/small as values here. - Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` + Models map[SelectedModelType]agent.Model `json:"models,omitempty"` // The providers that are configured - Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"` + Providers *csync.Map[string, provider.Config] `json:"providers,omitempty"` MCP MCPs `json:"mcp,omitempty"` @@ -256,10 +129,8 @@ type Config struct { // Internal workingDir string `json:"-"` - // TODO: most likely remove this concept when I come back to it - Agents map[string]Agent `json:"-"` // TODO: find a better way to do this this should probably not be part of the config - resolver VariableResolver + resolver resolver.Resolver dataConfigDir string `json:"-"` knownProviders []catwalk.Provider `json:"-"` } @@ -268,8 +139,8 @@ func (c *Config) WorkingDir() string { return c.workingDir } -func (c *Config) EnabledProviders() []ProviderConfig { - var enabled []ProviderConfig +func (c *Config) EnabledProviders() []provider.Config { + var enabled []provider.Config for p := range c.Providers.Seq() { if !p.Disable { enabled = append(enabled, p) @@ -294,7 +165,7 @@ func (c *Config) GetModel(provider, model string) *catwalk.Model { return nil } -func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig { +func (c *Config) GetProviderForModel(modelType SelectedModelType) *provider.Config { model, ok := c.Models[modelType] if !ok { return nil @@ -344,7 +215,7 @@ func (c *Config) Resolve(key string) (string, error) { return c.resolver.ResolveValue(key) } -func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error { +func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model agent.Model) error { c.Models[modelType] = model if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil { return fmt.Errorf("failed to update preferred model: %w", err) @@ -397,7 +268,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { if foundProvider != nil { // Create new provider config based on known provider - providerConfig = ProviderConfig{ + providerConfig = provider.Config{ ID: providerID, Name: foundProvider.Name, BaseURL: foundProvider.APIEndpoint, @@ -416,82 +287,6 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return nil } -func (c *Config) SetupAgents() { - agents := map[string]Agent{ - "coder": { - ID: "coder", - Name: "Coder", - Description: "An agent that helps with executing coding tasks.", - Model: SelectedModelTypeLarge, - ContextPaths: c.Options.ContextPaths, - // All tools allowed - }, - "task": { - ID: "task", - Name: "Task", - Description: "An agent that helps with searching for context and finding implementation details.", - Model: SelectedModelTypeLarge, - ContextPaths: c.Options.ContextPaths, - AllowedTools: []string{ - "glob", - "grep", - "ls", - "sourcegraph", - "view", - }, - // NO MCPs or LSPs by default - AllowedMCP: map[string][]string{}, - AllowedLSP: []string{}, - }, - } - c.Agents = agents -} - -func (c *Config) Resolver() VariableResolver { +func (c *Config) Resolver() resolver.Resolver { return c.resolver } - -func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { - testURL := "" - headers := make(map[string]string) - apiKey, _ := resolver.ResolveValue(c.APIKey) - switch c.Type { - case catwalk.TypeOpenAI: - baseURL, _ := resolver.ResolveValue(c.BaseURL) - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - testURL = baseURL + "/models" - headers["Authorization"] = "Bearer " + apiKey - case catwalk.TypeAnthropic: - baseURL, _ := resolver.ResolveValue(c.BaseURL) - if baseURL == "" { - baseURL = "https://api.anthropic.com/v1" - } - testURL = baseURL + "/models" - headers["x-api-key"] = apiKey - headers["anthropic-version"] = "2023-06-01" - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - client := &http.Client{} - req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil) - if err != nil { - return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) - } - for k, v := range headers { - req.Header.Set(k, v) - } - for k, v := range c.ExtraHeaders { - req.Header.Set(k, v) - } - b, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) - } - if b.StatusCode != http.StatusOK { - return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status) - } - _ = b.Body.Close() - return nil -} diff --git a/internal/config/init.go b/internal/config/init.go index ff44d43bb878f579d003c84537fcd970f9e52f9e..b887e062e5f652fb1ffc59f327faef4756bd47aa 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -5,7 +5,6 @@ import ( "os" "path/filepath" "strings" - "sync/atomic" ) const ( @@ -16,25 +15,7 @@ type ProjectInitFlag struct { Initialized bool `json:"initialized"` } -// TODO: we need to remove the global config instance keeping it now just until everything is migrated -var instance atomic.Pointer[Config] - -func Init(workingDir string, debug bool) (*Config, error) { - cfg, err := Load(workingDir, debug) - if err != nil { - return nil, err - } - instance.Store(cfg) - return instance.Load(), nil -} - -func Get() *Config { - cfg := instance.Load() - return cfg -} - -func ProjectNeedsInitialization() (bool, error) { - cfg := Get() +func ProjectNeedsInitialization(cfg *Config) (bool, error) { if cfg == nil { return false, fmt.Errorf("config not loaded") } @@ -81,8 +62,7 @@ func crushMdExists(dir string) (bool, error) { return false, nil } -func MarkProjectInitialized() error { - cfg := Get() +func MarkProjectInitialized(cfg *Config) error { if cfg == nil { return fmt.Errorf("config not loaded") } @@ -97,10 +77,13 @@ func MarkProjectInitialized() error { return nil } -func HasInitialDataConfig() bool { +func HasInitialDataConfig(cfg *Config) bool { + if cfg == nil { + return false + } cfgPath := GlobalConfigData() if _, err := os.Stat(cfgPath); err != nil { return false } - return Get().IsConfigured() + return cfg.IsConfigured() } diff --git a/internal/config/load.go b/internal/config/load.go index 77f53356b1e529cb5592366e1f2f3a8d757a315f..3ed199fd31713ac1f8b1eaf51dfa348621c78bb1 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -14,7 +14,10 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/llm/provider" "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/resolver" ) const defaultCatwalkURL = "https://catwalk.charm.sh" @@ -71,7 +74,7 @@ func Load(workingDir string, debug bool) (*Config, error) { env := env.New() // Configure providers - valueResolver := NewShellVariableResolver(env) + valueResolver := resolver.NewShellVariableResolver(env) cfg.resolver = valueResolver if err := cfg.configureProviders(env, valueResolver, providers); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) @@ -85,11 +88,10 @@ func Load(workingDir string, debug bool) (*Config, error) { if err := cfg.configureSelectedModels(providers); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } - cfg.SetupAgents() return cfg, nil } -func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { +func (c *Config) configureProviders(env env.Env, resolver resolver.Resolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true @@ -135,7 +137,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know p.Models = models } } - prepared := ProviderConfig{ + prepared := provider.Config{ ID: string(p.ID), Name: p.Name, BaseURL: p.APIEndpoint, @@ -269,13 +271,13 @@ func (c *Config) setDefaults(workingDir string) { c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory) } if c.Providers == nil { - c.Providers = csync.NewMap[string, ProviderConfig]() + c.Providers = csync.NewMap[string, provider.Config]() } if c.Models == nil { - c.Models = make(map[SelectedModelType]SelectedModel) + c.Models = make(map[SelectedModelType]agent.Model) } if c.MCP == nil { - c.MCP = make(map[string]MCPConfig) + c.MCP = make(map[string]agent.MCPConfig) } if c.LSP == nil { c.LSP = make(map[string]LSPConfig) @@ -287,7 +289,7 @@ func (c *Config) setDefaults(workingDir string) { c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths) } -func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { +func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel, smallModel agent.Model, err error) { if len(knownProviders) == 0 && c.Providers.Len() == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return @@ -305,7 +307,7 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID) return } - largeModel = SelectedModel{ + largeModel = agent.Model{ Provider: string(p.ID), Model: defaultLargeModel.ID, MaxTokens: defaultLargeModel.DefaultMaxTokens, @@ -317,7 +319,7 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID) return } - smallModel = SelectedModel{ + smallModel = agent.Model{ Provider: string(p.ID), Model: defaultSmallModel.ID, MaxTokens: defaultSmallModel.DefaultMaxTokens, @@ -327,7 +329,7 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large } enabledProviders := c.EnabledProviders() - slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int { + slices.SortFunc(enabledProviders, func(a, b provider.Config) int { return strings.Compare(a.ID, b.ID) }) @@ -342,13 +344,13 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large return } defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID) - largeModel = SelectedModel{ + largeModel = agent.Model{ Provider: providerConfig.ID, Model: defaultLargeModel.ID, MaxTokens: defaultLargeModel.DefaultMaxTokens, } defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID) - smallModel = SelectedModel{ + smallModel = agent.Model{ Provider: providerConfig.ID, Model: defaultSmallModel.ID, MaxTokens: defaultSmallModel.DefaultMaxTokens, diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 8c2735bd15fb3b52fe0c87401f57534e9b007e5b..4a04c003249542b453fecb524c42355128703e0d 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -11,6 +11,9 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/llm/provider" + "github.com/charmbracelet/crush/internal/resolver" "github.com/stretchr/testify/assert" ) @@ -72,7 +75,7 @@ func TestConfig_configureProviders(t *testing.T) { env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Equal(t, 1, cfg.Providers.Len()) @@ -95,9 +98,9 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMap[string, ProviderConfig](), + Providers: csync.NewMap[string, provider.Config](), } - cfg.Providers.Set("openai", ProviderConfig{ + cfg.Providers.Set("openai", provider.Config{ APIKey: "xyz", BaseURL: "https://api.openai.com/v2", Models: []catwalk.Model{ @@ -115,7 +118,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Equal(t, 1, cfg.Providers.Len()) @@ -141,7 +144,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", @@ -157,7 +160,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Should be to because of the env variable @@ -193,7 +196,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 1) @@ -219,7 +222,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without credentials @@ -244,7 +247,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.Error(t, err) } @@ -268,7 +271,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { "GOOGLE_CLOUD_PROJECT": "test-project", "GOOGLE_CLOUD_LOCATION": "us-central1", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 1) @@ -300,7 +303,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { "GOOGLE_CLOUD_PROJECT": "test-project", "GOOGLE_CLOUD_LOCATION": "us-central1", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without proper credentials @@ -325,7 +328,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { "GOOGLE_GENAI_USE_VERTEXAI": "true", "GOOGLE_CLOUD_LOCATION": "us-central1", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without project @@ -349,7 +352,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) assert.Equal(t, cfg.Providers.Len(), 1) @@ -362,7 +365,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { func TestConfig_EnabledProviders(t *testing.T) { t.Run("all providers enabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "openai": { ID: "openai", APIKey: "key1", @@ -382,7 +385,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("some providers disabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "openai": { ID: "openai", APIKey: "key1", @@ -403,7 +406,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("empty providers map", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMap[string, ProviderConfig](), + Providers: csync.NewMap[string, provider.Config](), } enabled := cfg.EnabledProviders() @@ -414,7 +417,7 @@ func TestConfig_EnabledProviders(t *testing.T) { func TestConfig_IsConfigured(t *testing.T) { t.Run("returns true when at least one provider is enabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "openai": { ID: "openai", APIKey: "key1", @@ -428,7 +431,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when no providers are configured", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMap[string, ProviderConfig](), + Providers: csync.NewMap[string, provider.Config](), } assert.False(t, cfg.IsConfigured()) @@ -436,7 +439,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when all providers are disabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "openai": { ID: "openai", APIKey: "key1", @@ -467,7 +470,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "openai": { Disable: true, }, @@ -478,7 +481,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -491,7 +494,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{{ @@ -506,7 +509,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) @@ -517,7 +520,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", Models: []catwalk.Model{{ @@ -529,7 +532,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) @@ -540,7 +543,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with no models is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -551,7 +554,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) @@ -562,7 +565,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with unsupported type is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -576,7 +579,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) @@ -587,7 +590,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -601,7 +604,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) @@ -615,7 +618,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom anthropic provider is supported", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom-anthropic": { APIKey: "test-key", BaseURL: "https://api.anthropic.com/v1", @@ -629,7 +632,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) @@ -644,7 +647,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -659,7 +662,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) @@ -683,7 +686,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "vertexai": { BaseURL: "custom-url", }, @@ -694,7 +697,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -716,7 +719,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "bedrock": { BaseURL: "custom-url", }, @@ -725,7 +728,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -747,7 +750,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "openai": { BaseURL: "custom-url", }, @@ -756,7 +759,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -778,7 +781,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "openai": { APIKey: "test-key", }, @@ -789,7 +792,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -823,7 +826,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -859,7 +862,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -889,7 +892,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) @@ -917,7 +920,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -932,7 +935,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) @@ -966,7 +969,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -976,7 +979,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) @@ -1003,7 +1006,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]provider.Config{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -1018,7 +1021,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) @@ -1058,7 +1061,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { } cfg := &Config{ - Models: map[SelectedModelType]SelectedModel{ + Models: map[SelectedModelType]agent.Model{ "large": { Model: "larger-model", }, @@ -1066,7 +1069,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -1118,7 +1121,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { } cfg := &Config{ - Models: map[SelectedModelType]SelectedModel{ + Models: map[SelectedModelType]agent.Model{ "small": { Model: "a-small-model", Provider: "anthropic", @@ -1128,7 +1131,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) @@ -1165,7 +1168,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { } cfg := &Config{ - Models: map[SelectedModelType]SelectedModel{ + Models: map[SelectedModelType]agent.Model{ "large": { MaxTokens: 100, }, @@ -1173,7 +1176,7 @@ func TestConfig_configureSelectedModels(t *testing.T) { } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) - resolver := NewEnvironmentVariableResolver(env) + resolver := resolver.NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) diff --git a/internal/csync/slices.go b/internal/csync/slices.go index 3913a054c166c2bd29b3fafb7e6a0fa1998463a8..e71345b6806ae442957eaa18750ca096a9832d2c 100644 --- a/internal/csync/slices.go +++ b/internal/csync/slices.go @@ -36,6 +36,32 @@ func (s *LazySlice[K]) Seq() iter.Seq[K] { } } +func (s *LazySlice[K]) Seq2() iter.Seq2[int, K] { + s.wg.Wait() + return func(yield func(int, K) bool) { + for i, v := range s.inner { + if !yield(i, v) { + return + } + } + } +} + +func (s *LazySlice[K]) Set(index int, item K) bool { + s.wg.Wait() + if index < 0 || index >= len(s.inner) { + return false + } + s.inner[index] = item + return true +} + +func (s *LazySlice[K]) Append(item K) bool { + s.wg.Wait() + s.inner = append(s.inner, item) + return true +} + // Slice is a thread-safe slice implementation that provides concurrent access. type Slice[T any] struct { inner []T diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 17a67f810b335f1dad105321a0bb0a8b354c9bfc..89726b4d7560be055501baf57cf459600e595b4f 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -5,19 +5,15 @@ import ( "errors" "fmt" "log/slog" - "slices" "strings" "time" "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" @@ -52,200 +48,126 @@ type AgentEvent struct { type Service interface { pubsub.Suscriber[AgentEvent] - Model() catwalk.Model Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool IsBusy() bool Summarize(ctx context.Context, sessionID string) error - UpdateModel() error + SetDebug(debug bool) + UpdateModels(large, small Model) error + // for now, not really sure how to handle this better + WithAgentTool() error + + ModelConfig() Model + Model() *catwalk.Model + Provider() *provider.Config +} + +type Model struct { + // The model id as used by the provider API. + // Required. + Model string `json:"model"` + // The model provider, same as the key/id used in the providers config. + // Required. + Provider string `json:"provider"` + + // Only used by models that use the openai provider and need this set. + ReasoningEffort string `json:"reasoning_effort,omitempty"` + + // Overrides the default model configuration. + MaxTokens int64 `json:"max_tokens,omitempty"` + + // Used by anthropic models that can reason to indicate if the model should think. + Think bool `json:"think,omitempty"` } type agent struct { *pubsub.Broker[AgentEvent] - agentCfg config.Agent + ctx context.Context + cwd string + systemPrompt string + providers map[string]provider.Config + sessions session.Service messages message.Service - tools *csync.LazySlice[tools.BaseTool] + toolsRegistry tools.Registry - provider provider.Provider - providerID string - - titleProvider provider.Provider - summarizeProvider provider.Provider - summarizeProviderID string + large, small Model + provider provider.Provider + titleProvider provider.Provider + summarizeProvider provider.Provider activeRequests *csync.Map[string, context.CancelFunc] -} -var agentPromptMap = map[string]prompt.PromptID{ - "coder": prompt.PromptCoder, - "task": prompt.PromptTask, + debug bool } func NewAgent( - agentCfg config.Agent, - // These services are needed in the tools - permissions permission.Service, + ctx context.Context, + cwd string, + systemPrompt string, + toolsRegistry tools.Registry, + providers map[string]provider.Config, + + smallModel Model, + largeModel Model, + sessions session.Service, messages message.Service, - history history.Service, - lspClients map[string]*lsp.Client, ) (Service, error) { - ctx := context.Background() - cfg := config.Get() - - var agentTool tools.BaseTool - if agentCfg.ID == "coder" { - taskAgentCfg := config.Get().Agents["task"] - if taskAgentCfg.ID == "" { - return nil, fmt.Errorf("task agent not found in config") - } - taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) - if err != nil { - return nil, fmt.Errorf("failed to create task agent: %w", err) - } - - agentTool = NewAgentTool(taskAgent, sessions, messages) - } - - providerCfg := config.Get().GetProviderForModel(agentCfg.Model) - if providerCfg == nil { - return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) - } - model := config.Get().GetModelByType(agentCfg.Model) - - if model == nil { - return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name) - } + agent := &agent{ + Broker: pubsub.NewBroker[AgentEvent](), + ctx: ctx, + providers: providers, + cwd: cwd, + systemPrompt: systemPrompt, + toolsRegistry: toolsRegistry, + small: smallModel, + large: largeModel, + messages: messages, + sessions: sessions, + activeRequests: csync.NewMap[string, context.CancelFunc](), + } + + err := agent.setProviders() + return agent, err +} - promptID := agentPromptMap[agentCfg.ID] - if promptID == "" { - promptID = prompt.PromptDefault - } - opts := []provider.ProviderClientOption{ - provider.WithModel(agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)), - } - agentProvider, err := provider.NewProvider(*providerCfg, opts...) - if err != nil { - return nil, err - } +func (a *agent) ModelConfig() Model { + return a.large +} - smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] - var smallModelProviderCfg *config.ProviderConfig - if smallModelCfg.Provider == providerCfg.ID { - smallModelProviderCfg = providerCfg - } else { - smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall) +func (a *agent) Model() *catwalk.Model { + return a.provider.Model(a.large.Model) +} - if smallModelProviderCfg.ID == "" { - return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) +func (a *agent) Provider() *provider.Config { + for _, provider := range a.providers { + if provider.ID == a.large.Provider { + return &provider } } - smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall) - if smallModel.ID == "" { - return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID) - } - - titleOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - } - titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...) - if err != nil { - return nil, err - } - summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), - } - summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...) - if err != nil { - return nil, err - } - - toolFn := func() []tools.BaseTool { - slog.Info("Initializing agent tools", "agent", agentCfg.ID) - defer func() { - slog.Info("Initialized agent tools", "agent", agentCfg.ID) - }() - - cwd := cfg.WorkingDir() - allTools := []tools.BaseTool{ - tools.NewBashTool(permissions, cwd), - tools.NewDownloadTool(permissions, cwd), - tools.NewEditTool(lspClients, permissions, history, cwd), - tools.NewFetchTool(permissions, cwd), - tools.NewGlobTool(cwd), - tools.NewGrepTool(cwd), - tools.NewLsTool(cwd), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients, cwd), - tools.NewWriteTool(lspClients, permissions, history, cwd), - } - - mcpTools := GetMCPTools(ctx, permissions, cfg) - allTools = append(allTools, mcpTools...) - - if len(lspClients) > 0 { - allTools = append(allTools, tools.NewDiagnosticsTool(lspClients)) - } - - if agentTool != nil { - allTools = append(allTools, agentTool) - } - - if agentCfg.AllowedTools == nil { - return allTools - } - - var filteredTools []tools.BaseTool - for _, tool := range allTools { - if slices.Contains(agentCfg.AllowedTools, tool.Name()) { - filteredTools = append(filteredTools, tool) - } - } - return filteredTools - } - - return &agent{ - Broker: pubsub.NewBroker[AgentEvent](), - agentCfg: agentCfg, - provider: agentProvider, - providerID: string(providerCfg.ID), - messages: messages, - sessions: sessions, - titleProvider: titleProvider, - summarizeProvider: summarizeProvider, - summarizeProviderID: string(smallModelProviderCfg.ID), - activeRequests: csync.NewMap[string, context.CancelFunc](), - tools: csync.NewLazySlice(toolFn), - }, nil -} - -func (a *agent) Model() catwalk.Model { - return *config.Get().GetModelByType(a.agentCfg.Model) + return nil } func (a *agent) Cancel(sessionID string) { // Cancel regular requests - if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil { + if cancel, exists := a.activeRequests.Take(sessionID); exists { slog.Info("Request cancellation initiated", "session_id", sessionID) cancel() } // Also check for summarize requests - if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil { + if cancel, exists := a.activeRequests.Take(sessionID + "-summarize"); exists { slog.Info("Summarize cancellation initiated", "session_id", sessionID) cancel() } } func (a *agent) IsBusy() bool { - var busy bool + busy := false for cancelFunc := range a.activeRequests.Seq() { if cancelFunc != nil { busy = true @@ -275,9 +197,9 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content), }} - // Use streaming approach like summarization - response := a.titleProvider.StreamResponse( + response := a.titleProvider.Stream( ctx, + a.small.Model, []message.Message{ { Role: message.User, @@ -352,7 +274,6 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac } func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent { - cfg := config.Get() // List existing messages; if none, start title generation asynchronously. msgs, err := a.messages.List(ctx, sessionID) if err != nil { @@ -411,7 +332,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } return a.err(fmt.Errorf("failed to process events: %w", err)) } - if cfg.Options.Debug { + if a.debug { slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) } if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { @@ -438,13 +359,13 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq())) + eventChan := a.provider.Stream(ctx, a.large.Model, msgHistory, a.toolsRegistry.GetAllTools()) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: a.Model().ID, - Provider: a.providerID, + Model: a.large.Model, + Provider: a.large.Provider, }) if err != nil { return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) @@ -487,7 +408,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for availableTool := range a.tools.Seq() { + for _, availableTool := range a.toolsRegistry.GetAllTools() { if availableTool.Info().Name == toolCall.Name { tool = availableTool break @@ -578,7 +499,8 @@ out: msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, - Provider: a.providerID, + Model: a.large.Model, + Provider: a.large.Provider, }) if err != nil { return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) @@ -632,7 +554,11 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg if err := a.messages.Update(ctx, *assistantMsg); err != nil { return fmt.Errorf("failed to update message: %w", err) } - return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage) + model := a.Model() + if model == nil { + return nil + } + return a.TrackUsage(ctx, sessionID, *model, event.Response.Usage) } return nil @@ -734,8 +660,9 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { a.Publish(pubsub.CreatedEvent, event) // Send the messages to the summarize provider - response := a.summarizeProvider.StreamResponse( + response := a.summarizeProvider.Stream( summarizeCtx, + a.large.Model, msgsWithPrompt, nil, ) @@ -763,7 +690,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { a.Publish(pubsub.CreatedEvent, event) return } - shell := shell.GetPersistentShell(config.Get().WorkingDir()) + shell := shell.GetPersistentShell(a.cwd) summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir() event = AgentEvent{ Type: AgentEventTypeSummarize, @@ -792,8 +719,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { Time: time.Now().Unix(), }, }, - Model: a.summarizeProvider.Model().ID, - Provider: a.summarizeProviderID, + Model: a.large.Model, + Provider: a.large.Provider, }) if err != nil { event = AgentEvent{ @@ -808,7 +735,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { oldSession.SummaryMessageID = msg.ID oldSession.CompletionTokens = finalResponse.Usage.OutputTokens oldSession.PromptTokens = 0 - model := a.summarizeProvider.Model() + model := a.summarizeProvider.Model(a.large.Model) usage := finalResponse.Usage cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + @@ -857,92 +784,101 @@ func (a *agent) CancelAll() { } } -func (a *agent) UpdateModel() error { - cfg := config.Get() +func (a *agent) UpdateModels(small, large Model) error { + a.small = small + a.large = large + return a.setProviders() +} - // Get current provider configuration - currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model) - if currentProviderCfg == nil || currentProviderCfg.ID == "" { - return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name) +func (a *agent) SetDebug(debug bool) { + a.debug = debug + if a.provider != nil { + a.provider.SetDebug(debug) } + if a.titleProvider != nil { + a.titleProvider.SetDebug(debug) + } + if a.summarizeProvider != nil { + a.summarizeProvider.SetDebug(debug) + } +} - // Check if provider has changed - if string(currentProviderCfg.ID) != a.providerID { - // Provider changed, need to recreate the main provider - model := cfg.GetModelByType(a.agentCfg.Model) - if model.ID == "" { - return fmt.Errorf("model not found for agent %s", a.agentCfg.Name) - } - - promptID := agentPromptMap[a.agentCfg.ID] - if promptID == "" { - promptID = prompt.PromptDefault - } - - opts := []provider.ProviderClientOption{ - provider.WithModel(a.agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)), - } - - newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) - if err != nil { - return fmt.Errorf("failed to create new provider: %w", err) - } +func (a *agent) setProviders() error { + opts := []provider.Option{ + provider.WithSystemMessage(a.systemPrompt), + provider.WithThinking(a.large.Think), + } - // Update the provider and provider ID - a.provider = newProvider - a.providerID = string(currentProviderCfg.ID) + if a.large.MaxTokens > 0 { + opts = append(opts, provider.WithMaxTokens(a.large.MaxTokens)) + } + if a.large.ReasoningEffort != "" { + opts = append(opts, provider.WithReasoningEffort(a.large.ReasoningEffort)) } - // Check if small model provider has changed (affects title and summarize providers) - smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] - var smallModelProviderCfg config.ProviderConfig + providerCfg, ok := a.providers[a.large.Provider] + if !ok { + return fmt.Errorf("provider %s not found in config", a.large.Provider) + } + var err error + a.provider, err = provider.NewProvider(providerCfg, opts...) + if err != nil { + return fmt.Errorf("failed to create provider: %w", err) + } - for p := range cfg.Providers.Seq() { - if p.ID == smallModelCfg.Provider { - smallModelProviderCfg = p - break - } + titleOpts := []provider.Option{ + provider.WithSystemMessage(prompt.TitlePrompt()), + provider.WithMaxTokens(40), } - if smallModelProviderCfg.ID == "" { - return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) + titleProviderCfg, ok := a.providers[a.small.Provider] + if !ok { + return fmt.Errorf("small model provider %s not found in config", a.small.Provider) } - // Check if summarize provider has changed - if string(smallModelProviderCfg.ID) != a.summarizeProviderID { - smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall) - if smallModel == nil { - return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID) - } + a.titleProvider, err = provider.NewProvider(titleProviderCfg, titleOpts...) + if err != nil { + return err + } + summarizeOpts := []provider.Option{ + provider.WithSystemMessage(prompt.SummarizerPrompt()), + } + a.summarizeProvider, err = provider.NewProvider(providerCfg, summarizeOpts...) + if err != nil { + return err + } - // Recreate title provider - titleOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - // We want the title to be short, so we limit the max tokens - provider.WithMaxTokens(40), - } - newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) - if err != nil { - return fmt.Errorf("failed to create new title provider: %w", err) - } + if _, ok := a.toolsRegistry.GetTool(AgentToolName); ok { + // reset the agent tool + a.WithAgentTool() + } - // Recreate summarize provider - summarizeOpts := []provider.ProviderClientOption{ - provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), - } - newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...) - if err != nil { - return fmt.Errorf("failed to create new summarize provider: %w", err) - } + a.SetDebug(a.debug) + return nil +} - // Update the providers and provider ID - a.titleProvider = newTitleProvider - a.summarizeProvider = newSummarizeProvider - a.summarizeProviderID = string(smallModelProviderCfg.ID) +func (a *agent) WithAgentTool() error { + agent, err := NewAgent( + a.ctx, + a.cwd, + prompt.TaskPrompt(a.cwd), + NewTaskTools(a.cwd), + a.providers, + a.small, + a.large, + a.sessions, + a.messages, + ) + if err != nil { + return err } + agentTool := NewAgentTool( + agent, + a.sessions, + a.messages, + ) + + a.toolsRegistry.SetTool(AgentToolName, agentTool) return nil } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go new file mode 100644 index 0000000000000000000000000000000000000000..90ef8e9066ea3c60183ea3685912a854b0500aff --- /dev/null +++ b/internal/llm/agent/coder.go @@ -0,0 +1,52 @@ +package agent + +import ( + "context" + + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/llm/prompt" + "github.com/charmbracelet/crush/internal/llm/provider" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/session" +) + +func NewCoderAgent( + ctx context.Context, + cwd string, + providers map[string]provider.Config, + smallModel Model, + largeModel Model, + contextFiles []string, + sessions session.Service, + messages message.Service, + permissions permission.Service, + lspClients map[string]*lsp.Client, + history history.Service, + mcps map[string]MCPConfig, +) (Service, error) { + systemPrompt := prompt.CoderPrompt(cwd, contextFiles...) + tools := NewCoderTools( + ctx, + cwd, + sessions, + messages, + permissions, + lspClients, + history, + mcps, + ) + + return NewAgent( + ctx, + cwd, + systemPrompt, + tools, + providers, + smallModel, + largeModel, + sessions, + messages, + ) +} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp.go similarity index 70% rename from internal/llm/agent/mcp-tools.go rename to internal/llm/agent/mcp.go index e17a5527fb46979a8cd056473b3bcd184c014d60..db9e63ccf68aedc5d7294f4a5007ef3362386180 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp.go @@ -8,22 +8,41 @@ import ( "slices" "sync" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/resolver" + "github.com/charmbracelet/crush/internal/version" "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/version" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) +type MCPType string + +const ( + MCPStdio MCPType = "stdio" + MCPSse MCPType = "sse" + MCPHttp MCPType = "http" +) + +type MCPConfig struct { + Command string `json:"command,omitempty" ` + Env map[string]string `json:"env,omitempty"` + Args []string `json:"args,omitempty"` + Type MCPType `json:"type"` + URL string `json:"url,omitempty"` + Disabled bool `json:"disabled,omitempty"` + + Headers map[string]string `json:"headers,omitempty"` +} + type mcpTool struct { mcpName string tool mcp.Tool - mcpConfig config.MCPConfig + mcpConfig MCPConfig permissions permission.Service workingDir string } @@ -60,7 +79,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "Crush", + Name: "crush", Version: version.Version, } @@ -115,7 +134,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes } switch b.mcpConfig.Type { - case config.MCPStdio: + case MCPStdio: c, err := client.NewStdioMCPClient( b.mcpConfig.Command, b.mcpConfig.ResolvedEnv(), @@ -125,7 +144,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(err.Error()), nil } return runTool(ctx, c, b.tool.Name, params.Input) - case config.MCPHttp: + case MCPHttp: c, err := client.NewStreamableHttpClient( b.mcpConfig.URL, transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()), @@ -134,7 +153,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(err.Error()), nil } return runTool(ctx, c, b.tool.Name, params.Input) - case config.MCPSse: + case MCPSse: c, err := client.NewSSEMCPClient( b.mcpConfig.URL, client.WithHeaders(b.mcpConfig.ResolvedHeaders()), @@ -148,23 +167,22 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("invalid mcp type"), nil } -func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool { +func NewMcpTool(name, cwd string, tool mcp.Tool, permissions permission.Service, mcpConfig MCPConfig) tools.BaseTool { return &mcpTool{ mcpName: name, tool: tool, mcpConfig: mcpConfig, permissions: permissions, - workingDir: workingDir, + workingDir: cwd, } } -func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool { +func getTools(ctx context.Context, cwd string, name string, m MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool { var stdioTools []tools.BaseTool initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "Crush", - Version: version.Version, + Name: "dreamlover", } _, err := c.Initialize(ctx, initRequest) @@ -179,7 +197,7 @@ func getTools(ctx context.Context, name string, m config.MCPConfig, permissions return stdioTools } for _, t := range tools.Tools { - stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir)) + stdioTools = append(stdioTools, NewMcpTool(name, cwd, t, permissions, m)) } defer c.Close() return stdioTools @@ -190,26 +208,26 @@ var ( mcpTools []tools.BaseTool ) -func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { +func GetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool { mcpToolsOnce.Do(func() { - mcpTools = doGetMCPTools(ctx, permissions, cfg) + mcpTools = doGetMCPTools(ctx, cwd, mcps, permissions) }) return mcpTools } -func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { +func doGetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool { var wg sync.WaitGroup result := csync.NewSlice[tools.BaseTool]() - for name, m := range cfg.MCP { + for name, m := range mcps { if m.Disabled { slog.Debug("skipping disabled mcp", "name", name) continue } wg.Add(1) - go func(name string, m config.MCPConfig) { + go func(name string, m MCPConfig) { defer wg.Done() switch m.Type { - case config.MCPStdio: + case MCPStdio: c, err := client.NewStdioMCPClient( m.Command, m.ResolvedEnv(), @@ -220,8 +238,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con return } - result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - case config.MCPHttp: + result.Append(getTools(ctx, cwd, name, m, permissions, c)...) + case MCPHttp: c, err := client.NewStreamableHttpClient( m.URL, transport.WithHTTPHeaders(m.ResolvedHeaders()), @@ -230,8 +248,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - case config.MCPSse: + result.Append(getTools(ctx, cwd, name, m, permissions, c)...) + case MCPSse: c, err := client.NewSSEMCPClient( m.URL, client.WithHeaders(m.ResolvedHeaders()), @@ -240,10 +258,41 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con slog.Error("error creating mcp client", "error", err) return } - result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) + result.Append(getTools(ctx, cwd, name, m, permissions, c)...) } }(name, m) } wg.Wait() return slices.Collect(result.Seq()) } + +func (m MCPConfig) ResolvedEnv() []string { + resolver := resolver.New() + for e, v := range m.Env { + var err error + m.Env[e], err = resolver.ResolveValue(v) + if err != nil { + slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v) + continue + } + } + + env := make([]string, 0, len(m.Env)) + for k, v := range m.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + return env +} + +func (m MCPConfig) ResolvedHeaders() map[string]string { + resolver := resolver.New() + for e, v := range m.Headers { + var err error + m.Headers[e], err = resolver.ResolveValue(v) + if err != nil { + slog.Error("error resolving header variable", "error", err, "variable", e, "value", v) + continue + } + } + return m.Headers +} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go new file mode 100644 index 0000000000000000000000000000000000000000..8f19aa651b39c7ed9147f5903900deeb157ca7f8 --- /dev/null +++ b/internal/llm/agent/tools.go @@ -0,0 +1,56 @@ +package agent + +import ( + "context" + + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/session" +) + +func NewCoderTools( + ctx context.Context, + cwd string, + sessions session.Service, + messages message.Service, + permissions permission.Service, + lspClients map[string]*lsp.Client, + history history.Service, + mcps map[string]MCPConfig, +) tools.Registry { + toolFn := func() []tools.BaseTool { + allTools := []tools.BaseTool{ + tools.NewBashTool(permissions, cwd), + tools.NewDownloadTool(permissions, cwd), + tools.NewEditTool(lspClients, permissions, history, cwd), + tools.NewFetchTool(permissions, cwd), + tools.NewGlobTool(cwd), + tools.NewGrepTool(cwd), + tools.NewLsTool(cwd), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients, cwd), + tools.NewWriteTool(lspClients, permissions, history, cwd), + } + mcpTools := GetMCPTools(ctx, cwd, mcps, permissions) + allTools = append(allTools, mcpTools...) + if len(lspClients) > 0 { + allTools = append(allTools, tools.NewDiagnosticsTool(lspClients)) + } + return allTools + } + return tools.NewRegistry(toolFn) +} + +func NewTaskTools(cwd string) tools.Registry { + return tools.NewRegistryFromTools([]tools.BaseTool{ + tools.NewGlobTool(cwd), + tools.NewGrepTool(cwd), + tools.NewLsTool(cwd), + tools.NewSourcegraphTool(), + // no need for LSP info here + tools.NewViewTool(map[string]*lsp.Client{}, cwd), + }) +} diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index ed879754c7c8c78debda98fb6b89c33d75fcab24..9412d00b404e2e4373eadb7d1d1108ced2dc4825 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -4,26 +4,23 @@ import ( "context" _ "embed" "fmt" - "log/slog" "os" "path/filepath" "runtime" "time" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" ) -func CoderPrompt(p string, contextFiles ...string) string { +func CoderPrompt(cwd string, contextFiles ...string) string { var basePrompt string basePrompt = string(baseCoderPrompt) - envInfo := getEnvironmentInfo() + envInfo := getEnvironmentInfo(cwd) - basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) + basePrompt = fmt.Sprintf("%s\n\n%s", basePrompt, envInfo) - contextContent := getContextFromPaths(config.Get().WorkingDir(), contextFiles) - slog.Debug("Context content", "Context", contextContent) + contextContent := getContextFromPaths(cwd, contextFiles) if contextContent != "" { return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent) } @@ -33,8 +30,7 @@ func CoderPrompt(p string, contextFiles ...string) string { //go:embed coder.md var baseCoderPrompt []byte -func getEnvironmentInfo() string { - cwd := config.Get().WorkingDir() +func getEnvironmentInfo(cwd string) string { isGit := isGitRepo(cwd) platform := runtime.GOOS date := time.Now().Format("1/2/2006") @@ -60,18 +56,7 @@ func isGitRepo(dir string) bool { return err == nil } -func lspInformation() string { - cfg := config.Get() - hasLSP := false - for _, v := range cfg.LSP { - if !v.Disabled { - hasLSP = true - break - } - } - if !hasLSP { - return "" - } +func LSPInformation() string { return `# LSP Information Tools that support it will also include useful diagnostics such as linting and typechecking. - These diagnostics will be automatically enabled when you run the tool, and will be displayed in the output at the bottom within the and tags. diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 8c87482a71679f5bc682e6fdd8c1f5a03b89c184..f1d5be556b5cfeeed8ec986cc338171801da8905 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -6,9 +6,7 @@ import ( "strings" "sync" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/resolver" ) type PromptID string @@ -21,23 +19,6 @@ const ( PromptDefault PromptID = "default" ) -func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string { - basePrompt := "" - switch promptID { - case PromptCoder: - basePrompt = CoderPrompt(provider, contextPaths...) - case PromptTitle: - basePrompt = TitlePrompt() - case PromptTask: - basePrompt = TaskPrompt() - case PromptSummarizer: - basePrompt = SummarizerPrompt() - default: - basePrompt = "You are a helpful assistant" - } - return basePrompt -} - func getContextFromPaths(workingDir string, contextPaths []string) string { return processContextPaths(workingDir, contextPaths) } @@ -59,7 +40,7 @@ func expandPath(path string) string { // Handle environment variable expansion using the same pattern as config if strings.HasPrefix(path, "$") { - resolver := config.NewEnvironmentVariableResolver(env.New()) + resolver := resolver.New() if expanded, err := resolver.ResolveValue(path); err == nil { path = expanded } @@ -75,7 +56,8 @@ func processContextPaths(workDir string, paths []string) string { ) // Track processed files to avoid duplicates - processedFiles := csync.NewMap[string, bool]() + processedFiles := make(map[string]bool) + var processedMutex sync.Mutex for _, path := range paths { wg.Add(1) @@ -106,8 +88,14 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(path) - if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { - processedFiles.Set(lowerPath, true) + processedMutex.Lock() + alreadyProcessed := processedFiles[lowerPath] + if !alreadyProcessed { + processedFiles[lowerPath] = true + } + processedMutex.Unlock() + + if !alreadyProcessed { if result := processFile(path); result != "" { resultCh <- result } @@ -120,8 +108,14 @@ func processContextPaths(workDir string, paths []string) string { // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) - if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed { - processedFiles.Set(lowerPath, true) + processedMutex.Lock() + alreadyProcessed := processedFiles[lowerPath] + if !alreadyProcessed { + processedFiles[lowerPath] = true + } + processedMutex.Unlock() + + if !alreadyProcessed { result := processFile(fullPath) if result != "" { resultCh <- result diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index e4f021d4ab7ef9f49873bc6893a231d72f2f3994..362105b4e41fddaec94eb784baad817f2b47ebdd 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -4,12 +4,12 @@ import ( "fmt" ) -func TaskPrompt() string { +func TaskPrompt(cwd string) string { agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query 3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.` - return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo()) + return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo(cwd)) } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 3de8c805b3f0cfa08b1b2bb6b60577742ce8cc1d..7e190f6cf6b3b418d915e4759961581e721f5edb 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -16,28 +16,25 @@ import ( "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) -type anthropicClient struct { - providerOptions providerClientOptions +type anthropicProvider struct { + *baseProvider useBedrock bool client anthropic.Client adjustedMaxTokens int // Used when context limit is hit } -type AnthropicClient ProviderClient - -func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient { - return &anthropicClient{ - providerOptions: opts, - client: createAnthropicClient(opts, useBedrock), +func NewAnthropicProvider(base *baseProvider, useBedrock bool) Provider { + return &anthropicProvider{ + baseProvider: base, + client: createAnthropicClient(base, useBedrock), } } -func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client { +func createAnthropicClient(opts *baseProvider, useBedrock bool) anthropic.Client { anthropicClientOptions := []option.RequestOption{} // Check if Authorization header is provided in extra headers @@ -76,7 +73,7 @@ func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropi return anthropic.NewClient(anthropicClientOptions...) } -func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { +func (a *anthropicProvider) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { for i, msg := range messages { cache := false if i > len(messages)-3 { @@ -85,7 +82,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic switch msg.Role { case message.User: content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.providerOptions.disableCache { + if cache && !a.disableCache { content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -110,7 +107,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic if msg.Content().String() != "" { content := anthropic.NewTextBlock(msg.Content().String()) - if cache && !a.providerOptions.disableCache { + if cache && !a.disableCache { content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -144,7 +141,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic return } -func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { +func (a *anthropicProvider) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) for i, tool := range tools { @@ -154,11 +151,11 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU Description: anthropic.String(info.Description), InputSchema: anthropic.ToolInputSchemaParam{ Properties: info.Parameters, - // TODO: figure out how we can tell claude the required fields? + Required: info.Required, }, } - if i == len(tools)-1 && !a.providerOptions.disableCache { + if i == len(tools)-1 && !a.disableCache { toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -170,7 +167,7 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU return anthropicTools } -func (a *anthropicClient) finishReason(reason string) message.FinishReason { +func (a *anthropicProvider) finishReason(reason string) message.FinishReason { switch reason { case "end_turn": return message.FinishReasonEndTurn @@ -185,37 +182,23 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason { } } -func (a *anthropicClient) isThinkingEnabled() bool { - cfg := config.Get() - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if a.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } - return a.Model().CanReason && modelConfig.Think +func (a *anthropicProvider) isThinkingEnabled(model string) bool { + return a.Model(model).CanReason && a.think } -func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { - model := a.providerOptions.model(a.providerOptions.modelType) +func (a *anthropicProvider) preparedMessages(modelID string, messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { + model := a.Model(modelID) var thinkingParam anthropic.ThinkingConfigParamUnion - cfg := config.Get() - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if a.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } temperature := anthropic.Float(0) maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens + if a.maxTokens > 0 { + maxTokens = a.maxTokens } - if a.isThinkingEnabled() { + if a.isThinkingEnabled(modelID) { thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8)) temperature = anthropic.Float(1) } - // Override max tokens if set in provider options - if a.providerOptions.maxTokens > 0 { - maxTokens = a.providerOptions.maxTokens - } // Use adjusted max tokens if context limit was hit if a.adjustedMaxTokens > 0 { @@ -225,9 +208,9 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to systemBlocks := []anthropic.TextBlockParam{} // Add custom system prompt prefix if configured - if a.providerOptions.systemPromptPrefix != "" { + if a.systemPromptPrefix != "" { systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ - Text: a.providerOptions.systemPromptPrefix, + Text: a.systemPromptPrefix, CacheControl: anthropic.CacheControlEphemeralParam{ Type: "ephemeral", }, @@ -235,7 +218,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to } systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ - Text: a.providerOptions.systemMessage, + Text: a.systemMessage, CacheControl: anthropic.CacheControlEphemeralParam{ Type: "ephemeral", }, @@ -252,21 +235,24 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to } } -func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { - cfg := config.Get() +func (a *anthropicProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + messages = a.cleanMessages(messages) + return a.send(ctx, model, messages, tools) +} +func (a *anthropicProvider) send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { attempts := 0 for { attempts++ // Prepare messages on each attempt in case max_tokens was adjusted - preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) - if cfg.Options.Debug { + preparedMessages := a.preparedMessages(model, a.convertMessages(messages), a.convertTools(tools)) + if a.debug { jsonData, _ := json.Marshal(preparedMessages) slog.Debug("Prepared messages", "messages", string(jsonData)) } var opts []option.RequestOption - if a.isThinkingEnabled() { + if a.isThinkingEnabled(model) { opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14")) } anthropicResponse, err := a.client.Messages.New( @@ -308,22 +294,26 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, } } -func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - cfg := config.Get() +func (a *anthropicProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + messages = a.cleanMessages(messages) + return a.stream(ctx, model, messages, tools) +} + +func (a *anthropicProvider) stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { attempts := 0 eventChan := make(chan ProviderEvent) go func() { for { attempts++ // Prepare messages on each attempt in case max_tokens was adjusted - preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) - if cfg.Options.Debug { + preparedMessages := a.preparedMessages(model, a.convertMessages(messages), a.convertTools(tools)) + if a.debug { jsonData, _ := json.Marshal(preparedMessages) slog.Debug("Prepared messages", "messages", string(jsonData)) } var opts []option.RequestOption - if a.isThinkingEnabled() { + if a.isThinkingEnabled(model) { opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14")) } @@ -460,7 +450,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message return eventChan } -func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { +func (a *anthropicProvider) shouldRetry(attempts int, err error) (bool, int64, error) { var apiErr *anthropic.Error if !errors.As(err, &apiErr) { return false, 0, err @@ -471,11 +461,12 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err } if apiErr.StatusCode == 401 { - a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey) + a.apiKey, err = a.resolver.ResolveValue(a.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } - a.client = createAnthropicClient(a.providerOptions, a.useBedrock) + + a.client = createAnthropicClient(a.baseProvider, a.useBedrock) return true, 0, nil } @@ -508,7 +499,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err } // handleContextLimitError parses context limit error and returns adjusted max_tokens -func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) { +func (a *anthropicProvider) handleContextLimitError(apiErr *anthropic.Error) (int, bool) { // Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000" errorMsg := apiErr.Error() @@ -535,7 +526,7 @@ func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, return safeMaxTokens, true } -func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { +func (a *anthropicProvider) toolCalls(msg anthropic.Message) []message.ToolCall { var toolCalls []message.ToolCall for _, block := range msg.Content { @@ -555,7 +546,7 @@ func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { return toolCalls } -func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { +func (a *anthropicProvider) usage(msg anthropic.Message) TokenUsage { return TokenUsage{ InputTokens: msg.Usage.InputTokens, OutputTokens: msg.Usage.OutputTokens, @@ -563,7 +554,3 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { CacheReadTokens: msg.Usage.CacheReadInputTokens, } } - -func (a *anthropicClient) Model() catwalk.Model { - return a.providerOptions.model(a.providerOptions.modelType) -} diff --git a/internal/llm/provider/azure.go b/internal/llm/provider/azure.go index 31d06bd1b040d8f8cce3afa28fad53b0fe12eaa3..53304e4d2109feeda9ad8ab9c915ba158c3caf6b 100644 --- a/internal/llm/provider/azure.go +++ b/internal/llm/provider/azure.go @@ -6,27 +6,25 @@ import ( "github.com/openai/openai-go/option" ) -type azureClient struct { - *openaiClient +type azureProvider struct { + *openaiProvider } -type AzureClient ProviderClient - -func newAzureClient(opts providerClientOptions) AzureClient { - apiVersion := opts.extraParams["apiVersion"] +func NewAzureProvider(base *baseProvider) Provider { + apiVersion := base.extraParams["apiVersion"] if apiVersion == "" { apiVersion = "2025-01-01-preview" } reqOpts := []option.RequestOption{ - azure.WithEndpoint(opts.baseURL, apiVersion), + azure.WithEndpoint(base.baseURL, apiVersion), } - reqOpts = append(reqOpts, azure.WithAPIKey(opts.apiKey)) - base := &openaiClient{ - providerOptions: opts, - client: openai.NewClient(reqOpts...), + reqOpts = append(reqOpts, azure.WithAPIKey(base.apiKey)) + client := &openaiProvider{ + baseProvider: base, + client: openai.NewClient(reqOpts...), } - return &azureClient{openaiClient: base} + return &azureProvider{openaiProvider: client} } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 8b5b21c36a390e80843504c7c9f6c257156f6379..d69a841328a66215f937ac189c8fe82577dff0fb 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -4,90 +4,56 @@ import ( "context" "errors" "fmt" - "strings" - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" ) -type bedrockClient struct { - providerOptions providerClientOptions - childProvider ProviderClient +type bedrockProvider struct { + *baseProvider + region string + childProvider Provider } -type BedrockClient ProviderClient - -func newBedrockClient(opts providerClientOptions) BedrockClient { +func NewBedrockProvider(base *baseProvider) Provider { // Get AWS region from environment - region := opts.extraParams["region"] + region := base.extraParams["region"] if region == "" { region = "us-east-1" // default region } - if len(region) < 2 { - return &bedrockClient{ - providerOptions: opts, - childProvider: nil, // Will cause an error when used - } - } - - opts.model = func(modelType config.SelectedModelType) catwalk.Model { - model := config.Get().GetModelByType(modelType) - - // Prefix the model name with region - regionPrefix := region[:2] - modelName := model.ID - model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName) - return *model - } - - model := opts.model(opts.modelType) - - // Determine which provider to use based on the model - if strings.Contains(string(model.ID), "anthropic") { - // Create Anthropic client with Bedrock configuration - anthropicOpts := opts - // TODO: later find a way to check if the AWS account has caching enabled - opts.disableCache = true // Disable cache for Bedrock - return &bedrockClient{ - providerOptions: opts, - childProvider: newAnthropicClient(anthropicOpts, true), - } - } - // Return client with nil childProvider if model is not supported - // This will cause an error when used - return &bedrockClient{ - providerOptions: opts, - childProvider: nil, + return &bedrockProvider{ + baseProvider: base, + childProvider: NewAnthropicProvider(base, true), } } -func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - if b.childProvider == nil { - return nil, errors.New("unsupported model for bedrock provider") +func (b *bedrockProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + if len(b.region) < 2 { + return nil, errors.New("no region selected") } - return b.childProvider.send(ctx, messages, tools) + regionPrefix := b.region[:2] + modelName := model + model = fmt.Sprintf("%s.%s", regionPrefix, modelName) + messages = b.cleanMessages(messages) + return b.childProvider.Send(ctx, model, messages, tools) } -func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - eventChan := make(chan ProviderEvent) - - if b.childProvider == nil { +func (b *bedrockProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + if len(b.region) < 2 { + eventChan := make(chan ProviderEvent) go func() { eventChan <- ProviderEvent{ Type: EventError, - Error: errors.New("unsupported model for bedrock provider"), + Error: errors.New("no region selected"), } close(eventChan) }() return eventChan } - - return b.childProvider.stream(ctx, messages, tools) -} - -func (b *bedrockClient) Model() catwalk.Model { - return b.providerOptions.model(b.providerOptions.modelType) + regionPrefix := b.region[:2] + modelName := model + model = fmt.Sprintf("%s.%s", regionPrefix, modelName) + messages = b.cleanMessages(messages) + return b.childProvider.Stream(ctx, model, messages, tools) } diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 0070d246012547a691f8c6a8cbd8de2234cd93ec..4b25c16857ace330993aa595c819a4889a25df74 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -10,43 +10,39 @@ import ( "strings" "time" - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/google/uuid" "google.golang.org/genai" ) -type geminiClient struct { - providerOptions providerClientOptions - client *genai.Client +type geminiProvider struct { + *baseProvider + client *genai.Client } -type GeminiClient ProviderClient - -func newGeminiClient(opts providerClientOptions) GeminiClient { - client, err := createGeminiClient(opts) +func NewGeminiProvider(base *baseProvider) Provider { + client, err := createGeminiClient(base) if err != nil { slog.Error("Failed to create Gemini client", "error", err) return nil } - return &geminiClient{ - providerOptions: opts, - client: client, + return &geminiProvider{ + baseProvider: base, + client: client, } } -func createGeminiClient(opts providerClientOptions) (*genai.Client, error) { - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) +func createGeminiClient(base *baseProvider) (*genai.Client, error) { + client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: base.apiKey, Backend: genai.BackendGeminiAPI}) if err != nil { return nil, err } return client, nil } -func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { +func (g *geminiProvider) convertMessages(messages []message.Message) []*genai.Content { var history []*genai.Content for _, msg := range messages { switch msg.Role { @@ -128,7 +124,7 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont return history } -func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { +func (g *geminiProvider) convertTools(tools []tools.BaseTool) []*genai.Tool { geminiTool := &genai.Tool{} geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools)) @@ -150,7 +146,7 @@ func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { return []*genai.Tool{geminiTool} } -func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason { +func (g *geminiProvider) finishReason(reason genai.FinishReason) message.FinishReason { switch reason { case genai.FinishReasonStop: return message.FinishReasonEndTurn @@ -161,28 +157,27 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea } } -func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { +func (g *geminiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + messages = g.cleanMessages(messages) + return g.send(ctx, model, messages, tools) +} + +func (g *geminiProvider) send(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { // Convert messages geminiMessages := g.convertMessages(messages) - model := g.providerOptions.model(g.providerOptions.modelType) - cfg := config.Get() - if cfg.Options.Debug { + if g.debug { jsonData, _ := json.Marshal(geminiMessages) slog.Debug("Prepared messages", "messages", string(jsonData)) } - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if g.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } - + model := g.Model(modelID) maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens + if g.maxTokens > 0 { + maxTokens = g.maxTokens } - systemMessage := g.providerOptions.systemMessage - if g.providerOptions.systemPromptPrefix != "" { - systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage + systemMessage := g.systemMessage + if g.systemPromptPrefix != "" { + systemMessage = g.systemPromptPrefix + "\n" + systemMessage } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] @@ -260,34 +255,31 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too } } -func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { +func (g *geminiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + messages = g.cleanMessages(messages) + return g.stream(ctx, model, messages, tools) +} + +func (g *geminiProvider) stream(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { // Convert messages geminiMessages := g.convertMessages(messages) - model := g.providerOptions.model(g.providerOptions.modelType) - cfg := config.Get() - if cfg.Options.Debug { + model := g.Model(modelID) + if g.debug { jsonData, _ := json.Marshal(geminiMessages) slog.Debug("Prepared messages", "messages", string(jsonData)) } - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if g.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] - } maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens + if g.maxTokens > 0 { + maxTokens = g.maxTokens } - // Override max tokens if set in provider options - if g.providerOptions.maxTokens > 0 { - maxTokens = g.providerOptions.maxTokens - } - systemMessage := g.providerOptions.systemMessage - if g.providerOptions.systemPromptPrefix != "" { - systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage + systemMessage := g.systemMessage + if g.systemPromptPrefix != "" { + systemMessage = g.systemPromptPrefix + "\n" + systemMessage } + history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ @@ -412,7 +404,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t return eventChan } -func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) { +func (g *geminiProvider) shouldRetry(attempts int, err error) (bool, int64, error) { // Check if error is a rate limit error if attempts > maxRetries { return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) @@ -429,11 +421,11 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) // Check for token expiration (401 Unauthorized) if contains(errMsg, "unauthorized", "invalid api key", "api key expired") { - g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey) + g.apiKey, err = g.resolver.ResolveValue(g.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } - g.client, err = createGeminiClient(g.providerOptions) + g.client, err = createGeminiClient(g.baseProvider) if err != nil { return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err) } @@ -454,7 +446,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) return true, int64(retryMs), nil } -func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { +func (g *geminiProvider) usage(resp *genai.GenerateContentResponse) TokenUsage { if resp == nil || resp.UsageMetadata == nil { return TokenUsage{} } @@ -467,10 +459,6 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { } } -func (g *geminiClient) Model() catwalk.Model { - return g.providerOptions.model(g.providerOptions.modelType) -} - // Helper functions func parseJSONToMap(jsonStr string) (map[string]any, error) { var result map[string]any diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 23e247830a48ba1860ba7bde5059da69fab6d3ac..7e8d0701fbe9f1551d9d6b59dc56dc1a5a947dc1 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -10,7 +10,6 @@ import ( "time" "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" @@ -18,30 +17,25 @@ import ( "github.com/openai/openai-go/shared" ) -type openaiClient struct { - providerOptions providerClientOptions - client openai.Client +type openaiProvider struct { + *baseProvider + client openai.Client } -type OpenAIClient ProviderClient - -func newOpenAIClient(opts providerClientOptions) OpenAIClient { - return &openaiClient{ - providerOptions: opts, - client: createOpenAIClient(opts), +func NewOpenAIProvider(base *baseProvider) Provider { + return &openaiProvider{ + baseProvider: base, + client: createOpenAIClient(base), } } -func createOpenAIClient(opts providerClientOptions) openai.Client { +func createOpenAIClient(opts *baseProvider) openai.Client { openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } if opts.baseURL != "" { - resolvedBaseURL, err := config.Get().Resolve(opts.baseURL) - if err == nil { - openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL)) - } + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL)) } for key, value := range opts.extraHeaders { @@ -55,11 +49,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { return openai.NewClient(openaiClientOptions...) } -func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { +func (o *openaiProvider) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { // Add system message first - systemMessage := o.providerOptions.systemMessage - if o.providerOptions.systemPromptPrefix != "" { - systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage + systemMessage := o.systemMessage + if o.systemPromptPrefix != "" { + systemMessage = o.systemPromptPrefix + "\n" + systemMessage } openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage)) @@ -126,7 +120,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag return } -func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { +func (o *openaiProvider) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) for i, tool := range tools { @@ -147,7 +141,7 @@ func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatComplet return openaiTools } -func (o *openaiClient) finishReason(reason string) message.FinishReason { +func (o *openaiProvider) finishReason(reason string) message.FinishReason { switch reason { case "stop": return message.FinishReasonEndTurn @@ -160,17 +154,14 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { } } -func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { - model := o.providerOptions.model(o.providerOptions.modelType) - cfg := config.Get() +func (o *openaiProvider) preparedParams(modelID string, messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + model := o.Model(modelID) - modelConfig := cfg.Models[config.SelectedModelTypeLarge] - if o.providerOptions.modelType == config.SelectedModelTypeSmall { - modelConfig = cfg.Models[config.SelectedModelTypeSmall] + reasoningEffort := o.reasoningEffort + if reasoningEffort == "" { + reasoningEffort = model.DefaultReasoningEffort } - reasoningEffort := modelConfig.ReasoningEffort - params := openai.ChatCompletionNewParams{ Model: openai.ChatModel(model.ID), Messages: messages, @@ -178,14 +169,10 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar } maxTokens := model.DefaultMaxTokens - if modelConfig.MaxTokens > 0 { - maxTokens = modelConfig.MaxTokens + if o.maxTokens > 0 { + maxTokens = o.maxTokens } - // Override max tokens if set in provider options - if o.providerOptions.maxTokens > 0 { - maxTokens = o.providerOptions.maxTokens - } if model.CanReason { params.MaxCompletionTokens = openai.Int(maxTokens) switch reasoningEffort { @@ -205,10 +192,14 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar return params } -func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { - params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) - cfg := config.Get() - if cfg.Options.Debug { +func (o *openaiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + messages = o.cleanMessages(messages) + return o.send(ctx, model, messages, tools) +} + +func (o *openaiProvider) send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { + params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools)) + if o.debug { jsonData, _ := json.Marshal(params) slog.Debug("Prepared messages", "messages", string(jsonData)) } @@ -262,14 +253,18 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too } } -func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) +func (o *openaiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + messages = o.cleanMessages(messages) + return o.stream(ctx, model, messages, tools) +} + +func (o *openaiProvider) stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools)) params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ IncludeUsage: openai.Bool(true), } - cfg := config.Get() - if cfg.Options.Debug { + if o.debug { jsonData, _ := json.Marshal(params) slog.Debug("Prepared messages", "messages", string(jsonData)) } @@ -350,7 +345,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t err := openaiStream.Err() if err == nil || errors.Is(err, io.EOF) { - if cfg.Options.Debug { + if o.debug { jsonData, _ := json.Marshal(acc.ChatCompletion) slog.Debug("Response", "messages", string(jsonData)) } @@ -421,7 +416,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t return eventChan } -func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { +func (o *openaiProvider) shouldRetry(attempts int, err error) (bool, int64, error) { var apiErr *openai.Error if !errors.As(err, &apiErr) { return false, 0, err @@ -433,11 +428,11 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) // Check for token expiration (401 Unauthorized) if apiErr.StatusCode == 401 { - o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey) + o.apiKey, err = o.resolver.ResolveValue(o.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } - o.client = createOpenAIClient(o.providerOptions) + o.client = createOpenAIClient(o.baseProvider) return true, 0, nil } @@ -459,7 +454,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) return true, int64(retryMs), nil } -func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { +func (o *openaiProvider) toolCalls(completion openai.ChatCompletion) []message.ToolCall { var toolCalls []message.ToolCall if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { @@ -478,7 +473,7 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too return toolCalls } -func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { +func (o *openaiProvider) usage(completion openai.ChatCompletion) TokenUsage { cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens inputTokens := completion.Usage.PromptTokens - cachedTokens @@ -489,7 +484,3 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { CacheReadTokens: cachedTokens, } } - -func (o *openaiClient) Model() catwalk.Model { - return o.providerOptions.model(o.providerOptions.modelType) -} diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go deleted file mode 100644 index ef79803c8a8aa1ee3fe6cb7de8bc8fa86f26c03c..0000000000000000000000000000000000000000 --- a/internal/llm/provider/openai_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package provider - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "testing" - "time" - - "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/message" - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" -) - -func TestMain(m *testing.M) { - _, err := config.Init(".", true) - if err != nil { - panic("Failed to initialize config: " + err.Error()) - } - - os.Exit(m.Run()) -} - -func TestOpenAIClientStreamChoices(t *testing.T) { - // Create a mock server that returns Server-Sent Events with empty choices - // This simulates the 🤡 behavior when a server returns 200 instead of 404 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusOK) - - emptyChoicesChunk := map[string]any{ - "id": "chat-completion-test", - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": "test-model", - "choices": []any{}, // Empty choices array that causes panic - } - - jsonData, _ := json.Marshal(emptyChoicesChunk) - w.Write([]byte("data: " + string(jsonData) + "\n\n")) - w.Write([]byte("data: [DONE]\n\n")) - })) - defer server.Close() - - // Create OpenAI client pointing to our mock server - client := &openaiClient{ - providerOptions: providerClientOptions{ - modelType: config.SelectedModelTypeLarge, - apiKey: "test-key", - systemMessage: "test", - model: func(config.SelectedModelType) catwalk.Model { - return catwalk.Model{ - ID: "test-model", - Name: "test-model", - } - }, - }, - client: openai.NewClient( - option.WithAPIKey("test-key"), - option.WithBaseURL(server.URL), - ), - } - - // Create test messages - messages := []message.Message{ - { - Role: message.User, - Parts: []message.ContentPart{message.TextContent{Text: "Hello"}}, - }, - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - eventsChan := client.stream(ctx, messages, nil) - - // Collect events - this will panic without the bounds check - for event := range eventsChan { - t.Logf("Received event: %+v", event) - if event.Type == EventError || event.Type == EventComplete { - break - } - } -} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index c236c10f0b0e9bf9b4db50544ca664291ef13b65..49afca4aba855cdc09fec803af43ae9c73876011 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -3,11 +3,13 @@ package provider import ( "context" "fmt" + "net/http" + "time" "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/resolver" ) type EventType string @@ -52,159 +54,233 @@ type ProviderEvent struct { ToolCall *message.ToolCall Error error } + +type Config struct { + // The provider's id. + ID string `json:"id,omitempty"` + // The provider's name, used for display purposes. + Name string `json:"name,omitempty"` + // The provider's API endpoint. + BaseURL string `json:"base_url,omitempty"` + // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. + Type catwalk.Type `json:"type,omitempty"` + // The provider's API key. + APIKey string `json:"api_key,omitempty"` + // Marks the provider as disabled. + Disable bool `json:"disable,omitempty"` + + // Custom system prompt prefix. + SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"` + + // Extra headers to send with each request to the provider. + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // Extra body + ExtraBody map[string]any `json:"extra_body,omitempty"` + + // Used to pass extra parameters to the provider. + ExtraParams map[string]string `json:"-"` + + // The provider models + Models []catwalk.Model `json:"models,omitempty"` +} + type Provider interface { - SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) + Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) + + Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent + Model(modelID string) *catwalk.Model - Model() catwalk.Model + SetDebug(debug bool) } -type providerClientOptions struct { +type baseProvider struct { baseURL string - config config.ProviderConfig + debug bool + config Config apiKey string - modelType config.SelectedModelType - model func(config.SelectedModelType) catwalk.Model disableCache bool systemMessage string systemPromptPrefix string maxTokens int64 + think bool + reasoningEffort string + resolver resolver.Resolver extraHeaders map[string]string extraBody map[string]any extraParams map[string]string } -type ProviderClientOption func(*providerClientOptions) - -type ProviderClient interface { - send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) - stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent - - Model() catwalk.Model -} - -type baseProvider[C ProviderClient] struct { - options providerClientOptions - client C -} +type Option func(*baseProvider) -func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { - for _, msg := range messages { - // The message has no content - if len(msg.Parts) == 0 { - continue - } - cleaned = append(cleaned, msg) +func WithDisableCache(disableCache bool) Option { + return func(options *baseProvider) { + options.disableCache = disableCache } - return } -func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = p.cleanMessages(messages) - return p.client.send(ctx, messages, tools) +func WithSystemMessage(systemMessage string) Option { + return func(options *baseProvider) { + options.systemMessage = systemMessage + } } -func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - messages = p.cleanMessages(messages) - return p.client.stream(ctx, messages, tools) +func WithMaxTokens(maxTokens int64) Option { + return func(options *baseProvider) { + options.maxTokens = maxTokens + } } -func (p *baseProvider[C]) Model() catwalk.Model { - return p.client.Model() +func WithThinking(think bool) Option { + return func(options *baseProvider) { + options.think = think + } } -func WithModel(model config.SelectedModelType) ProviderClientOption { - return func(options *providerClientOptions) { - options.modelType = model +func WithReasoningEffort(reasoningEffort string) Option { + return func(options *baseProvider) { + options.reasoningEffort = reasoningEffort } } -func WithDisableCache(disableCache bool) ProviderClientOption { - return func(options *providerClientOptions) { - options.disableCache = disableCache +func WithDebug(debug bool) Option { + return func(options *baseProvider) { + options.debug = debug } } -func WithSystemMessage(systemMessage string) ProviderClientOption { - return func(options *providerClientOptions) { - options.systemMessage = systemMessage +func WithResolver(resolver resolver.Resolver) Option { + return func(options *baseProvider) { + options.resolver = resolver } } -func WithMaxTokens(maxTokens int64) ProviderClientOption { - return func(options *providerClientOptions) { - options.maxTokens = maxTokens +func newBaseProvider(cfg Config, opts ...Option) (*baseProvider, error) { + provider := &baseProvider{ + baseURL: cfg.BaseURL, + config: cfg, + apiKey: cfg.APIKey, + extraHeaders: cfg.ExtraHeaders, + extraBody: cfg.ExtraBody, + systemPromptPrefix: cfg.SystemPromptPrefix, + resolver: resolver.New(), + } + for _, o := range opts { + o(provider) } -} -func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { - resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey) + resolvedAPIKey, err := provider.resolver.ResolveValue(cfg.APIKey) if err != nil { return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) } + resolvedBaseURL, err := provider.resolver.ResolveValue(cfg.BaseURL) + if err != nil { + resolvedBaseURL = "" + } // Resolve extra headers resolvedExtraHeaders := make(map[string]string) for key, value := range cfg.ExtraHeaders { - resolvedValue, err := config.Get().Resolve(value) + resolvedValue, err := provider.resolver.ResolveValue(value) if err != nil { return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err) } resolvedExtraHeaders[key] = resolvedValue } - clientOptions := providerClientOptions{ - baseURL: cfg.BaseURL, - config: cfg, - apiKey: resolvedAPIKey, - extraHeaders: resolvedExtraHeaders, - extraBody: cfg.ExtraBody, - systemPromptPrefix: cfg.SystemPromptPrefix, - model: func(tp config.SelectedModelType) catwalk.Model { - return *config.Get().GetModelByType(tp) - }, - } - for _, o := range opts { - o(&clientOptions) + provider.apiKey = resolvedAPIKey + provider.baseURL = resolvedBaseURL + provider.extraHeaders = resolvedExtraHeaders + return provider, nil +} + +func NewProvider(cfg Config, opts ...Option) (Provider, error) { + base, err := newBaseProvider(cfg, opts...) + if err != nil { + return nil, err } switch cfg.Type { case catwalk.TypeAnthropic: - return &baseProvider[AnthropicClient]{ - options: clientOptions, - client: newAnthropicClient(clientOptions, false), - }, nil + return NewAnthropicProvider(base, false), nil case catwalk.TypeOpenAI: - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil + return NewOpenAIProvider(base), nil case catwalk.TypeGemini: - return &baseProvider[GeminiClient]{ - options: clientOptions, - client: newGeminiClient(clientOptions), - }, nil + return NewGeminiProvider(base), nil case catwalk.TypeBedrock: - return &baseProvider[BedrockClient]{ - options: clientOptions, - client: newBedrockClient(clientOptions), - }, nil + return NewBedrockProvider(base), nil case catwalk.TypeAzure: - return &baseProvider[AzureClient]{ - options: clientOptions, - client: newAzureClient(clientOptions), - }, nil + return NewAzureProvider(base), nil case catwalk.TypeVertexAI: - return &baseProvider[VertexAIClient]{ - options: clientOptions, - client: newVertexAIClient(clientOptions), - }, nil - case catwalk.TypeXAI: - clientOptions.baseURL = "https://api.x.ai/v1" - return &baseProvider[OpenAIClient]{ - options: clientOptions, - client: newOpenAIClient(clientOptions), - }, nil + return NewVertexAIProvider(base), nil } return nil, fmt.Errorf("provider not supported: %s", cfg.Type) } + +func (p *baseProvider) cleanMessages(messages []message.Message) (cleaned []message.Message) { + for _, msg := range messages { + // The message has no content + if len(msg.Parts) == 0 { + continue + } + cleaned = append(cleaned, msg) + } + return +} + +func (o *baseProvider) Model(model string) *catwalk.Model { + for _, m := range o.config.Models { + if m.ID == model { + return &m + } + } + return nil +} + +func (o *baseProvider) SetDebug(debug bool) { + o.debug = debug +} + +func (c *Config) TestConnection(resolver resolver.Resolver) error { + testURL := "" + headers := make(map[string]string) + apiKey, _ := resolver.ResolveValue(c.APIKey) + switch c.Type { + case catwalk.TypeOpenAI: + baseURL, _ := resolver.ResolveValue(c.BaseURL) + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + testURL = baseURL + "/models" + headers["Authorization"] = "Bearer " + apiKey + case catwalk.TypeAnthropic: + baseURL, _ := resolver.ResolveValue(c.BaseURL) + if baseURL == "" { + baseURL = "https://api.anthropic.com/v1" + } + testURL = baseURL + "/models" + headers["x-api-key"] = apiKey + headers["anthropic-version"] = "2023-06-01" + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client := &http.Client{} + req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil) + if err != nil { + return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + for k, v := range c.ExtraHeaders { + req.Header.Set(k, v) + } + b, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) + } + if b.StatusCode != http.StatusOK { + return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status) + } + _ = b.Body.Close() + return nil +} diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index 1baa08927dcfacd40e3dc3a9909311b7be452826..bdf103b81cf6efa74fcc9d0e0c8ad54fa1e988a0 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -7,11 +7,9 @@ import ( "google.golang.org/genai" ) -type VertexAIClient ProviderClient - -func newVertexAIClient(opts providerClientOptions) VertexAIClient { - project := opts.extraHeaders["project"] - location := opts.extraHeaders["location"] +func NewVertexAIProvider(base *baseProvider) Provider { + project := base.extraHeaders["project"] + location := base.extraHeaders["location"] client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ Project: project, Location: location, @@ -22,8 +20,8 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient { return nil } - return &geminiClient{ - providerOptions: opts, - client: client, + return &geminiProvider{ + baseProvider: base, + client: client, } } diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 41c0515616032b117f3c09a0056cac9e86b62c66..1791cb7c4bdd1091d34b451f3f4c845fe46ba310 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -3,6 +3,9 @@ package tools import ( "context" "encoding/json" + "slices" + + "github.com/charmbracelet/crush/internal/csync" ) type ToolInfo struct { @@ -83,3 +86,51 @@ func GetContextValues(ctx context.Context) (string, string) { } return sessionID.(string), messageID.(string) } + +type Registry interface { + GetTool(name string) (BaseTool, bool) + SetTool(name string, tool BaseTool) + GetAllTools() []BaseTool +} + +type registry struct { + tools *csync.LazySlice[BaseTool] +} + +func (r *registry) GetAllTools() []BaseTool { + return slices.Collect(r.tools.Seq()) +} + +func (r *registry) GetTool(name string) (BaseTool, bool) { + for tool := range r.tools.Seq() { + if tool.Name() == name { + return tool, true + } + } + + return nil, false +} + +func (r *registry) SetTool(name string, tool BaseTool) { + for k, tool := range r.tools.Seq2() { + if tool.Name() == name { + r.tools.Set(k, tool) + return + } + } + r.tools.Append(tool) +} + +type LazyToolsFn func() []BaseTool + +func NewRegistry(lazyTools LazyToolsFn) Registry { + return ®istry{ + tools: csync.NewLazySlice(lazyTools), + } +} + +func NewRegistryFromTools(tools []BaseTool) Registry { + return ®istry{ + tools: csync.NewLazySlice(func() []BaseTool { return tools }), + } +} diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go new file mode 100644 index 0000000000000000000000000000000000000000..09bf0a43eebb937181675d5236d7bd0079496aae --- /dev/null +++ b/internal/resolver/resolver.go @@ -0,0 +1,188 @@ +package resolver + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/shell" +) + +type Resolver interface { + ResolveValue(value string) (string, error) +} + +type Shell interface { + Exec(ctx context.Context, command string) (stdout, stderr string, err error) +} + +type shellVariableResolver struct { + shell Shell + env env.Env +} + +func NewShellVariableResolver(env env.Env) Resolver { + return &shellVariableResolver{ + env: env, + shell: shell.NewShell( + &shell.Options{ + Env: env.Env(), + }, + ), + } +} + +// ResolveValue is a method for resolving values, such as environment variables. +// it will resolve shell-like variable substitution anywhere in the string, including: +// - $(command) for command substitution +// - $VAR or ${VAR} for environment variables +func (r *shellVariableResolver) ResolveValue(value string) (string, error) { + // Special case: lone $ is an error (backward compatibility) + if value == "$" { + return "", fmt.Errorf("invalid value format: %s", value) + } + + // If no $ found, return as-is + if !strings.Contains(value, "$") { + return value, nil + } + + result := value + + // Handle command substitution: $(command) + for { + start := strings.Index(result, "$(") + if start == -1 { + break + } + + // Find matching closing parenthesis + depth := 0 + end := -1 + for i := start + 2; i < len(result); i++ { + if result[i] == '(' { + depth++ + } else if result[i] == ')' { + if depth == 0 { + end = i + break + } + depth-- + } + } + + if end == -1 { + return "", fmt.Errorf("unmatched $( in value: %s", value) + } + + command := result[start+2 : end] + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + + stdout, _, err := r.shell.Exec(ctx, command) + cancel() + if err != nil { + return "", fmt.Errorf("command execution failed for '%s': %w", command, err) + } + + // Replace the $(command) with the output + replacement := strings.TrimSpace(stdout) + result = result[:start] + replacement + result[end+1:] + } + + // Handle environment variables: $VAR and ${VAR} + searchStart := 0 + for { + start := strings.Index(result[searchStart:], "$") + if start == -1 { + break + } + start += searchStart // Adjust for the offset + + // Skip if this is part of $( which we already handled + if start+1 < len(result) && result[start+1] == '(' { + // Skip past this $(...) + searchStart = start + 1 + continue + } + var varName string + var end int + + if start+1 < len(result) && result[start+1] == '{' { + // Handle ${VAR} format + closeIdx := strings.Index(result[start+2:], "}") + if closeIdx == -1 { + return "", fmt.Errorf("unmatched ${ in value: %s", value) + } + varName = result[start+2 : start+2+closeIdx] + end = start + 2 + closeIdx + 1 + } else { + // Handle $VAR format - variable names must start with letter or underscore + if start+1 >= len(result) { + return "", fmt.Errorf("incomplete variable reference at end of string: %s", value) + } + + if result[start+1] != '_' && + (result[start+1] < 'a' || result[start+1] > 'z') && + (result[start+1] < 'A' || result[start+1] > 'Z') { + return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value) + } + + end = start + 1 + for end < len(result) && (result[end] == '_' || + (result[end] >= 'a' && result[end] <= 'z') || + (result[end] >= 'A' && result[end] <= 'Z') || + (result[end] >= '0' && result[end] <= '9')) { + end++ + } + varName = result[start+1 : end] + } + + envValue := r.env.Get(varName) + if envValue == "" { + return "", fmt.Errorf("environment variable %q not set", varName) + } + + result = result[:start] + envValue + result[end:] + searchStart = start + len(envValue) // Continue searching after the replacement + } + + return result, nil +} + +type environmentVariableResolver struct { + env env.Env +} + +func NewEnvironmentVariableResolver(env env.Env) Resolver { + return &environmentVariableResolver{ + env: env, + } +} + +// ResolveValue resolves environment variables from the provided env.Env. +func (r *environmentVariableResolver) ResolveValue(value string) (string, error) { + if !strings.HasPrefix(value, "$") { + return value, nil + } + + varName := strings.TrimPrefix(value, "$") + resolvedValue := r.env.Get(varName) + if resolvedValue == "" { + return "", fmt.Errorf("environment variable %q not set", varName) + } + return resolvedValue, nil +} + +func New() Resolver { + env := env.New() + return &shellVariableResolver{ + env: env, + shell: shell.NewShell( + &shell.Options{ + Env: env.Env(), + }, + ), + } +} diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1a95683ec7908551de17cd2dc2f35031cb4c7150 --- /dev/null +++ b/internal/resolver/resolver_test.go @@ -0,0 +1,332 @@ +package resolver + +import ( + "context" + "errors" + "testing" + + "github.com/charmbracelet/crush/internal/env" + "github.com/stretchr/testify/assert" +) + +// mockShell implements the Shell interface for testing +type mockShell struct { + execFunc func(ctx context.Context, command string) (stdout, stderr string, err error) +} + +func (m *mockShell) Exec(ctx context.Context, command string) (stdout, stderr string, err error) { + if m.execFunc != nil { + return m.execFunc(ctx, command) + } + return "", "", nil +} + +func TestShellVariableResolver_ResolveValue(t *testing.T) { + tests := []struct { + name string + value string + envVars map[string]string + shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) + expected string + expectError bool + }{ + { + name: "non-variable string returns as-is", + value: "plain-string", + expected: "plain-string", + }, + { + name: "environment variable resolution", + value: "$HOME", + envVars: map[string]string{"HOME": "/home/user"}, + expected: "/home/user", + }, + { + name: "missing environment variable returns error", + value: "$MISSING_VAR", + envVars: map[string]string{}, + expectError: true, + }, + + { + name: "shell command with whitespace trimming", + value: "$(echo ' spaced ')", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo ' spaced '" { + return " spaced \n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "spaced", + }, + { + name: "shell command execution error", + value: "$(false)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + return "", "", errors.New("command failed") + }, + expectError: true, + }, + { + name: "invalid format returns error", + value: "$", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testEnv := env.NewFromMap(tt.envVars) + resolver := &shellVariableResolver{ + shell: &mockShell{execFunc: tt.shellFunc}, + env: testEnv, + } + + result, err := resolver.ResolveValue(tt.value) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { + tests := []struct { + name string + value string + envVars map[string]string + shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) + expected string + expectError bool + }{ + { + name: "command substitution within string", + value: "Bearer $(echo token123)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo token123" { + return "token123\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "Bearer token123", + }, + { + name: "environment variable within string", + value: "Bearer $TOKEN", + envVars: map[string]string{"TOKEN": "sk-ant-123"}, + expected: "Bearer sk-ant-123", + }, + { + name: "environment variable with braces within string", + value: "Bearer ${TOKEN}", + envVars: map[string]string{"TOKEN": "sk-ant-456"}, + expected: "Bearer sk-ant-456", + }, + { + name: "mixed command and environment substitution", + value: "$USER-$(date +%Y)-$HOST", + envVars: map[string]string{ + "USER": "testuser", + "HOST": "localhost", + }, + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "date +%Y" { + return "2024\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "testuser-2024-localhost", + }, + { + name: "multiple command substitutions", + value: "$(echo hello) $(echo world)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + switch command { + case "echo hello": + return "hello\n", "", nil + case "echo world": + return "world\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "hello world", + }, + { + name: "nested parentheses in command", + value: "$(echo $(echo inner))", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo $(echo inner)" { + return "nested\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "nested", + }, + { + name: "lone dollar with non-variable chars", + value: "prefix$123suffix", // Numbers can't start variable names + expectError: true, + }, + { + name: "dollar with special chars", + value: "a$@b$#c", // Special chars aren't valid in variable names + expectError: true, + }, + { + name: "empty environment variable substitution", + value: "Bearer $EMPTY_VAR", + envVars: map[string]string{}, + expectError: true, + }, + { + name: "unmatched command substitution opening", + value: "Bearer $(echo test", + expectError: true, + }, + { + name: "unmatched environment variable braces", + value: "Bearer ${TOKEN", + expectError: true, + }, + { + name: "command substitution with error", + value: "Bearer $(false)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + return "", "", errors.New("command failed") + }, + expectError: true, + }, + { + name: "complex real-world example", + value: "Bearer $(cat /tmp/token.txt | base64 -w 0)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "cat /tmp/token.txt | base64 -w 0" { + return "c2stYW50LXRlc3Q=\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "Bearer c2stYW50LXRlc3Q=", + }, + { + name: "environment variable with underscores and numbers", + value: "Bearer $API_KEY_V2", + envVars: map[string]string{"API_KEY_V2": "sk-test-123"}, + expected: "Bearer sk-test-123", + }, + { + name: "no substitution needed", + value: "Bearer sk-ant-static-token", + expected: "Bearer sk-ant-static-token", + }, + { + name: "incomplete variable at end", + value: "Bearer $", + expectError: true, + }, + { + name: "variable with invalid character", + value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names + expectError: true, + }, + { + name: "multiple invalid variables", + value: "$1$2$3", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testEnv := env.NewFromMap(tt.envVars) + resolver := &shellVariableResolver{ + shell: &mockShell{execFunc: tt.shellFunc}, + env: testEnv, + } + + result, err := resolver.ResolveValue(tt.value) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { + tests := []struct { + name string + value string + envVars map[string]string + expected string + expectError bool + }{ + { + name: "non-variable string returns as-is", + value: "plain-string", + expected: "plain-string", + }, + { + name: "environment variable resolution", + value: "$HOME", + envVars: map[string]string{"HOME": "/home/user"}, + expected: "/home/user", + }, + { + name: "environment variable with complex value", + value: "$PATH", + envVars: map[string]string{"PATH": "/usr/bin:/bin:/usr/local/bin"}, + expected: "/usr/bin:/bin:/usr/local/bin", + }, + { + name: "missing environment variable returns error", + value: "$MISSING_VAR", + envVars: map[string]string{}, + expectError: true, + }, + { + name: "empty environment variable returns error", + value: "$EMPTY_VAR", + envVars: map[string]string{"EMPTY_VAR": ""}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testEnv := env.NewFromMap(tt.envVars) + resolver := NewEnvironmentVariableResolver(testEnv) + + result, err := resolver.ResolveValue(tt.value) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestNewShellVariableResolver(t *testing.T) { + testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) + resolver := NewShellVariableResolver(testEnv) + + assert.NotNil(t, resolver) + assert.Implements(t, (*Resolver)(nil), resolver) +} + +func TestNewEnvironmentVariableResolver(t *testing.T) { + testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) + resolver := NewEnvironmentVariableResolver(testEnv) + + assert.NotNil(t, resolver) + assert.Implements(t, (*Resolver)(nil), resolver) +} diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index d994c1ffd608ba42eeabb01a510c6a04fe67a2df..fb4ee6f501fb1e53e2a49519368e3c0dac3ef164 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -351,6 +351,7 @@ func (m *messageListCmp) updateAssistantMessageContent(msg message.Message, assi messages.NewAssistantSection( msg, time.Unix(m.lastUserMessageTime, 0), + m.app.Config(), ), ) } @@ -472,7 +473,7 @@ func (m *messageListCmp) convertMessagesToUI(sessionMessages []message.Message, case message.Assistant: uiMessages = append(uiMessages, m.convertAssistantMessage(msg, toolResultMap)...) if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonEndTurn { - uiMessages = append(uiMessages, messages.NewAssistantSection(msg, time.Unix(m.lastUserMessageTime, 0))) + uiMessages = append(uiMessages, messages.NewAssistantSection(msg, time.Unix(m.lastUserMessageTime, 0), m.app.Config())) } } } diff --git a/internal/tui/components/chat/header/header.go b/internal/tui/components/chat/header/header.go index 4eac0c2444321a59c06d2e83d328fd1ea9e8512c..7ae0838882448138366fdda67fa1ab72fcbfacd0 100644 --- a/internal/tui/components/chat/header/header.go +++ b/internal/tui/components/chat/header/header.go @@ -5,9 +5,8 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea/v2" - "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" @@ -27,14 +26,14 @@ type Header interface { type header struct { width int session session.Session - lspClients map[string]*lsp.Client + app *app.App detailsOpen bool } -func New(lspClients map[string]*lsp.Client) Header { +func New(app *app.App) Header { return &header{ - lspClients: lspClients, - width: 0, + app: app, + width: 0, } } @@ -88,13 +87,13 @@ func (p *header) View() string { func (h *header) details() string { t := styles.CurrentTheme() - cwd := fsext.DirTrim(fsext.PrettyPath(config.Get().WorkingDir()), 4) + cwd := fsext.DirTrim(fsext.PrettyPath(h.app.Config().WorkingDir()), 4) parts := []string{ t.S().Muted.Render(cwd), } errorCount := 0 - for _, l := range h.lspClients { + for _, l := range h.app.LSPClients { for _, diagnostics := range l.GetDiagnostics() { for _, diagnostic := range diagnostics { if diagnostic.Severity == protocol.SeverityError { @@ -108,8 +107,7 @@ func (h *header) details() string { parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount))) } - agentCfg := config.Get().Agents["coder"] - model := config.Get().GetModelByType(agentCfg.Model) + model := h.app.CoderAgent.Model() percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100 formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage))) parts = append(parts, formattedPercentage) diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index cb1ea90cf34a3cd3b206ce6ef019feea9bc240f9..d1847cdafbeee0b0df6ee3f89b4c0810b976a33f 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -352,6 +352,7 @@ type AssistantSection interface { } type assistantSectionModel struct { width int + config *config.Config id string message message.Message lastUserMessageTime time.Time @@ -362,9 +363,10 @@ func (m *assistantSectionModel) ID() string { return m.id } -func NewAssistantSection(message message.Message, lastUserMessageTime time.Time) AssistantSection { +func NewAssistantSection(message message.Message, lastUserMessageTime time.Time, cfg *config.Config) AssistantSection { return &assistantSectionModel{ width: 0, + config: cfg, id: uuid.NewString(), message: message, lastUserMessageTime: lastUserMessageTime, @@ -386,7 +388,7 @@ func (m *assistantSectionModel) View() string { duration := finishTime.Sub(m.lastUserMessageTime) infoMsg := t.S().Subtle.Render(duration.String()) icon := t.S().Subtle.Render(styles.ModelIcon) - model := config.Get().GetModel(m.message.Provider, m.message.Model) + model := m.config.GetModel(m.message.Provider, m.message.Model) if model == nil { // This means the model is not configured anymore model = &catwalk.Model{ diff --git a/internal/tui/components/chat/sidebar/sidebar.go b/internal/tui/components/chat/sidebar/sidebar.go index 1f5fd2a672e3d643efbed4ca35b08ed88c55d2eb..0a4221ea8bf550e6c8132e86d1c4634147d0b7eb 100644 --- a/internal/tui/components/chat/sidebar/sidebar.go +++ b/internal/tui/components/chat/sidebar/sidebar.go @@ -10,12 +10,11 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/diff" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/lsp/protocol" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" @@ -69,16 +68,14 @@ type sidebarCmp struct { session session.Session logo string cwd string - lspClients map[string]*lsp.Client compactMode bool - history history.Service files *csync.Map[string, SessionFile] + app *app.App } -func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar { +func New(app *app.App, compact bool) Sidebar { return &sidebarCmp{ - lspClients: lspClients, - history: history, + app: app, compactMode: compact, files: csync.NewMap[string, SessionFile](), } @@ -194,7 +191,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te before := existing.History.initialVersion.Content after := existing.History.latestVersion.Content path := existing.History.initialVersion.Path - cwd := config.Get().WorkingDir() + cwd := m.app.Config().WorkingDir() path = strings.TrimPrefix(path, cwd) _, additions, deletions := diff.GenerateDiff(before, after, path) existing.Additions = additions @@ -221,7 +218,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te } func (m *sidebarCmp) loadSessionFiles() tea.Msg { - files, err := m.history.ListBySession(context.Background(), m.session.ID) + files, err := m.app.History.ListBySession(context.Background(), m.session.ID) if err != nil { return util.InfoMsg{ Type: util.InfoTypeError, @@ -247,7 +244,7 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg { sessionFiles := make([]SessionFile, 0, len(fileMap)) for path, fh := range fileMap { - cwd := config.Get().WorkingDir() + cwd := m.app.Config().WorkingDir() path = strings.TrimPrefix(path, cwd) _, additions, deletions := diff.GenerateDiff(fh.initialVersion.Content, fh.latestVersion.Content, path) sessionFiles = append(sessionFiles, SessionFile{ @@ -265,7 +262,7 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg { func (m *sidebarCmp) SetSize(width, height int) tea.Cmd { m.logo = m.logoBlock() - m.cwd = cwd() + m.cwd = cwd(m.app.Config().WorkingDir()) m.width = width m.height = height return nil @@ -428,7 +425,7 @@ func (m *sidebarCmp) filesBlockCompact(maxWidth int) string { } extraContent := strings.Join(statusParts, " ") - cwd := config.Get().WorkingDir() + string(os.PathSeparator) + cwd := m.app.Config().WorkingDir() + string(os.PathSeparator) filePath := file.FilePath filePath = strings.TrimPrefix(filePath, cwd) filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2) @@ -471,7 +468,7 @@ func (m *sidebarCmp) lspBlockCompact(maxWidth int) string { lspList := []string{section, ""} - lsp := config.Get().LSP.Sorted() + lsp := m.app.Config().LSP.Sorted() if len(lsp) == 0 { content := lipgloss.JoinVertical( lipgloss.Left, @@ -505,7 +502,7 @@ func (m *sidebarCmp) lspBlockCompact(maxWidth int) string { protocol.SeverityHint: 0, protocol.SeverityInformation: 0, } - if client, ok := m.lspClients[l.Name]; ok { + if client, ok := m.app.LSPClients[l.Name]; ok { for _, diagnostics := range client.GetDiagnostics() { for _, diagnostic := range diagnostics { if severity, ok := lspErrs[diagnostic.Severity]; ok { @@ -559,7 +556,7 @@ func (m *sidebarCmp) mcpBlockCompact(maxWidth int) string { mcpList := []string{section, ""} - mcps := config.Get().MCP.Sorted() + mcps := m.app.Config().MCP.Sorted() if len(mcps) == 0 { content := lipgloss.JoinVertical( lipgloss.Left, @@ -653,7 +650,7 @@ func (m *sidebarCmp) filesBlock() string { } extraContent := strings.Join(statusParts, " ") - cwd := config.Get().WorkingDir() + string(os.PathSeparator) + cwd := m.app.Config().WorkingDir() + string(os.PathSeparator) filePath := file.FilePath filePath = strings.TrimPrefix(filePath, cwd) filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2) @@ -701,7 +698,7 @@ func (m *sidebarCmp) lspBlock() string { lspList := []string{section, ""} - lsp := config.Get().LSP.Sorted() + lsp := m.app.Config().LSP.Sorted() if len(lsp) == 0 { return lipgloss.JoinVertical( lipgloss.Left, @@ -729,7 +726,7 @@ func (m *sidebarCmp) lspBlock() string { protocol.SeverityHint: 0, protocol.SeverityInformation: 0, } - if client, ok := m.lspClients[l.Name]; ok { + if client, ok := m.app.LSPClients[l.Name]; ok { for _, diagnostics := range client.GetDiagnostics() { for _, diagnostic := range diagnostics { if severity, ok := lspErrs[diagnostic.Severity]; ok { @@ -789,7 +786,7 @@ func (m *sidebarCmp) mcpBlock() string { mcpList := []string{section, ""} - mcps := config.Get().MCP.Sorted() + mcps := m.app.Config().MCP.Sorted() if len(mcps) == 0 { return lipgloss.JoinVertical( lipgloss.Left, @@ -876,13 +873,9 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string { } func (s *sidebarCmp) currentModelBlock() string { - cfg := config.Get() - agentCfg := cfg.Agents["coder"] - - selectedModel := cfg.Models[agentCfg.Model] - - model := config.Get().GetModelByType(agentCfg.Model) - modelProvider := config.Get().GetProviderForModel(agentCfg.Model) + model := s.app.CoderAgent.Model() + selectedModel := s.app.CoderAgent.ModelConfig() + modelProvider := s.app.CoderAgent.Provider() t := styles.CurrentTheme() @@ -938,8 +931,7 @@ func (m *sidebarCmp) SetCompactMode(compact bool) { m.compactMode = compact } -func cwd() string { - cwd := config.Get().WorkingDir() +func cwd(cwd string) string { t := styles.CurrentTheme() // Replace home directory with ~, unless we're at the top level of the // home directory). diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 66f8f697aa4e6c51b199d1fee8667263f4608714..f17d6981dc0a7665132883efd8357b193b1c66c7 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -12,7 +12,9 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/llm/prompt" + "github.com/charmbracelet/crush/internal/llm/provider" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/layout" @@ -71,9 +73,10 @@ type splashCmp struct { selectedModel *models.ModelOption isAPIKeyValid bool apiKeyValue string + config *config.Config } -func New() Splash { +func New(cfg *config.Config) Splash { keyMap := DefaultKeyMap() listKeyMap := list.DefaultKeyMap() listKeyMap.Down.SetEnabled(false) @@ -85,12 +88,13 @@ func New() Splash { listKeyMap.DownOneItem = keyMap.Next listKeyMap.UpOneItem = keyMap.Previous - modelList := models.NewModelListComponent(listKeyMap, "Find your fave", false) + modelList := models.NewModelListComponent(cfg, listKeyMap, "Find your fave", false) apiKeyInput := models.NewAPIKeyInput() return &splashCmp{ width: 0, height: 0, + config: cfg, keyMap: keyMap, logoRendered: "", modelList: modelList, @@ -214,16 +218,16 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return s, nil } - provider, err := s.getProvider(s.selectedModel.Provider.ID) - if err != nil || provider == nil { + selectedProvider, err := s.getProvider(s.selectedModel.Provider.ID) + if err != nil || selectedProvider == nil { return s, util.ReportError(fmt.Errorf("provider %s not found", s.selectedModel.Provider.ID)) } - providerConfig := config.ProviderConfig{ + providerConfig := provider.Config{ ID: string(s.selectedModel.Provider.ID), Name: s.selectedModel.Provider.Name, APIKey: s.apiKeyValue, - Type: provider.Type, - BaseURL: provider.APIEndpoint, + Type: selectedProvider.Type, + BaseURL: selectedProvider.APIEndpoint, } return s, tea.Sequence( util.CmdHandler(models.APIKeyStateChangeMsg{ @@ -231,7 +235,7 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { }), func() tea.Msg { start := time.Now() - err := providerConfig.TestConnection(config.Get().Resolver()) + err := providerConfig.TestConnection(s.config.Resolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner elapsed := time.Since(start) if elapsed < 750*time.Millisecond { @@ -320,8 +324,7 @@ func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { return util.ReportError(fmt.Errorf("no model selected")) } - cfg := config.Get() - err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey) + err := s.config.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey) if err != nil { return util.ReportError(fmt.Errorf("failed to save API key: %w", err)) } @@ -338,7 +341,7 @@ func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { func (s *splashCmp) initializeProject() tea.Cmd { s.needsProjectInit = false - if err := config.MarkProjectInitialized(); err != nil { + if err := config.MarkProjectInitialized(s.config); err != nil { return util.ReportError(err) } var cmds []tea.Cmd @@ -356,20 +359,19 @@ func (s *splashCmp) initializeProject() tea.Cmd { } func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { - cfg := config.Get() - model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID) + model := s.config.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID) if model == nil { return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID)) } - selectedModel := config.SelectedModel{ + selectedModel := agent.Model{ Model: selectedItem.Model.ID, Provider: string(selectedItem.Provider.ID), ReasoningEffort: model.DefaultReasoningEffort, MaxTokens: model.DefaultMaxTokens, } - err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel) + err := s.config.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel) if err != nil { return util.ReportError(err) } @@ -381,33 +383,32 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { } if knownProvider == nil { // for local provider we just use the same model - err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel) + err = s.config.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel) if err != nil { return util.ReportError(err) } } else { smallModel := knownProvider.DefaultSmallModelID - model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel) + model := s.config.GetModel(string(selectedItem.Provider.ID), smallModel) // should never happen if model == nil { - err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel) + err = s.config.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel) if err != nil { return util.ReportError(err) } return nil } - smallSelectedModel := config.SelectedModel{ + smallSelectedModel := agent.Model{ Model: smallModel, Provider: string(selectedItem.Provider.ID), ReasoningEffort: model.DefaultReasoningEffort, MaxTokens: model.DefaultMaxTokens, } - err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel) + err = s.config.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel) if err != nil { return util.ReportError(err) } } - cfg.SetupAgents() return nil } @@ -425,8 +426,7 @@ func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk. } func (s *splashCmp) isProviderConfigured(providerID string) bool { - cfg := config.Get() - if _, ok := cfg.Providers.Get(providerID); ok { + if _, ok := s.config.Providers.Get(providerID); ok { return true } return false @@ -652,7 +652,7 @@ func (s *splashCmp) getMaxInfoWidth() int { } func (s *splashCmp) cwd() string { - cwd := config.Get().WorkingDir() + cwd := s.config.WorkingDir() t := styles.CurrentTheme() homeDir, err := os.UserHomeDir() if err == nil && cwd != homeDir { @@ -662,10 +662,10 @@ func (s *splashCmp) cwd() string { return t.S().Muted.Width(maxWidth).Render(cwd) } -func LSPList(maxWidth int) []string { +func LSPList(cfg *config.Config, maxWidth int) []string { t := styles.CurrentTheme() lspList := []string{} - lsp := config.Get().LSP.Sorted() + lsp := cfg.LSP.Sorted() if len(lsp) == 0 { return []string{t.S().Base.Foreground(t.Border).Render("None")} } @@ -692,7 +692,7 @@ func (s *splashCmp) lspBlock() string { t := styles.CurrentTheme() maxWidth := s.getMaxInfoWidth() / 2 section := t.S().Subtle.Render("LSPs") - lspList := append([]string{section, ""}, LSPList(maxWidth-1)...) + lspList := append([]string{section, ""}, LSPList(s.config, maxWidth-1)...) return t.S().Base.Width(maxWidth).PaddingRight(1).Render( lipgloss.JoinVertical( lipgloss.Left, @@ -701,10 +701,10 @@ func (s *splashCmp) lspBlock() string { ) } -func MCPList(maxWidth int) []string { +func MCPList(cfg *config.Config, maxWidth int) []string { t := styles.CurrentTheme() mcpList := []string{} - mcps := config.Get().MCP.Sorted() + mcps := cfg.MCP.Sorted() if len(mcps) == 0 { return []string{t.S().Base.Foreground(t.Border).Render("None")} } @@ -731,7 +731,7 @@ func (s *splashCmp) mcpBlock() string { t := styles.CurrentTheme() maxWidth := s.getMaxInfoWidth() / 2 section := t.S().Subtle.Render("MCPs") - mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...) + mcpList := append([]string{section, ""}, MCPList(s.config, maxWidth-1)...) return t.S().Base.Width(maxWidth).PaddingRight(1).Render( lipgloss.JoinVertical( lipgloss.Left, diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 50a67b77be373f987849953d0d60d9773caeb752..89cc36b9333d06fe115409d51365becb51329484 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -7,7 +7,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/lipgloss/v2" - "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" @@ -49,6 +49,7 @@ type commandDialogCmp struct { wWidth int // Width of the terminal window wHeight int // Height of the terminal window + app *app.App commandList listModel keyMap CommandsDialogKeyMap help help.Model @@ -67,7 +68,7 @@ type ( } ) -func NewCommandDialog(sessionID string) CommandsDialog { +func NewCommandDialog(app *app.App, sessionID string) CommandsDialog { keyMap := DefaultCommandsDialogKeyMap() listKeyMap := list.DefaultKeyMap() listKeyMap.Down.SetEnabled(false) @@ -89,6 +90,7 @@ func NewCommandDialog(sessionID string) CommandsDialog { help := help.New() help.Styles = t.S().Help return &commandDialogCmp{ + app: app, commandList: commandList, width: defaultWidth, keyMap: DefaultCommandsDialogKeyMap(), @@ -99,7 +101,7 @@ func NewCommandDialog(sessionID string) CommandsDialog { } func (c *commandDialogCmp) Init() tea.Cmd { - commands, err := LoadCustomCommands() + commands, err := LoadCustomCommands(c.app.Config()) if err != nil { return util.ReportError(err) } @@ -274,13 +276,12 @@ func (c *commandDialogCmp) defaultCommands() []Command { } // Only show thinking toggle for Anthropic models that can reason - cfg := config.Get() - if agentCfg, ok := cfg.Agents["coder"]; ok { - providerCfg := cfg.GetProviderForModel(agentCfg.Model) - model := cfg.GetModelByType(agentCfg.Model) + if c.app.CoderAgent != nil { + providerCfg := c.app.CoderAgent.Provider() + model := c.app.CoderAgent.Model() if providerCfg != nil && model != nil && providerCfg.Type == catwalk.TypeAnthropic && model.CanReason { - selectedModel := cfg.Models[agentCfg.Model] + selectedModel := c.app.CoderAgent.ModelConfig() status := "Enable" if selectedModel.Think { status = "Disable" diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go index 9aee528ee48d0f23e48c417f8bee5bc0e3f381c5..89256846c62631edaa9fe6f904acc6eb0d149e10 100644 --- a/internal/tui/components/dialogs/commands/loader.go +++ b/internal/tui/components/dialogs/commands/loader.go @@ -29,8 +29,7 @@ type commandSource struct { prefix string } -func LoadCustomCommands() ([]Command, error) { - cfg := config.Get() +func LoadCustomCommands(cfg *config.Config) ([]Command, error) { if cfg == nil { return nil, fmt.Errorf("config not loaded") } diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index d68e701160e99f36d68a453f0f8095a281d584ed..c94bba1e9e65c8a3f45e975dba4b10189a1700e4 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/tui/exp/list" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" @@ -18,9 +19,10 @@ type ModelListComponent struct { list listModel modelType int providers []catwalk.Provider + config *config.Config } -func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent { +func NewModelListComponent(cfg *config.Config, keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent { t := styles.CurrentTheme() inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1) options := []list.ListOption{ @@ -42,6 +44,7 @@ func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldRe return &ModelListComponent{ list: modelList, modelType: LargeModelType, + config: cfg, } } @@ -94,12 +97,11 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { // first none section selectedItemID := "" - cfg := config.Get() - var currentModel config.SelectedModel + var currentModel agent.Model if m.modelType == LargeModelType { - currentModel = cfg.Models[config.SelectedModelTypeLarge] + currentModel = m.config.Models[config.SelectedModelTypeLarge] } else { - currentModel = cfg.Models[config.SelectedModelTypeSmall] + currentModel = m.config.Models[config.SelectedModelTypeSmall] } configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon) @@ -114,7 +116,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { if err != nil { return util.ReportError(err) } - for providerID, providerConfig := range cfg.Providers.Seq2() { + for providerID, providerConfig := range m.config.Providers.Seq2() { if providerConfig.Disable { continue } @@ -185,7 +187,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } // Check if this provider is configured and not disabled - if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable { + if providerConfig, exists := m.config.Providers.Get(string(provider.ID)); exists && providerConfig.Disable { continue } @@ -195,7 +197,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } section := list.NewItemSection(name) - if _, ok := cfg.Providers.Get(string(provider.ID)); ok { + if _, ok := m.config.Providers.Get(string(provider.ID)); ok { section.SetInfo(configured) } group := list.Group[list.CompletionItem[ModelOption]]{ diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index e09b040a52ebf911ceefc455b0892c7c9ceba754..2cb14839b123594f7b6c915cffe5e43df8584719 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -10,6 +10,8 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/llm/provider" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" "github.com/charmbracelet/crush/internal/tui/exp/list" @@ -34,7 +36,7 @@ const ( // ModelSelectedMsg is sent when a model is selected type ModelSelectedMsg struct { - Model config.SelectedModel + Model agent.Model ModelType config.SelectedModelType } @@ -56,6 +58,7 @@ type modelDialogCmp struct { wWidth int wHeight int + config *config.Config modelList *ModelListComponent keyMap KeyMap help help.Model @@ -69,7 +72,7 @@ type modelDialogCmp struct { apiKeyValue string } -func NewModelDialogCmp() ModelDialog { +func NewModelDialogCmp(cfg *config.Config) ModelDialog { keyMap := DefaultKeyMap() listKeyMap := list.DefaultKeyMap() @@ -79,7 +82,7 @@ func NewModelDialogCmp() ModelDialog { listKeyMap.UpOneItem = keyMap.Previous t := styles.CurrentTheme() - modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true) + modelList := NewModelListComponent(cfg, listKeyMap, "Choose a model for large, complex tasks", true) apiKeyInput := NewAPIKeyInput() apiKeyInput.SetShowTitle(false) help := help.New() @@ -91,6 +94,7 @@ func NewModelDialogCmp() ModelDialog { width: defaultWidth, keyMap: DefaultKeyMap(), help: help, + config: cfg, } } @@ -119,16 +123,16 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.needsAPIKey { // Handle API key submission m.apiKeyValue = m.apiKeyInput.Value() - provider, err := m.getProvider(m.selectedModel.Provider.ID) - if err != nil || provider == nil { + selectedProvider, err := m.getProvider(m.selectedModel.Provider.ID) + if err != nil || selectedProvider == nil { return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID)) } - providerConfig := config.ProviderConfig{ + providerConfig := provider.Config{ ID: string(m.selectedModel.Provider.ID), Name: m.selectedModel.Provider.Name, APIKey: m.apiKeyValue, - Type: provider.Type, - BaseURL: provider.APIEndpoint, + Type: selectedProvider.Type, + BaseURL: selectedProvider.APIEndpoint, } return m, tea.Sequence( util.CmdHandler(APIKeyStateChangeMsg{ @@ -136,7 +140,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { }), func() tea.Msg { start := time.Now() - err := providerConfig.TestConnection(config.Get().Resolver()) + err := providerConfig.TestConnection(m.config.Resolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner elapsed := time.Since(start) if elapsed < 750*time.Millisecond { @@ -169,7 +173,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), util.CmdHandler(ModelSelectedMsg{ - Model: config.SelectedModel{ + Model: agent.Model{ Model: selectedItem.Model.ID, Provider: string(selectedItem.Provider.ID), }, @@ -342,8 +346,7 @@ func (m *modelDialogCmp) modelTypeRadio() string { } func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { - cfg := config.Get() - if _, ok := cfg.Providers.Get(providerID); ok { + if _, ok := m.config.Providers.Get(providerID); ok { return true } return false @@ -367,8 +370,7 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { return util.ReportError(fmt.Errorf("no model selected")) } - cfg := config.Get() - err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey) + err := m.config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey) if err != nil { return util.ReportError(fmt.Errorf("failed to save API key: %w", err)) } @@ -378,7 +380,7 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { return tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), util.CmdHandler(ModelSelectedMsg{ - Model: config.SelectedModel{ + Model: agent.Model{ Model: selectedModel.Model.ID, Provider: string(selectedModel.Provider.ID), }, diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 4b4495709d6359919b8525af73b6fcb1a09db330..785ea226625adee6c746dd01c209cc4c7bf20adb 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -117,29 +117,28 @@ func New(app *app.App) ChatPage { return &chatPage{ app: app, keyMap: DefaultKeyMap(), - header: header.New(app.LSPClients), - sidebar: sidebar.New(app.History, app.LSPClients, false), + header: header.New(app), + sidebar: sidebar.New(app, false), chat: chat.New(app), editor: editor.New(app), - splash: splash.New(), + splash: splash.New(app.Config()), focusedPane: PanelTypeSplash, } } func (p *chatPage) Init() tea.Cmd { - cfg := config.Get() - compact := cfg.Options.TUI.CompactMode + compact := p.app.Config().Options.TUI.CompactMode p.compact = compact p.forceCompact = compact p.sidebar.SetCompactMode(p.compact) // Set splash state based on config - if !config.HasInitialDataConfig() { + if !config.HasInitialDataConfig(p.app.Config()) { // First-time setup: show model selection p.splash.SetOnboarding(true) p.isOnboarding = true p.splashFullScreen = true - } else if b, _ := config.ProjectNeedsInitialization(); b { + } else if b, _ := config.ProjectNeedsInitialization(p.app.Config()); b { // Project needs CRUSH.md initialization p.splash.SetProjectInit(true) p.isProjectInit = true @@ -275,7 +274,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case splash.OnboardingCompleteMsg: p.splashFullScreen = false - if b, _ := config.ProjectNeedsInitialization(); b { + if b, _ := config.ProjectNeedsInitialization(p.app.Config()); b { p.splash.SetProjectInit(true) p.splashFullScreen = true return p, p.SetSize(p.width, p.height) @@ -296,8 +295,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return p, p.newSession() case key.Matches(msg, p.keyMap.AddAttachment): - agentCfg := config.Get().Agents["coder"] - model := config.Get().GetModelByType(agentCfg.Model) + model := p.app.CoderAgent.Model() if model.SupportsImages { return p, util.CmdHandler(OpenFilePickerMsg{}) } else { @@ -441,7 +439,7 @@ func (p *chatPage) View() string { func (p *chatPage) updateCompactConfig(compact bool) tea.Cmd { return func() tea.Msg { - err := config.Get().SetCompactMode(compact) + err := p.app.Config().SetCompactMode(compact) if err != nil { return util.InfoMsg{ Type: util.InfoTypeError, @@ -454,13 +452,11 @@ func (p *chatPage) updateCompactConfig(compact bool) tea.Cmd { func (p *chatPage) toggleThinking() tea.Cmd { return func() tea.Msg { - cfg := config.Get() - agentCfg := cfg.Agents["coder"] - currentModel := cfg.Models[agentCfg.Model] + currentModel := p.app.CoderAgent.ModelConfig() // Toggle the thinking mode currentModel.Think = !currentModel.Think - cfg.Models[agentCfg.Model] = currentModel + p.app.Config().Models[config.SelectedModelTypeLarge] = currentModel // Update the agent with the new configuration if err := p.app.UpdateAgentModel(); err != nil { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 62d38fd595d876dad2a384f155cc01d62db59cc9..702317fdccd0dca4ac80351b5fb1d1cfe9c0d533 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -93,7 +93,7 @@ func (a appModel) Init() tea.Cmd { func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd var cmd tea.Cmd - a.isConfigured = config.HasInitialDataConfig() + a.isConfigured = config.HasInitialDataConfig(a.app.Config()) switch msg := msg.(type) { case tea.KeyboardEnhancementsMsg: @@ -162,7 +162,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case commands.SwitchModelMsg: return a, util.CmdHandler( dialogs.OpenDialogMsg{ - Model: models.NewModelDialogCmp(), + Model: models.NewModelDialogCmp(a.app.Config()), }, ) // Compact @@ -173,7 +173,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: - config.Get().UpdatePreferredModel(msg.ModelType, msg.Model) + a.app.Config().UpdatePreferredModel(msg.ModelType, msg.Model) // Update the agent with the new model/provider configuration if err := a.app.UpdateAgentModel(); err != nil { @@ -234,7 +234,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { model := a.app.CoderAgent.Model() contextWindow := model.ContextWindow tokens := session.CompletionTokens + session.PromptTokens - if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog + if (tokens >= int64(float64(contextWindow)*0.95)) && !a.app.Config().Options.DisableAutoSummarize { // Show compact confirmation dialog cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false), })) @@ -244,7 +244,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, tea.Batch(cmds...) case splash.OnboardingCompleteMsg: - a.isConfigured = config.HasInitialDataConfig() + a.isConfigured = config.HasInitialDataConfig(a.app.Config()) updated, pageCmd := a.pages[a.currentPage].Update(msg) a.pages[a.currentPage] = updated.(util.Model) cmds = append(cmds, pageCmd) @@ -348,7 +348,7 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { return nil } return util.CmdHandler(dialogs.OpenDialogMsg{ - Model: commands.NewCommandDialog(a.selectedSessionID), + Model: commands.NewCommandDialog(a.app, a.selectedSessionID), }) case key.Matches(msg, a.keyMap.Sessions): // if the app is not configured show no sessions diff --git a/main.go b/main.go index 072e3b35d2a2f408d8ed6a09423712b324df8b96..356228c578602aa171fb5b20e82ad1cda08c6885 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,7 @@ import ( _ "github.com/joho/godotenv/autoload" // automatically load .env files - "github.com/charmbracelet/crush/internal/cmd" + "github.com/charmbracelet/crush/cmd" "github.com/charmbracelet/crush/internal/log" )