internal/cmd/logs.go → cmd/logs.go 🔗
Kujtim Hoxha created
cmd/logs.go | 0
cmd/root.go | 2
internal/app/app.go | 33
internal/config/config.go | 235 -------
internal/config/init.go | 31
internal/config/load.go | 28
internal/config/load_test.go | 119 ++-
internal/csync/slices.go | 26
internal/llm/agent/agent.go | 418 +++++--------
internal/llm/agent/coder.go | 52 +
internal/llm/agent/mcp.go | 97 ++
internal/llm/agent/tools.go | 56 +
internal/llm/prompt/coder.go | 27
internal/llm/prompt/prompt.go | 46
internal/llm/prompt/task.go | 4
internal/llm/provider/anthropic.go | 109 +--
internal/llm/provider/azure.go | 22
internal/llm/provider/bedrock.go | 86 --
internal/llm/provider/gemini.go | 100 +-
internal/llm/provider/openai.go | 97 +-
internal/llm/provider/openai_test.go | 90 ---
internal/llm/provider/provider.go | 268 +++++---
internal/llm/provider/vertexai.go | 14
internal/llm/tools/tools.go | 51 +
internal/resolver/resolver.go | 188 ++++++
internal/resolver/resolver_test.go | 332 +++++++++++
internal/tui/components/chat/chat.go | 3
internal/tui/components/chat/header/header.go | 18
internal/tui/components/chat/messages/messages.go | 6
internal/tui/components/chat/sidebar/sidebar.go | 48
internal/tui/components/chat/splash/splash.go | 60 +-
internal/tui/components/dialogs/commands/commands.go | 17
internal/tui/components/dialogs/commands/loader.go | 3
internal/tui/components/dialogs/models/list.go | 18
internal/tui/components/dialogs/models/models.go | 32
internal/tui/page/chat/chat.go | 26
internal/tui/tui.go | 12
main.go | 2
38 files changed, 1,573 insertions(+), 1,203 deletions(-)
@@ -69,7 +69,7 @@ to assist developers in writing, debugging, and understanding code directly from
cwd = c
}
- cfg, err := config.Init(cwd, debug)
+ cfg, err := config.Load(cwd, debug)
if err != nil {
return err
}
@@ -17,12 +17,12 @@ import (
"github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/agent"
+ "github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/log"
- "github.com/charmbracelet/crush/internal/pubsub"
-
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/session"
)
@@ -196,7 +196,9 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool
}
func (app *App) UpdateAgentModel() error {
- return app.CoderAgent.UpdateModel()
+ small := app.config.Models[config.SelectedModelTypeSmall]
+ large := app.config.Models[config.SelectedModelTypeLarge]
+ return app.CoderAgent.UpdateModels(small, large)
}
func (app *App) setupEvents() {
@@ -250,23 +252,32 @@ func setupSubscriber[T any](
}
func (app *App) InitCoderAgent() error {
- coderAgentCfg := app.config.Agents["coder"]
- if coderAgentCfg.ID == "" {
- return fmt.Errorf("coder agent configuration is missing")
- }
var err error
- app.CoderAgent, err = agent.NewAgent(
- coderAgentCfg,
- app.Permissions,
+ providers := map[string]provider.Config{}
+ maps.Insert(providers, app.config.Providers.Seq2())
+ app.CoderAgent, err = agent.NewCoderAgent(
+ app.globalCtx,
+ app.config.WorkingDir(),
+ providers,
+ app.config.Models[config.SelectedModelTypeSmall],
+ app.config.Models[config.SelectedModelTypeLarge],
+ app.config.Options.ContextPaths,
app.Sessions,
app.Messages,
- app.History,
+ app.Permissions,
app.LSPClients,
+ app.History,
+ app.config.MCP,
)
if err != nil {
slog.Error("Failed to create coder agent", "err", err)
return err
}
+ err = app.CoderAgent.WithAgentTool()
+ if err != nil {
+ slog.Error("Failed to create agent tool", "err", err)
+ return err
+ }
setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
return nil
}
@@ -1,18 +1,16 @@
package config
import (
- "context"
"fmt"
- "log/slog"
- "net/http"
"os"
"slices"
"strings"
- "time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/env"
+ "github.com/charmbracelet/crush/internal/llm/agent"
+ "github.com/charmbracelet/crush/internal/llm/provider"
+ "github.com/charmbracelet/crush/internal/resolver"
"github.com/tidwall/sjson"
)
@@ -45,73 +43,6 @@ const (
SelectedModelTypeSmall SelectedModelType = "small"
)
-type SelectedModel struct {
- // The model id as used by the provider API.
- // Required.
- Model string `json:"model"`
- // The model provider, same as the key/id used in the providers config.
- // Required.
- Provider string `json:"provider"`
-
- // Only used by models that use the openai provider and need this set.
- ReasoningEffort string `json:"reasoning_effort,omitempty"`
-
- // Overrides the default model configuration.
- MaxTokens int64 `json:"max_tokens,omitempty"`
-
- // Used by anthropic models that can reason to indicate if the model should think.
- Think bool `json:"think,omitempty"`
-}
-
-type ProviderConfig struct {
- // The provider's id.
- ID string `json:"id,omitempty"`
- // The provider's name, used for display purposes.
- Name string `json:"name,omitempty"`
- // The provider's API endpoint.
- BaseURL string `json:"base_url,omitempty"`
- // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
- Type catwalk.Type `json:"type,omitempty"`
- // The provider's API key.
- APIKey string `json:"api_key,omitempty"`
- // Marks the provider as disabled.
- Disable bool `json:"disable,omitempty"`
-
- // Custom system prompt prefix.
- SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
-
- // Extra headers to send with each request to the provider.
- ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
- // Extra body
- ExtraBody map[string]any `json:"extra_body,omitempty"`
-
- // Used to pass extra parameters to the provider.
- ExtraParams map[string]string `json:"-"`
-
- // The provider models
- Models []catwalk.Model `json:"models,omitempty"`
-}
-
-type MCPType string
-
-const (
- MCPStdio MCPType = "stdio"
- MCPSse MCPType = "sse"
- MCPHttp MCPType = "http"
-)
-
-type MCPConfig struct {
- Command string `json:"command,omitempty" `
- Env map[string]string `json:"env,omitempty"`
- Args []string `json:"args,omitempty"`
- Type MCPType `json:"type"`
- URL string `json:"url,omitempty"`
- Disabled bool `json:"disabled,omitempty"`
-
- // TODO: maybe make it possible to get the value from the env
- Headers map[string]string `json:"headers,omitempty"`
-}
-
type LSPConfig struct {
Disabled bool `json:"enabled,omitempty"`
Command string `json:"command"`
@@ -138,11 +69,11 @@ type Options struct {
DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
}
-type MCPs map[string]MCPConfig
+type MCPs map[string]agent.MCPConfig
type MCP struct {
- Name string `json:"name"`
- MCP MCPConfig `json:"mcp"`
+ Name string `json:"name"`
+ MCP agent.MCPConfig `json:"mcp"`
}
func (m MCPs) Sorted() []MCP {
@@ -180,71 +111,13 @@ func (l LSPs) Sorted() []LSP {
return sorted
}
-func (m MCPConfig) ResolvedEnv() []string {
- resolver := NewShellVariableResolver(env.New())
- for e, v := range m.Env {
- var err error
- m.Env[e], err = resolver.ResolveValue(v)
- if err != nil {
- slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
- continue
- }
- }
-
- env := make([]string, 0, len(m.Env))
- for k, v := range m.Env {
- env = append(env, fmt.Sprintf("%s=%s", k, v))
- }
- return env
-}
-
-func (m MCPConfig) ResolvedHeaders() map[string]string {
- resolver := NewShellVariableResolver(env.New())
- for e, v := range m.Headers {
- var err error
- m.Headers[e], err = resolver.ResolveValue(v)
- if err != nil {
- slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
- continue
- }
- }
- return m.Headers
-}
-
-type Agent struct {
- ID string `json:"id,omitempty"`
- Name string `json:"name,omitempty"`
- Description string `json:"description,omitempty"`
- // This is the id of the system prompt used by the agent
- Disabled bool `json:"disabled,omitempty"`
-
- Model SelectedModelType `json:"model"`
-
- // The available tools for the agent
- // if this is nil, all tools are available
- AllowedTools []string `json:"allowed_tools,omitempty"`
-
- // this tells us which MCPs are available for this agent
- // if this is empty all mcps are available
- // the string array is the list of tools from the AllowedMCP the agent has available
- // if the string array is nil, all tools from the AllowedMCP are available
- AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
-
- // The list of LSPs that this agent can use
- // if this is nil, all LSPs are available
- AllowedLSP []string `json:"allowed_lsp,omitempty"`
-
- // Overrides the context paths for this agent
- ContextPaths []string `json:"context_paths,omitempty"`
-}
-
// Config holds the configuration for crush.
type Config struct {
// We currently only support large/small as values here.
- Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
+ Models map[SelectedModelType]agent.Model `json:"models,omitempty"`
// The providers that are configured
- Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"`
+ Providers *csync.Map[string, provider.Config] `json:"providers,omitempty"`
MCP MCPs `json:"mcp,omitempty"`
@@ -256,10 +129,8 @@ type Config struct {
// Internal
workingDir string `json:"-"`
- // TODO: most likely remove this concept when I come back to it
- Agents map[string]Agent `json:"-"`
// TODO: find a better way to do this this should probably not be part of the config
- resolver VariableResolver
+ resolver resolver.Resolver
dataConfigDir string `json:"-"`
knownProviders []catwalk.Provider `json:"-"`
}
@@ -268,8 +139,8 @@ func (c *Config) WorkingDir() string {
return c.workingDir
}
-func (c *Config) EnabledProviders() []ProviderConfig {
- var enabled []ProviderConfig
+func (c *Config) EnabledProviders() []provider.Config {
+ var enabled []provider.Config
for p := range c.Providers.Seq() {
if !p.Disable {
enabled = append(enabled, p)
@@ -294,7 +165,7 @@ func (c *Config) GetModel(provider, model string) *catwalk.Model {
return nil
}
-func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
+func (c *Config) GetProviderForModel(modelType SelectedModelType) *provider.Config {
model, ok := c.Models[modelType]
if !ok {
return nil
@@ -344,7 +215,7 @@ func (c *Config) Resolve(key string) (string, error) {
return c.resolver.ResolveValue(key)
}
-func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
+func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model agent.Model) error {
c.Models[modelType] = model
if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
return fmt.Errorf("failed to update preferred model: %w", err)
@@ -397,7 +268,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
if foundProvider != nil {
// Create new provider config based on known provider
- providerConfig = ProviderConfig{
+ providerConfig = provider.Config{
ID: providerID,
Name: foundProvider.Name,
BaseURL: foundProvider.APIEndpoint,
@@ -416,82 +287,6 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
return nil
}
-func (c *Config) SetupAgents() {
- agents := map[string]Agent{
- "coder": {
- ID: "coder",
- Name: "Coder",
- Description: "An agent that helps with executing coding tasks.",
- Model: SelectedModelTypeLarge,
- ContextPaths: c.Options.ContextPaths,
- // All tools allowed
- },
- "task": {
- ID: "task",
- Name: "Task",
- Description: "An agent that helps with searching for context and finding implementation details.",
- Model: SelectedModelTypeLarge,
- ContextPaths: c.Options.ContextPaths,
- AllowedTools: []string{
- "glob",
- "grep",
- "ls",
- "sourcegraph",
- "view",
- },
- // NO MCPs or LSPs by default
- AllowedMCP: map[string][]string{},
- AllowedLSP: []string{},
- },
- }
- c.Agents = agents
-}
-
-func (c *Config) Resolver() VariableResolver {
+func (c *Config) Resolver() resolver.Resolver {
return c.resolver
}
-
-func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
- testURL := ""
- headers := make(map[string]string)
- apiKey, _ := resolver.ResolveValue(c.APIKey)
- switch c.Type {
- case catwalk.TypeOpenAI:
- baseURL, _ := resolver.ResolveValue(c.BaseURL)
- if baseURL == "" {
- baseURL = "https://api.openai.com/v1"
- }
- testURL = baseURL + "/models"
- headers["Authorization"] = "Bearer " + apiKey
- case catwalk.TypeAnthropic:
- baseURL, _ := resolver.ResolveValue(c.BaseURL)
- if baseURL == "" {
- baseURL = "https://api.anthropic.com/v1"
- }
- testURL = baseURL + "/models"
- headers["x-api-key"] = apiKey
- headers["anthropic-version"] = "2023-06-01"
- }
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- client := &http.Client{}
- req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
- if err != nil {
- return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
- }
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- for k, v := range c.ExtraHeaders {
- req.Header.Set(k, v)
- }
- b, err := client.Do(req)
- if err != nil {
- return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
- }
- if b.StatusCode != http.StatusOK {
- return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
- }
- _ = b.Body.Close()
- return nil
-}
@@ -5,7 +5,6 @@ import (
"os"
"path/filepath"
"strings"
- "sync/atomic"
)
const (
@@ -16,25 +15,7 @@ type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
-// TODO: we need to remove the global config instance keeping it now just until everything is migrated
-var instance atomic.Pointer[Config]
-
-func Init(workingDir string, debug bool) (*Config, error) {
- cfg, err := Load(workingDir, debug)
- if err != nil {
- return nil, err
- }
- instance.Store(cfg)
- return instance.Load(), nil
-}
-
-func Get() *Config {
- cfg := instance.Load()
- return cfg
-}
-
-func ProjectNeedsInitialization() (bool, error) {
- cfg := Get()
+func ProjectNeedsInitialization(cfg *Config) (bool, error) {
if cfg == nil {
return false, fmt.Errorf("config not loaded")
}
@@ -81,8 +62,7 @@ func crushMdExists(dir string) (bool, error) {
return false, nil
}
-func MarkProjectInitialized() error {
- cfg := Get()
+func MarkProjectInitialized(cfg *Config) error {
if cfg == nil {
return fmt.Errorf("config not loaded")
}
@@ -97,10 +77,13 @@ func MarkProjectInitialized() error {
return nil
}
-func HasInitialDataConfig() bool {
+func HasInitialDataConfig(cfg *Config) bool {
+ if cfg == nil {
+ return false
+ }
cfgPath := GlobalConfigData()
if _, err := os.Stat(cfgPath); err != nil {
return false
}
- return Get().IsConfigured()
+ return cfg.IsConfigured()
}
@@ -14,7 +14,10 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
+ "github.com/charmbracelet/crush/internal/llm/agent"
+ "github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/log"
+ "github.com/charmbracelet/crush/internal/resolver"
)
const defaultCatwalkURL = "https://catwalk.charm.sh"
@@ -71,7 +74,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
env := env.New()
// Configure providers
- valueResolver := NewShellVariableResolver(env)
+ valueResolver := resolver.NewShellVariableResolver(env)
cfg.resolver = valueResolver
if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
return nil, fmt.Errorf("failed to configure providers: %w", err)
@@ -85,11 +88,10 @@ func Load(workingDir string, debug bool) (*Config, error) {
if err := cfg.configureSelectedModels(providers); err != nil {
return nil, fmt.Errorf("failed to configure selected models: %w", err)
}
- cfg.SetupAgents()
return cfg, nil
}
-func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
+func (c *Config) configureProviders(env env.Env, resolver resolver.Resolver, knownProviders []catwalk.Provider) error {
knownProviderNames := make(map[string]bool)
for _, p := range knownProviders {
knownProviderNames[string(p.ID)] = true
@@ -135,7 +137,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
p.Models = models
}
}
- prepared := ProviderConfig{
+ prepared := provider.Config{
ID: string(p.ID),
Name: p.Name,
BaseURL: p.APIEndpoint,
@@ -269,13 +271,13 @@ func (c *Config) setDefaults(workingDir string) {
c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
}
if c.Providers == nil {
- c.Providers = csync.NewMap[string, ProviderConfig]()
+ c.Providers = csync.NewMap[string, provider.Config]()
}
if c.Models == nil {
- c.Models = make(map[SelectedModelType]SelectedModel)
+ c.Models = make(map[SelectedModelType]agent.Model)
}
if c.MCP == nil {
- c.MCP = make(map[string]MCPConfig)
+ c.MCP = make(map[string]agent.MCPConfig)
}
if c.LSP == nil {
c.LSP = make(map[string]LSPConfig)
@@ -287,7 +289,7 @@ func (c *Config) setDefaults(workingDir string) {
c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
}
-func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
+func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel, smallModel agent.Model, err error) {
if len(knownProviders) == 0 && c.Providers.Len() == 0 {
err = fmt.Errorf("no providers configured, please configure at least one provider")
return
@@ -305,7 +307,7 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large
err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
return
}
- largeModel = SelectedModel{
+ largeModel = agent.Model{
Provider: string(p.ID),
Model: defaultLargeModel.ID,
MaxTokens: defaultLargeModel.DefaultMaxTokens,
@@ -317,7 +319,7 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large
err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
return
}
- smallModel = SelectedModel{
+ smallModel = agent.Model{
Provider: string(p.ID),
Model: defaultSmallModel.ID,
MaxTokens: defaultSmallModel.DefaultMaxTokens,
@@ -327,7 +329,7 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large
}
enabledProviders := c.EnabledProviders()
- slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
+ slices.SortFunc(enabledProviders, func(a, b provider.Config) int {
return strings.Compare(a.ID, b.ID)
})
@@ -342,13 +344,13 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large
return
}
defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
- largeModel = SelectedModel{
+ largeModel = agent.Model{
Provider: providerConfig.ID,
Model: defaultLargeModel.ID,
MaxTokens: defaultLargeModel.DefaultMaxTokens,
}
defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
- smallModel = SelectedModel{
+ smallModel = agent.Model{
Provider: providerConfig.ID,
Model: defaultSmallModel.ID,
MaxTokens: defaultSmallModel.DefaultMaxTokens,
@@ -11,6 +11,9 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
+ "github.com/charmbracelet/crush/internal/llm/agent"
+ "github.com/charmbracelet/crush/internal/llm/provider"
+ "github.com/charmbracelet/crush/internal/resolver"
"github.com/stretchr/testify/assert"
)
@@ -72,7 +75,7 @@ func TestConfig_configureProviders(t *testing.T) {
env := env.NewFromMap(map[string]string{
"OPENAI_API_KEY": "test-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, 1, cfg.Providers.Len())
@@ -95,9 +98,9 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMap[string, ProviderConfig](),
+ Providers: csync.NewMap[string, provider.Config](),
}
- cfg.Providers.Set("openai", ProviderConfig{
+ cfg.Providers.Set("openai", provider.Config{
APIKey: "xyz",
BaseURL: "https://api.openai.com/v2",
Models: []catwalk.Model{
@@ -115,7 +118,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
env := env.NewFromMap(map[string]string{
"OPENAI_API_KEY": "test-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, 1, cfg.Providers.Len())
@@ -141,7 +144,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "xyz",
BaseURL: "https://api.someendpoint.com/v2",
@@ -157,7 +160,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
env := env.NewFromMap(map[string]string{
"OPENAI_API_KEY": "test-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Should be to because of the env variable
@@ -193,7 +196,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
"AWS_ACCESS_KEY_ID": "test-key-id",
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
@@ -219,7 +222,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
cfg := &Config{}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without credentials
@@ -244,7 +247,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
"AWS_ACCESS_KEY_ID": "test-key-id",
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.Error(t, err)
}
@@ -268,7 +271,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
"GOOGLE_CLOUD_PROJECT": "test-project",
"GOOGLE_CLOUD_LOCATION": "us-central1",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
@@ -300,7 +303,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
"GOOGLE_CLOUD_PROJECT": "test-project",
"GOOGLE_CLOUD_LOCATION": "us-central1",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without proper credentials
@@ -325,7 +328,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
"GOOGLE_GENAI_USE_VERTEXAI": "true",
"GOOGLE_CLOUD_LOCATION": "us-central1",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without project
@@ -349,7 +352,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
env := env.NewFromMap(map[string]string{
"OPENAI_API_KEY": "test-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
@@ -362,7 +365,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
func TestConfig_EnabledProviders(t *testing.T) {
t.Run("all providers enabled", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -382,7 +385,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
t.Run("some providers disabled", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -403,7 +406,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
t.Run("empty providers map", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMap[string, ProviderConfig](),
+ Providers: csync.NewMap[string, provider.Config](),
}
enabled := cfg.EnabledProviders()
@@ -414,7 +417,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -428,7 +431,7 @@ func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns false when no providers are configured", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMap[string, ProviderConfig](),
+ Providers: csync.NewMap[string, provider.Config](),
}
assert.False(t, cfg.IsConfigured())
@@ -436,7 +439,7 @@ func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns false when all providers are disabled", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -467,7 +470,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"openai": {
Disable: true,
},
@@ -478,7 +481,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
env := env.NewFromMap(map[string]string{
"OPENAI_API_KEY": "test-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -491,7 +494,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
BaseURL: "https://api.custom.com/v1",
Models: []catwalk.Model{{
@@ -506,7 +509,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
@@ -517,7 +520,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
Models: []catwalk.Model{{
@@ -529,7 +532,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
@@ -540,7 +543,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("custom provider with no models is removed", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -551,7 +554,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
@@ -562,7 +565,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -576,7 +579,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
@@ -587,7 +590,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -601,7 +604,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
@@ -615,7 +618,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("custom anthropic provider is supported", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom-anthropic": {
APIKey: "test-key",
BaseURL: "https://api.anthropic.com/v1",
@@ -629,7 +632,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
@@ -644,7 +647,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("disabled custom provider is removed", func(t *testing.T) {
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -659,7 +662,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
@@ -683,7 +686,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"vertexai": {
BaseURL: "custom-url",
},
@@ -694,7 +697,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{
"GOOGLE_GENAI_USE_VERTEXAI": "false",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -716,7 +719,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"bedrock": {
BaseURL: "custom-url",
},
@@ -725,7 +728,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -747,7 +750,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"openai": {
BaseURL: "custom-url",
},
@@ -756,7 +759,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -778,7 +781,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"openai": {
APIKey: "test-key",
},
@@ -789,7 +792,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{
"OPENAI_API_KEY": "test-key",
})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -823,7 +826,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg := &Config{}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -859,7 +862,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg := &Config{}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -889,7 +892,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg := &Config{}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
@@ -917,7 +920,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -932,7 +935,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
@@ -966,7 +969,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -976,7 +979,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
@@ -1003,7 +1006,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: csync.NewMapFrom(map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]provider.Config{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -1018,7 +1021,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
@@ -1058,7 +1061,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
}
cfg := &Config{
- Models: map[SelectedModelType]SelectedModel{
+ Models: map[SelectedModelType]agent.Model{
"large": {
Model: "larger-model",
},
@@ -1066,7 +1069,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -1118,7 +1121,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
}
cfg := &Config{
- Models: map[SelectedModelType]SelectedModel{
+ Models: map[SelectedModelType]agent.Model{
"small": {
Model: "a-small-model",
Provider: "anthropic",
@@ -1128,7 +1131,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -1165,7 +1168,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
}
cfg := &Config{
- Models: map[SelectedModelType]SelectedModel{
+ Models: map[SelectedModelType]agent.Model{
"large": {
MaxTokens: 100,
},
@@ -1173,7 +1176,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
- resolver := NewEnvironmentVariableResolver(env)
+ resolver := resolver.NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
@@ -36,6 +36,32 @@ func (s *LazySlice[K]) Seq() iter.Seq[K] {
}
}
+func (s *LazySlice[K]) Seq2() iter.Seq2[int, K] {
+ s.wg.Wait()
+ return func(yield func(int, K) bool) {
+ for i, v := range s.inner {
+ if !yield(i, v) {
+ return
+ }
+ }
+ }
+}
+
+func (s *LazySlice[K]) Set(index int, item K) bool {
+ s.wg.Wait()
+ if index < 0 || index >= len(s.inner) {
+ return false
+ }
+ s.inner[index] = item
+ return true
+}
+
+func (s *LazySlice[K]) Append(item K) bool {
+ s.wg.Wait()
+ s.inner = append(s.inner, item)
+ return true
+}
+
// Slice is a thread-safe slice implementation that provides concurrent access.
type Slice[T any] struct {
inner []T
@@ -5,19 +5,15 @@ import (
"errors"
"fmt"
"log/slog"
- "slices"
"strings"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/log"
- "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -52,200 +48,126 @@ type AgentEvent struct {
type Service interface {
pubsub.Suscriber[AgentEvent]
- Model() catwalk.Model
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
CancelAll()
IsSessionBusy(sessionID string) bool
IsBusy() bool
Summarize(ctx context.Context, sessionID string) error
- UpdateModel() error
+ SetDebug(debug bool)
+ UpdateModels(large, small Model) error
+ // for now, not really sure how to handle this better
+ WithAgentTool() error
+
+ ModelConfig() Model
+ Model() *catwalk.Model
+ Provider() *provider.Config
+}
+
+type Model struct {
+ // The model id as used by the provider API.
+ // Required.
+ Model string `json:"model"`
+ // The model provider, same as the key/id used in the providers config.
+ // Required.
+ Provider string `json:"provider"`
+
+ // Only used by models that use the openai provider and need this set.
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
+
+ // Overrides the default model configuration.
+ MaxTokens int64 `json:"max_tokens,omitempty"`
+
+ // Used by anthropic models that can reason to indicate if the model should think.
+ Think bool `json:"think,omitempty"`
}
type agent struct {
*pubsub.Broker[AgentEvent]
- agentCfg config.Agent
+ ctx context.Context
+ cwd string
+ systemPrompt string
+ providers map[string]provider.Config
+
sessions session.Service
messages message.Service
- tools *csync.LazySlice[tools.BaseTool]
+ toolsRegistry tools.Registry
- provider provider.Provider
- providerID string
-
- titleProvider provider.Provider
- summarizeProvider provider.Provider
- summarizeProviderID string
+ large, small Model
+ provider provider.Provider
+ titleProvider provider.Provider
+ summarizeProvider provider.Provider
activeRequests *csync.Map[string, context.CancelFunc]
-}
-var agentPromptMap = map[string]prompt.PromptID{
- "coder": prompt.PromptCoder,
- "task": prompt.PromptTask,
+ debug bool
}
func NewAgent(
- agentCfg config.Agent,
- // These services are needed in the tools
- permissions permission.Service,
+ ctx context.Context,
+ cwd string,
+ systemPrompt string,
+ toolsRegistry tools.Registry,
+ providers map[string]provider.Config,
+
+ smallModel Model,
+ largeModel Model,
+
sessions session.Service,
messages message.Service,
- history history.Service,
- lspClients map[string]*lsp.Client,
) (Service, error) {
- ctx := context.Background()
- cfg := config.Get()
-
- var agentTool tools.BaseTool
- if agentCfg.ID == "coder" {
- taskAgentCfg := config.Get().Agents["task"]
- if taskAgentCfg.ID == "" {
- return nil, fmt.Errorf("task agent not found in config")
- }
- taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
- if err != nil {
- return nil, fmt.Errorf("failed to create task agent: %w", err)
- }
-
- agentTool = NewAgentTool(taskAgent, sessions, messages)
- }
-
- providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
- if providerCfg == nil {
- return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
- }
- model := config.Get().GetModelByType(agentCfg.Model)
-
- if model == nil {
- return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
- }
+ agent := &agent{
+ Broker: pubsub.NewBroker[AgentEvent](),
+ ctx: ctx,
+ providers: providers,
+ cwd: cwd,
+ systemPrompt: systemPrompt,
+ toolsRegistry: toolsRegistry,
+ small: smallModel,
+ large: largeModel,
+ messages: messages,
+ sessions: sessions,
+ activeRequests: csync.NewMap[string, context.CancelFunc](),
+ }
+
+ err := agent.setProviders()
+ return agent, err
+}
- promptID := agentPromptMap[agentCfg.ID]
- if promptID == "" {
- promptID = prompt.PromptDefault
- }
- opts := []provider.ProviderClientOption{
- provider.WithModel(agentCfg.Model),
- provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
- }
- agentProvider, err := provider.NewProvider(*providerCfg, opts...)
- if err != nil {
- return nil, err
- }
+func (a *agent) ModelConfig() Model {
+ return a.large
+}
- smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
- var smallModelProviderCfg *config.ProviderConfig
- if smallModelCfg.Provider == providerCfg.ID {
- smallModelProviderCfg = providerCfg
- } else {
- smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
+func (a *agent) Model() *catwalk.Model {
+ return a.provider.Model(a.large.Model)
+}
- if smallModelProviderCfg.ID == "" {
- return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
+func (a *agent) Provider() *provider.Config {
+ for _, provider := range a.providers {
+ if provider.ID == a.large.Provider {
+ return &provider
}
}
- smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
- if smallModel.ID == "" {
- return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
- }
-
- titleOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SelectedModelTypeSmall),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
- }
- titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
- if err != nil {
- return nil, err
- }
- summarizeOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SelectedModelTypeSmall),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
- }
- summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
- if err != nil {
- return nil, err
- }
-
- toolFn := func() []tools.BaseTool {
- slog.Info("Initializing agent tools", "agent", agentCfg.ID)
- defer func() {
- slog.Info("Initialized agent tools", "agent", agentCfg.ID)
- }()
-
- cwd := cfg.WorkingDir()
- allTools := []tools.BaseTool{
- tools.NewBashTool(permissions, cwd),
- tools.NewDownloadTool(permissions, cwd),
- tools.NewEditTool(lspClients, permissions, history, cwd),
- tools.NewFetchTool(permissions, cwd),
- tools.NewGlobTool(cwd),
- tools.NewGrepTool(cwd),
- tools.NewLsTool(cwd),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients, cwd),
- tools.NewWriteTool(lspClients, permissions, history, cwd),
- }
-
- mcpTools := GetMCPTools(ctx, permissions, cfg)
- allTools = append(allTools, mcpTools...)
-
- if len(lspClients) > 0 {
- allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
- }
-
- if agentTool != nil {
- allTools = append(allTools, agentTool)
- }
-
- if agentCfg.AllowedTools == nil {
- return allTools
- }
-
- var filteredTools []tools.BaseTool
- for _, tool := range allTools {
- if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
- filteredTools = append(filteredTools, tool)
- }
- }
- return filteredTools
- }
-
- return &agent{
- Broker: pubsub.NewBroker[AgentEvent](),
- agentCfg: agentCfg,
- provider: agentProvider,
- providerID: string(providerCfg.ID),
- messages: messages,
- sessions: sessions,
- titleProvider: titleProvider,
- summarizeProvider: summarizeProvider,
- summarizeProviderID: string(smallModelProviderCfg.ID),
- activeRequests: csync.NewMap[string, context.CancelFunc](),
- tools: csync.NewLazySlice(toolFn),
- }, nil
-}
-
-func (a *agent) Model() catwalk.Model {
- return *config.Get().GetModelByType(a.agentCfg.Model)
+ return nil
}
func (a *agent) Cancel(sessionID string) {
// Cancel regular requests
- if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
+ if cancel, exists := a.activeRequests.Take(sessionID); exists {
slog.Info("Request cancellation initiated", "session_id", sessionID)
cancel()
}
// Also check for summarize requests
- if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
+ if cancel, exists := a.activeRequests.Take(sessionID + "-summarize"); exists {
slog.Info("Summarize cancellation initiated", "session_id", sessionID)
cancel()
}
}
func (a *agent) IsBusy() bool {
- var busy bool
+ busy := false
for cancelFunc := range a.activeRequests.Seq() {
if cancelFunc != nil {
busy = true
@@ -275,9 +197,9 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str
Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
}}
- // Use streaming approach like summarization
- response := a.titleProvider.StreamResponse(
+ response := a.titleProvider.Stream(
ctx,
+ a.small.Model,
[]message.Message{
{
Role: message.User,
@@ -352,7 +274,6 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
}
func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
- cfg := config.Get()
// List existing messages; if none, start title generation asynchronously.
msgs, err := a.messages.List(ctx, sessionID)
if err != nil {
@@ -411,7 +332,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
}
return a.err(fmt.Errorf("failed to process events: %w", err))
}
- if cfg.Options.Debug {
+ if a.debug {
slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
}
if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
@@ -438,13 +359,13 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
+ eventChan := a.provider.Stream(ctx, a.large.Model, msgHistory, a.toolsRegistry.GetAllTools())
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
- Model: a.Model().ID,
- Provider: a.providerID,
+ Model: a.large.Model,
+ Provider: a.large.Provider,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
@@ -487,7 +408,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
default:
// Continue processing
var tool tools.BaseTool
- for availableTool := range a.tools.Seq() {
+ for _, availableTool := range a.toolsRegistry.GetAllTools() {
if availableTool.Info().Name == toolCall.Name {
tool = availableTool
break
@@ -578,7 +499,8 @@ out:
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
- Provider: a.providerID,
+ Model: a.large.Model,
+ Provider: a.large.Provider,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
@@ -632,7 +554,11 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
if err := a.messages.Update(ctx, *assistantMsg); err != nil {
return fmt.Errorf("failed to update message: %w", err)
}
- return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
+ model := a.Model()
+ if model == nil {
+ return nil
+ }
+ return a.TrackUsage(ctx, sessionID, *model, event.Response.Usage)
}
return nil
@@ -734,8 +660,9 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
a.Publish(pubsub.CreatedEvent, event)
// Send the messages to the summarize provider
- response := a.summarizeProvider.StreamResponse(
+ response := a.summarizeProvider.Stream(
summarizeCtx,
+ a.large.Model,
msgsWithPrompt,
nil,
)
@@ -763,7 +690,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
a.Publish(pubsub.CreatedEvent, event)
return
}
- shell := shell.GetPersistentShell(config.Get().WorkingDir())
+ shell := shell.GetPersistentShell(a.cwd)
summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
event = AgentEvent{
Type: AgentEventTypeSummarize,
@@ -792,8 +719,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
Time: time.Now().Unix(),
},
},
- Model: a.summarizeProvider.Model().ID,
- Provider: a.summarizeProviderID,
+ Model: a.large.Model,
+ Provider: a.large.Provider,
})
if err != nil {
event = AgentEvent{
@@ -808,7 +735,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
oldSession.SummaryMessageID = msg.ID
oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
oldSession.PromptTokens = 0
- model := a.summarizeProvider.Model()
+ model := a.summarizeProvider.Model(a.large.Model)
usage := finalResponse.Usage
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
@@ -857,92 +784,101 @@ func (a *agent) CancelAll() {
}
}
-func (a *agent) UpdateModel() error {
- cfg := config.Get()
+func (a *agent) UpdateModels(small, large Model) error {
+ a.small = small
+ a.large = large
+ return a.setProviders()
+}
- // Get current provider configuration
- currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
- if currentProviderCfg == nil || currentProviderCfg.ID == "" {
- return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
+func (a *agent) SetDebug(debug bool) {
+ a.debug = debug
+ if a.provider != nil {
+ a.provider.SetDebug(debug)
}
+ if a.titleProvider != nil {
+ a.titleProvider.SetDebug(debug)
+ }
+ if a.summarizeProvider != nil {
+ a.summarizeProvider.SetDebug(debug)
+ }
+}
- // Check if provider has changed
- if string(currentProviderCfg.ID) != a.providerID {
- // Provider changed, need to recreate the main provider
- model := cfg.GetModelByType(a.agentCfg.Model)
- if model.ID == "" {
- return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
- }
-
- promptID := agentPromptMap[a.agentCfg.ID]
- if promptID == "" {
- promptID = prompt.PromptDefault
- }
-
- opts := []provider.ProviderClientOption{
- provider.WithModel(a.agentCfg.Model),
- provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
- }
-
- newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
- if err != nil {
- return fmt.Errorf("failed to create new provider: %w", err)
- }
+func (a *agent) setProviders() error {
+ opts := []provider.Option{
+ provider.WithSystemMessage(a.systemPrompt),
+ provider.WithThinking(a.large.Think),
+ }
- // Update the provider and provider ID
- a.provider = newProvider
- a.providerID = string(currentProviderCfg.ID)
+ if a.large.MaxTokens > 0 {
+ opts = append(opts, provider.WithMaxTokens(a.large.MaxTokens))
+ }
+ if a.large.ReasoningEffort != "" {
+ opts = append(opts, provider.WithReasoningEffort(a.large.ReasoningEffort))
}
- // Check if small model provider has changed (affects title and summarize providers)
- smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
- var smallModelProviderCfg config.ProviderConfig
+ providerCfg, ok := a.providers[a.large.Provider]
+ if !ok {
+ return fmt.Errorf("provider %s not found in config", a.large.Provider)
+ }
+ var err error
+ a.provider, err = provider.NewProvider(providerCfg, opts...)
+ if err != nil {
+ return fmt.Errorf("failed to create provider: %w", err)
+ }
- for p := range cfg.Providers.Seq() {
- if p.ID == smallModelCfg.Provider {
- smallModelProviderCfg = p
- break
- }
+ titleOpts := []provider.Option{
+ provider.WithSystemMessage(prompt.TitlePrompt()),
+ provider.WithMaxTokens(40),
}
- if smallModelProviderCfg.ID == "" {
- return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
+ titleProviderCfg, ok := a.providers[a.small.Provider]
+ if !ok {
+ return fmt.Errorf("small model provider %s not found in config", a.small.Provider)
}
- // Check if summarize provider has changed
- if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
- smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
- if smallModel == nil {
- return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
- }
+ a.titleProvider, err = provider.NewProvider(titleProviderCfg, titleOpts...)
+ if err != nil {
+ return err
+ }
+ summarizeOpts := []provider.Option{
+ provider.WithSystemMessage(prompt.SummarizerPrompt()),
+ }
+ a.summarizeProvider, err = provider.NewProvider(providerCfg, summarizeOpts...)
+ if err != nil {
+ return err
+ }
- // Recreate title provider
- titleOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SelectedModelTypeSmall),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
- // We want the title to be short, so we limit the max tokens
- provider.WithMaxTokens(40),
- }
- newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
- if err != nil {
- return fmt.Errorf("failed to create new title provider: %w", err)
- }
+ if _, ok := a.toolsRegistry.GetTool(AgentToolName); ok {
+ // reset the agent tool
+ a.WithAgentTool()
+ }
- // Recreate summarize provider
- summarizeOpts := []provider.ProviderClientOption{
- provider.WithModel(config.SelectedModelTypeSmall),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
- }
- newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
- if err != nil {
- return fmt.Errorf("failed to create new summarize provider: %w", err)
- }
+ a.SetDebug(a.debug)
+ return nil
+}
- // Update the providers and provider ID
- a.titleProvider = newTitleProvider
- a.summarizeProvider = newSummarizeProvider
- a.summarizeProviderID = string(smallModelProviderCfg.ID)
+func (a *agent) WithAgentTool() error {
+ agent, err := NewAgent(
+ a.ctx,
+ a.cwd,
+ prompt.TaskPrompt(a.cwd),
+ NewTaskTools(a.cwd),
+ a.providers,
+ a.small,
+ a.large,
+ a.sessions,
+ a.messages,
+ )
+ if err != nil {
+ return err
}
+ agentTool := NewAgentTool(
+ agent,
+ a.sessions,
+ a.messages,
+ )
+
+ a.toolsRegistry.SetTool(AgentToolName, agentTool)
return nil
}
@@ -0,0 +1,52 @@
+package agent
+
+import (
+ "context"
+
+ "github.com/charmbracelet/crush/internal/history"
+ "github.com/charmbracelet/crush/internal/llm/prompt"
+ "github.com/charmbracelet/crush/internal/llm/provider"
+ "github.com/charmbracelet/crush/internal/lsp"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/session"
+)
+
+func NewCoderAgent(
+ ctx context.Context,
+ cwd string,
+ providers map[string]provider.Config,
+ smallModel Model,
+ largeModel Model,
+ contextFiles []string,
+ sessions session.Service,
+ messages message.Service,
+ permissions permission.Service,
+ lspClients map[string]*lsp.Client,
+ history history.Service,
+ mcps map[string]MCPConfig,
+) (Service, error) {
+ systemPrompt := prompt.CoderPrompt(cwd, contextFiles...)
+ tools := NewCoderTools(
+ ctx,
+ cwd,
+ sessions,
+ messages,
+ permissions,
+ lspClients,
+ history,
+ mcps,
+ )
+
+ return NewAgent(
+ ctx,
+ cwd,
+ systemPrompt,
+ tools,
+ providers,
+ smallModel,
+ largeModel,
+ sessions,
+ messages,
+ )
+}
@@ -8,22 +8,41 @@ import (
"slices"
"sync"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/llm/tools"
+ "github.com/charmbracelet/crush/internal/resolver"
+ "github.com/charmbracelet/crush/internal/version"
"github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/version"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)
+type MCPType string
+
+const (
+ MCPStdio MCPType = "stdio"
+ MCPSse MCPType = "sse"
+ MCPHttp MCPType = "http"
+)
+
+type MCPConfig struct {
+ Command string `json:"command,omitempty" `
+ Env map[string]string `json:"env,omitempty"`
+ Args []string `json:"args,omitempty"`
+ Type MCPType `json:"type"`
+ URL string `json:"url,omitempty"`
+ Disabled bool `json:"disabled,omitempty"`
+
+ Headers map[string]string `json:"headers,omitempty"`
+}
+
type mcpTool struct {
mcpName string
tool mcp.Tool
- mcpConfig config.MCPConfig
+ mcpConfig MCPConfig
permissions permission.Service
workingDir string
}
@@ -60,7 +79,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "Crush",
+ Name: "crush",
Version: version.Version,
}
@@ -115,7 +134,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
}
switch b.mcpConfig.Type {
- case config.MCPStdio:
+ case MCPStdio:
c, err := client.NewStdioMCPClient(
b.mcpConfig.Command,
b.mcpConfig.ResolvedEnv(),
@@ -125,7 +144,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse(err.Error()), nil
}
return runTool(ctx, c, b.tool.Name, params.Input)
- case config.MCPHttp:
+ case MCPHttp:
c, err := client.NewStreamableHttpClient(
b.mcpConfig.URL,
transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
@@ -134,7 +153,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse(err.Error()), nil
}
return runTool(ctx, c, b.tool.Name, params.Input)
- case config.MCPSse:
+ case MCPSse:
c, err := client.NewSSEMCPClient(
b.mcpConfig.URL,
client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
@@ -148,23 +167,22 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
-func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
+func NewMcpTool(name, cwd string, tool mcp.Tool, permissions permission.Service, mcpConfig MCPConfig) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
mcpConfig: mcpConfig,
permissions: permissions,
- workingDir: workingDir,
+ workingDir: cwd,
}
}
-func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
+func getTools(ctx context.Context, cwd string, name string, m MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "Crush",
- Version: version.Version,
+ Name: "dreamlover",
}
_, err := c.Initialize(ctx, initRequest)
@@ -179,7 +197,7 @@ func getTools(ctx context.Context, name string, m config.MCPConfig, permissions
return stdioTools
}
for _, t := range tools.Tools {
- stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
+ stdioTools = append(stdioTools, NewMcpTool(name, cwd, t, permissions, m))
}
defer c.Close()
return stdioTools
@@ -190,26 +208,26 @@ var (
mcpTools []tools.BaseTool
)
-func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+func GetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
mcpToolsOnce.Do(func() {
- mcpTools = doGetMCPTools(ctx, permissions, cfg)
+ mcpTools = doGetMCPTools(ctx, cwd, mcps, permissions)
})
return mcpTools
}
-func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+func doGetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
var wg sync.WaitGroup
result := csync.NewSlice[tools.BaseTool]()
- for name, m := range cfg.MCP {
+ for name, m := range mcps {
if m.Disabled {
slog.Debug("skipping disabled mcp", "name", name)
continue
}
wg.Add(1)
- go func(name string, m config.MCPConfig) {
+ go func(name string, m MCPConfig) {
defer wg.Done()
switch m.Type {
- case config.MCPStdio:
+ case MCPStdio:
c, err := client.NewStdioMCPClient(
m.Command,
m.ResolvedEnv(),
@@ -220,8 +238,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
return
}
- result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
- case config.MCPHttp:
+ result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
+ case MCPHttp:
c, err := client.NewStreamableHttpClient(
m.URL,
transport.WithHTTPHeaders(m.ResolvedHeaders()),
@@ -230,8 +248,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
slog.Error("error creating mcp client", "error", err)
return
}
- result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
- case config.MCPSse:
+ result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
+ case MCPSse:
c, err := client.NewSSEMCPClient(
m.URL,
client.WithHeaders(m.ResolvedHeaders()),
@@ -240,10 +258,41 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
slog.Error("error creating mcp client", "error", err)
return
}
- result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
+ result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
}
}(name, m)
}
wg.Wait()
return slices.Collect(result.Seq())
}
+
+func (m MCPConfig) ResolvedEnv() []string {
+ resolver := resolver.New()
+ for e, v := range m.Env {
+ var err error
+ m.Env[e], err = resolver.ResolveValue(v)
+ if err != nil {
+ slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
+ continue
+ }
+ }
+
+ env := make([]string, 0, len(m.Env))
+ for k, v := range m.Env {
+ env = append(env, fmt.Sprintf("%s=%s", k, v))
+ }
+ return env
+}
+
+func (m MCPConfig) ResolvedHeaders() map[string]string {
+ resolver := resolver.New()
+ for e, v := range m.Headers {
+ var err error
+ m.Headers[e], err = resolver.ResolveValue(v)
+ if err != nil {
+ slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
+ continue
+ }
+ }
+ return m.Headers
+}
@@ -0,0 +1,56 @@
+package agent
+
+import (
+ "context"
+
+ "github.com/charmbracelet/crush/internal/history"
+ "github.com/charmbracelet/crush/internal/llm/tools"
+ "github.com/charmbracelet/crush/internal/lsp"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/session"
+)
+
+func NewCoderTools(
+ ctx context.Context,
+ cwd string,
+ sessions session.Service,
+ messages message.Service,
+ permissions permission.Service,
+ lspClients map[string]*lsp.Client,
+ history history.Service,
+ mcps map[string]MCPConfig,
+) tools.Registry {
+ toolFn := func() []tools.BaseTool {
+ allTools := []tools.BaseTool{
+ tools.NewBashTool(permissions, cwd),
+ tools.NewDownloadTool(permissions, cwd),
+ tools.NewEditTool(lspClients, permissions, history, cwd),
+ tools.NewFetchTool(permissions, cwd),
+ tools.NewGlobTool(cwd),
+ tools.NewGrepTool(cwd),
+ tools.NewLsTool(cwd),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients, cwd),
+ tools.NewWriteTool(lspClients, permissions, history, cwd),
+ }
+ mcpTools := GetMCPTools(ctx, cwd, mcps, permissions)
+ allTools = append(allTools, mcpTools...)
+ if len(lspClients) > 0 {
+ allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
+ }
+ return allTools
+ }
+ return tools.NewRegistry(toolFn)
+}
+
+func NewTaskTools(cwd string) tools.Registry {
+ return tools.NewRegistryFromTools([]tools.BaseTool{
+ tools.NewGlobTool(cwd),
+ tools.NewGrepTool(cwd),
+ tools.NewLsTool(cwd),
+ tools.NewSourcegraphTool(),
+ // no need for LSP info here
+ tools.NewViewTool(map[string]*lsp.Client{}, cwd),
+ })
+}
@@ -4,26 +4,23 @@ import (
"context"
_ "embed"
"fmt"
- "log/slog"
"os"
"path/filepath"
"runtime"
"time"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
)
-func CoderPrompt(p string, contextFiles ...string) string {
+func CoderPrompt(cwd string, contextFiles ...string) string {
var basePrompt string
basePrompt = string(baseCoderPrompt)
- envInfo := getEnvironmentInfo()
+ envInfo := getEnvironmentInfo(cwd)
- basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+ basePrompt = fmt.Sprintf("%s\n\n%s", basePrompt, envInfo)
- contextContent := getContextFromPaths(config.Get().WorkingDir(), contextFiles)
- slog.Debug("Context content", "Context", contextContent)
+ contextContent := getContextFromPaths(cwd, contextFiles)
if contextContent != "" {
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
}
@@ -33,8 +30,7 @@ func CoderPrompt(p string, contextFiles ...string) string {
//go:embed coder.md
var baseCoderPrompt []byte
-func getEnvironmentInfo() string {
- cwd := config.Get().WorkingDir()
+func getEnvironmentInfo(cwd string) string {
isGit := isGitRepo(cwd)
platform := runtime.GOOS
date := time.Now().Format("1/2/2006")
@@ -60,18 +56,7 @@ func isGitRepo(dir string) bool {
return err == nil
}
-func lspInformation() string {
- cfg := config.Get()
- hasLSP := false
- for _, v := range cfg.LSP {
- if !v.Disabled {
- hasLSP = true
- break
- }
- }
- if !hasLSP {
- return ""
- }
+func LSPInformation() string {
return `# LSP Information
Tools that support it will also include useful diagnostics such as linting and typechecking.
- These diagnostics will be automatically enabled when you run the tool, and will be displayed in the output at the bottom within the <file_diagnostics></file_diagnostics> and <project_diagnostics></project_diagnostics> tags.
@@ -6,9 +6,7 @@ import (
"strings"
"sync"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/env"
+ "github.com/charmbracelet/crush/internal/resolver"
)
type PromptID string
@@ -21,23 +19,6 @@ const (
PromptDefault PromptID = "default"
)
-func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string {
- basePrompt := ""
- switch promptID {
- case PromptCoder:
- basePrompt = CoderPrompt(provider, contextPaths...)
- case PromptTitle:
- basePrompt = TitlePrompt()
- case PromptTask:
- basePrompt = TaskPrompt()
- case PromptSummarizer:
- basePrompt = SummarizerPrompt()
- default:
- basePrompt = "You are a helpful assistant"
- }
- return basePrompt
-}
-
func getContextFromPaths(workingDir string, contextPaths []string) string {
return processContextPaths(workingDir, contextPaths)
}
@@ -59,7 +40,7 @@ func expandPath(path string) string {
// Handle environment variable expansion using the same pattern as config
if strings.HasPrefix(path, "$") {
- resolver := config.NewEnvironmentVariableResolver(env.New())
+ resolver := resolver.New()
if expanded, err := resolver.ResolveValue(path); err == nil {
path = expanded
}
@@ -75,7 +56,8 @@ func processContextPaths(workDir string, paths []string) string {
)
// Track processed files to avoid duplicates
- processedFiles := csync.NewMap[string, bool]()
+ processedFiles := make(map[string]bool)
+ var processedMutex sync.Mutex
for _, path := range paths {
wg.Add(1)
@@ -106,8 +88,14 @@ func processContextPaths(workDir string, paths []string) string {
// Check if we've already processed this file (case-insensitive)
lowerPath := strings.ToLower(path)
- if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
- processedFiles.Set(lowerPath, true)
+ processedMutex.Lock()
+ alreadyProcessed := processedFiles[lowerPath]
+ if !alreadyProcessed {
+ processedFiles[lowerPath] = true
+ }
+ processedMutex.Unlock()
+
+ if !alreadyProcessed {
if result := processFile(path); result != "" {
resultCh <- result
}
@@ -120,8 +108,14 @@ func processContextPaths(workDir string, paths []string) string {
// Check if we've already processed this file (case-insensitive)
lowerPath := strings.ToLower(fullPath)
- if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
- processedFiles.Set(lowerPath, true)
+ processedMutex.Lock()
+ alreadyProcessed := processedFiles[lowerPath]
+ if !alreadyProcessed {
+ processedFiles[lowerPath] = true
+ }
+ processedMutex.Unlock()
+
+ if !alreadyProcessed {
result := processFile(fullPath)
if result != "" {
resultCh <- result
@@ -4,12 +4,12 @@ import (
"fmt"
)
-func TaskPrompt() string {
+func TaskPrompt(cwd string) string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2. When relevant, share file names and code snippets relevant to the query
3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.`
- return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo())
+ return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo(cwd))
}
@@ -16,28 +16,25 @@ import (
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
-type anthropicClient struct {
- providerOptions providerClientOptions
+type anthropicProvider struct {
+ *baseProvider
useBedrock bool
client anthropic.Client
adjustedMaxTokens int // Used when context limit is hit
}
-type AnthropicClient ProviderClient
-
-func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
- return &anthropicClient{
- providerOptions: opts,
- client: createAnthropicClient(opts, useBedrock),
+func NewAnthropicProvider(base *baseProvider, useBedrock bool) Provider {
+ return &anthropicProvider{
+ baseProvider: base,
+ client: createAnthropicClient(base, useBedrock),
}
}
-func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
+func createAnthropicClient(opts *baseProvider, useBedrock bool) anthropic.Client {
anthropicClientOptions := []option.RequestOption{}
// Check if Authorization header is provided in extra headers
@@ -76,7 +73,7 @@ func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropi
return anthropic.NewClient(anthropicClientOptions...)
}
-func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
+func (a *anthropicProvider) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
for i, msg := range messages {
cache := false
if i > len(messages)-3 {
@@ -85,7 +82,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content().String())
- if cache && !a.providerOptions.disableCache {
+ if cache && !a.disableCache {
content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -110,7 +107,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
- if cache && !a.providerOptions.disableCache {
+ if cache && !a.disableCache {
content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -144,7 +141,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
return
}
-func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
+func (a *anthropicProvider) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
for i, tool := range tools {
@@ -154,11 +151,11 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU
Description: anthropic.String(info.Description),
InputSchema: anthropic.ToolInputSchemaParam{
Properties: info.Parameters,
- // TODO: figure out how we can tell claude the required fields?
+ Required: info.Required,
},
}
- if i == len(tools)-1 && !a.providerOptions.disableCache {
+ if i == len(tools)-1 && !a.disableCache {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
@@ -170,7 +167,7 @@ func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolU
return anthropicTools
}
-func (a *anthropicClient) finishReason(reason string) message.FinishReason {
+func (a *anthropicProvider) finishReason(reason string) message.FinishReason {
switch reason {
case "end_turn":
return message.FinishReasonEndTurn
@@ -185,37 +182,23 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason {
}
}
-func (a *anthropicClient) isThinkingEnabled() bool {
- cfg := config.Get()
- modelConfig := cfg.Models[config.SelectedModelTypeLarge]
- if a.providerOptions.modelType == config.SelectedModelTypeSmall {
- modelConfig = cfg.Models[config.SelectedModelTypeSmall]
- }
- return a.Model().CanReason && modelConfig.Think
+func (a *anthropicProvider) isThinkingEnabled(model string) bool {
+ return a.Model(model).CanReason && a.think
}
-func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
- model := a.providerOptions.model(a.providerOptions.modelType)
+func (a *anthropicProvider) preparedMessages(modelID string, messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
+ model := a.Model(modelID)
var thinkingParam anthropic.ThinkingConfigParamUnion
- cfg := config.Get()
- modelConfig := cfg.Models[config.SelectedModelTypeLarge]
- if a.providerOptions.modelType == config.SelectedModelTypeSmall {
- modelConfig = cfg.Models[config.SelectedModelTypeSmall]
- }
temperature := anthropic.Float(0)
maxTokens := model.DefaultMaxTokens
- if modelConfig.MaxTokens > 0 {
- maxTokens = modelConfig.MaxTokens
+ if a.maxTokens > 0 {
+ maxTokens = a.maxTokens
}
- if a.isThinkingEnabled() {
+ if a.isThinkingEnabled(modelID) {
thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8))
temperature = anthropic.Float(1)
}
- // Override max tokens if set in provider options
- if a.providerOptions.maxTokens > 0 {
- maxTokens = a.providerOptions.maxTokens
- }
// Use adjusted max tokens if context limit was hit
if a.adjustedMaxTokens > 0 {
@@ -225,9 +208,9 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
systemBlocks := []anthropic.TextBlockParam{}
// Add custom system prompt prefix if configured
- if a.providerOptions.systemPromptPrefix != "" {
+ if a.systemPromptPrefix != "" {
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
- Text: a.providerOptions.systemPromptPrefix,
+ Text: a.systemPromptPrefix,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
@@ -235,7 +218,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
}
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
- Text: a.providerOptions.systemMessage,
+ Text: a.systemMessage,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
@@ -252,21 +235,24 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
}
}
-func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
- cfg := config.Get()
+func (a *anthropicProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ messages = a.cleanMessages(messages)
+ return a.send(ctx, model, messages, tools)
+}
+func (a *anthropicProvider) send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
attempts := 0
for {
attempts++
// Prepare messages on each attempt in case max_tokens was adjusted
- preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
- if cfg.Options.Debug {
+ preparedMessages := a.preparedMessages(model, a.convertMessages(messages), a.convertTools(tools))
+ if a.debug {
jsonData, _ := json.Marshal(preparedMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
var opts []option.RequestOption
- if a.isThinkingEnabled() {
+ if a.isThinkingEnabled(model) {
opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
}
anthropicResponse, err := a.client.Messages.New(
@@ -308,22 +294,26 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
}
}
-func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- cfg := config.Get()
+func (a *anthropicProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ messages = a.cleanMessages(messages)
+ return a.stream(ctx, model, messages, tools)
+}
+
+func (a *anthropicProvider) stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
attempts := 0
eventChan := make(chan ProviderEvent)
go func() {
for {
attempts++
// Prepare messages on each attempt in case max_tokens was adjusted
- preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
- if cfg.Options.Debug {
+ preparedMessages := a.preparedMessages(model, a.convertMessages(messages), a.convertTools(tools))
+ if a.debug {
jsonData, _ := json.Marshal(preparedMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
var opts []option.RequestOption
- if a.isThinkingEnabled() {
+ if a.isThinkingEnabled(model) {
opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
}
@@ -460,7 +450,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
return eventChan
}
-func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+func (a *anthropicProvider) shouldRetry(attempts int, err error) (bool, int64, error) {
var apiErr *anthropic.Error
if !errors.As(err, &apiErr) {
return false, 0, err
@@ -471,11 +461,12 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
}
if apiErr.StatusCode == 401 {
- a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
+ a.apiKey, err = a.resolver.ResolveValue(a.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
- a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
+
+ a.client = createAnthropicClient(a.baseProvider, a.useBedrock)
return true, 0, nil
}
@@ -508,7 +499,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
}
// handleContextLimitError parses context limit error and returns adjusted max_tokens
-func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
+func (a *anthropicProvider) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
// Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
errorMsg := apiErr.Error()
@@ -535,7 +526,7 @@ func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int,
return safeMaxTokens, true
}
-func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
+func (a *anthropicProvider) toolCalls(msg anthropic.Message) []message.ToolCall {
var toolCalls []message.ToolCall
for _, block := range msg.Content {
@@ -555,7 +546,7 @@ func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
return toolCalls
}
-func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
+func (a *anthropicProvider) usage(msg anthropic.Message) TokenUsage {
return TokenUsage{
InputTokens: msg.Usage.InputTokens,
OutputTokens: msg.Usage.OutputTokens,
@@ -563,7 +554,3 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
CacheReadTokens: msg.Usage.CacheReadInputTokens,
}
}
-
-func (a *anthropicClient) Model() catwalk.Model {
- return a.providerOptions.model(a.providerOptions.modelType)
-}
@@ -6,27 +6,25 @@ import (
"github.com/openai/openai-go/option"
)
-type azureClient struct {
- *openaiClient
+type azureProvider struct {
+ *openaiProvider
}
-type AzureClient ProviderClient
-
-func newAzureClient(opts providerClientOptions) AzureClient {
- apiVersion := opts.extraParams["apiVersion"]
+func NewAzureProvider(base *baseProvider) Provider {
+ apiVersion := base.extraParams["apiVersion"]
if apiVersion == "" {
apiVersion = "2025-01-01-preview"
}
reqOpts := []option.RequestOption{
- azure.WithEndpoint(opts.baseURL, apiVersion),
+ azure.WithEndpoint(base.baseURL, apiVersion),
}
- reqOpts = append(reqOpts, azure.WithAPIKey(opts.apiKey))
- base := &openaiClient{
- providerOptions: opts,
- client: openai.NewClient(reqOpts...),
+ reqOpts = append(reqOpts, azure.WithAPIKey(base.apiKey))
+ client := &openaiProvider{
+ baseProvider: base,
+ client: openai.NewClient(reqOpts...),
}
- return &azureClient{openaiClient: base}
+ return &azureProvider{openaiProvider: client}
}
@@ -4,90 +4,56 @@ import (
"context"
"errors"
"fmt"
- "strings"
- "github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
-type bedrockClient struct {
- providerOptions providerClientOptions
- childProvider ProviderClient
+type bedrockProvider struct {
+ *baseProvider
+ region string
+ childProvider Provider
}
-type BedrockClient ProviderClient
-
-func newBedrockClient(opts providerClientOptions) BedrockClient {
+func NewBedrockProvider(base *baseProvider) Provider {
// Get AWS region from environment
- region := opts.extraParams["region"]
+ region := base.extraParams["region"]
if region == "" {
region = "us-east-1" // default region
}
- if len(region) < 2 {
- return &bedrockClient{
- providerOptions: opts,
- childProvider: nil, // Will cause an error when used
- }
- }
-
- opts.model = func(modelType config.SelectedModelType) catwalk.Model {
- model := config.Get().GetModelByType(modelType)
-
- // Prefix the model name with region
- regionPrefix := region[:2]
- modelName := model.ID
- model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
- return *model
- }
-
- model := opts.model(opts.modelType)
-
- // Determine which provider to use based on the model
- if strings.Contains(string(model.ID), "anthropic") {
- // Create Anthropic client with Bedrock configuration
- anthropicOpts := opts
- // TODO: later find a way to check if the AWS account has caching enabled
- opts.disableCache = true // Disable cache for Bedrock
- return &bedrockClient{
- providerOptions: opts,
- childProvider: newAnthropicClient(anthropicOpts, true),
- }
- }
- // Return client with nil childProvider if model is not supported
- // This will cause an error when used
- return &bedrockClient{
- providerOptions: opts,
- childProvider: nil,
+ return &bedrockProvider{
+ baseProvider: base,
+ childProvider: NewAnthropicProvider(base, true),
}
}
-func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- if b.childProvider == nil {
- return nil, errors.New("unsupported model for bedrock provider")
+func (b *bedrockProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ if len(b.region) < 2 {
+ return nil, errors.New("no region selected")
}
- return b.childProvider.send(ctx, messages, tools)
+ regionPrefix := b.region[:2]
+ modelName := model
+ model = fmt.Sprintf("%s.%s", regionPrefix, modelName)
+ messages = b.cleanMessages(messages)
+ return b.childProvider.Send(ctx, model, messages, tools)
}
-func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- eventChan := make(chan ProviderEvent)
-
- if b.childProvider == nil {
+func (b *bedrockProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ if len(b.region) < 2 {
+ eventChan := make(chan ProviderEvent)
go func() {
eventChan <- ProviderEvent{
Type: EventError,
- Error: errors.New("unsupported model for bedrock provider"),
+ Error: errors.New("no region selected"),
}
close(eventChan)
}()
return eventChan
}
-
- return b.childProvider.stream(ctx, messages, tools)
-}
-
-func (b *bedrockClient) Model() catwalk.Model {
- return b.providerOptions.model(b.providerOptions.modelType)
+ regionPrefix := b.region[:2]
+ modelName := model
+ model = fmt.Sprintf("%s.%s", regionPrefix, modelName)
+ messages = b.cleanMessages(messages)
+ return b.childProvider.Stream(ctx, model, messages, tools)
}
@@ -10,43 +10,39 @@ import (
"strings"
"time"
- "github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
"github.com/google/uuid"
"google.golang.org/genai"
)
-type geminiClient struct {
- providerOptions providerClientOptions
- client *genai.Client
+type geminiProvider struct {
+ *baseProvider
+ client *genai.Client
}
-type GeminiClient ProviderClient
-
-func newGeminiClient(opts providerClientOptions) GeminiClient {
- client, err := createGeminiClient(opts)
+func NewGeminiProvider(base *baseProvider) Provider {
+ client, err := createGeminiClient(base)
if err != nil {
slog.Error("Failed to create Gemini client", "error", err)
return nil
}
- return &geminiClient{
- providerOptions: opts,
- client: client,
+ return &geminiProvider{
+ baseProvider: base,
+ client: client,
}
}
-func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
- client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
+func createGeminiClient(base *baseProvider) (*genai.Client, error) {
+ client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: base.apiKey, Backend: genai.BackendGeminiAPI})
if err != nil {
return nil, err
}
return client, nil
}
-func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
+func (g *geminiProvider) convertMessages(messages []message.Message) []*genai.Content {
var history []*genai.Content
for _, msg := range messages {
switch msg.Role {
@@ -128,7 +124,7 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
return history
}
-func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
+func (g *geminiProvider) convertTools(tools []tools.BaseTool) []*genai.Tool {
geminiTool := &genai.Tool{}
geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
@@ -150,7 +146,7 @@ func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
return []*genai.Tool{geminiTool}
}
-func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
+func (g *geminiProvider) finishReason(reason genai.FinishReason) message.FinishReason {
switch reason {
case genai.FinishReasonStop:
return message.FinishReasonEndTurn
@@ -161,28 +157,27 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
}
}
-func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+func (g *geminiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ messages = g.cleanMessages(messages)
+ return g.send(ctx, model, messages, tools)
+}
+
+func (g *geminiProvider) send(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
// Convert messages
geminiMessages := g.convertMessages(messages)
- model := g.providerOptions.model(g.providerOptions.modelType)
- cfg := config.Get()
- if cfg.Options.Debug {
+ if g.debug {
jsonData, _ := json.Marshal(geminiMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
- modelConfig := cfg.Models[config.SelectedModelTypeLarge]
- if g.providerOptions.modelType == config.SelectedModelTypeSmall {
- modelConfig = cfg.Models[config.SelectedModelTypeSmall]
- }
-
+ model := g.Model(modelID)
maxTokens := model.DefaultMaxTokens
- if modelConfig.MaxTokens > 0 {
- maxTokens = modelConfig.MaxTokens
+ if g.maxTokens > 0 {
+ maxTokens = g.maxTokens
}
- systemMessage := g.providerOptions.systemMessage
- if g.providerOptions.systemPromptPrefix != "" {
- systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
+ systemMessage := g.systemMessage
+ if g.systemPromptPrefix != "" {
+ systemMessage = g.systemPromptPrefix + "\n" + systemMessage
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
@@ -260,34 +255,31 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
}
}
-func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+func (g *geminiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ messages = g.cleanMessages(messages)
+ return g.stream(ctx, model, messages, tools)
+}
+
+func (g *geminiProvider) stream(ctx context.Context, modelID string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
// Convert messages
geminiMessages := g.convertMessages(messages)
- model := g.providerOptions.model(g.providerOptions.modelType)
- cfg := config.Get()
- if cfg.Options.Debug {
+ model := g.Model(modelID)
+ if g.debug {
jsonData, _ := json.Marshal(geminiMessages)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
- modelConfig := cfg.Models[config.SelectedModelTypeLarge]
- if g.providerOptions.modelType == config.SelectedModelTypeSmall {
- modelConfig = cfg.Models[config.SelectedModelTypeSmall]
- }
maxTokens := model.DefaultMaxTokens
- if modelConfig.MaxTokens > 0 {
- maxTokens = modelConfig.MaxTokens
+ if g.maxTokens > 0 {
+ maxTokens = g.maxTokens
}
- // Override max tokens if set in provider options
- if g.providerOptions.maxTokens > 0 {
- maxTokens = g.providerOptions.maxTokens
- }
- systemMessage := g.providerOptions.systemMessage
- if g.providerOptions.systemPromptPrefix != "" {
- systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
+ systemMessage := g.systemMessage
+ if g.systemPromptPrefix != "" {
+ systemMessage = g.systemPromptPrefix + "\n" + systemMessage
}
+
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
@@ -412,7 +404,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
return eventChan
}
-func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+func (g *geminiProvider) shouldRetry(attempts int, err error) (bool, int64, error) {
// Check if error is a rate limit error
if attempts > maxRetries {
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
@@ -429,11 +421,11 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
// Check for token expiration (401 Unauthorized)
if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
- g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
+ g.apiKey, err = g.resolver.ResolveValue(g.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
- g.client, err = createGeminiClient(g.providerOptions)
+ g.client, err = createGeminiClient(g.baseProvider)
if err != nil {
return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
}
@@ -454,7 +446,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
return true, int64(retryMs), nil
}
-func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
+func (g *geminiProvider) usage(resp *genai.GenerateContentResponse) TokenUsage {
if resp == nil || resp.UsageMetadata == nil {
return TokenUsage{}
}
@@ -467,10 +459,6 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
}
}
-func (g *geminiClient) Model() catwalk.Model {
- return g.providerOptions.model(g.providerOptions.modelType)
-}
-
// Helper functions
func parseJSONToMap(jsonStr string) (map[string]any, error) {
var result map[string]any
@@ -10,7 +10,6 @@ import (
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
"github.com/openai/openai-go"
@@ -18,30 +17,25 @@ import (
"github.com/openai/openai-go/shared"
)
-type openaiClient struct {
- providerOptions providerClientOptions
- client openai.Client
+type openaiProvider struct {
+ *baseProvider
+ client openai.Client
}
-type OpenAIClient ProviderClient
-
-func newOpenAIClient(opts providerClientOptions) OpenAIClient {
- return &openaiClient{
- providerOptions: opts,
- client: createOpenAIClient(opts),
+func NewOpenAIProvider(base *baseProvider) Provider {
+ return &openaiProvider{
+ baseProvider: base,
+ client: createOpenAIClient(base),
}
}
-func createOpenAIClient(opts providerClientOptions) openai.Client {
+func createOpenAIClient(opts *baseProvider) openai.Client {
openaiClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
if opts.baseURL != "" {
- resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
- if err == nil {
- openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
- }
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
}
for key, value := range opts.extraHeaders {
@@ -55,11 +49,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
return openai.NewClient(openaiClientOptions...)
}
-func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
+func (o *openaiProvider) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
// Add system message first
- systemMessage := o.providerOptions.systemMessage
- if o.providerOptions.systemPromptPrefix != "" {
- systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
+ systemMessage := o.systemMessage
+ if o.systemPromptPrefix != "" {
+ systemMessage = o.systemPromptPrefix + "\n" + systemMessage
}
openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
@@ -126,7 +120,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
return
}
-func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
+func (o *openaiProvider) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
for i, tool := range tools {
@@ -147,7 +141,7 @@ func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatComplet
return openaiTools
}
-func (o *openaiClient) finishReason(reason string) message.FinishReason {
+func (o *openaiProvider) finishReason(reason string) message.FinishReason {
switch reason {
case "stop":
return message.FinishReasonEndTurn
@@ -160,17 +154,14 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
}
}
-func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
- model := o.providerOptions.model(o.providerOptions.modelType)
- cfg := config.Get()
+func (o *openaiProvider) preparedParams(modelID string, messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
+ model := o.Model(modelID)
- modelConfig := cfg.Models[config.SelectedModelTypeLarge]
- if o.providerOptions.modelType == config.SelectedModelTypeSmall {
- modelConfig = cfg.Models[config.SelectedModelTypeSmall]
+ reasoningEffort := o.reasoningEffort
+ if reasoningEffort == "" {
+ reasoningEffort = model.DefaultReasoningEffort
}
- reasoningEffort := modelConfig.ReasoningEffort
-
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(model.ID),
Messages: messages,
@@ -178,14 +169,10 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
}
maxTokens := model.DefaultMaxTokens
- if modelConfig.MaxTokens > 0 {
- maxTokens = modelConfig.MaxTokens
+ if o.maxTokens > 0 {
+ maxTokens = o.maxTokens
}
- // Override max tokens if set in provider options
- if o.providerOptions.maxTokens > 0 {
- maxTokens = o.providerOptions.maxTokens
- }
if model.CanReason {
params.MaxCompletionTokens = openai.Int(maxTokens)
switch reasoningEffort {
@@ -205,10 +192,14 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
return params
}
-func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
- params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
- cfg := config.Get()
- if cfg.Options.Debug {
+func (o *openaiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ messages = o.cleanMessages(messages)
+ return o.send(ctx, model, messages, tools)
+}
+
+func (o *openaiProvider) send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
+ params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools))
+ if o.debug {
jsonData, _ := json.Marshal(params)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -262,14 +253,18 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too
}
}
-func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
+func (o *openaiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ messages = o.cleanMessages(messages)
+ return o.stream(ctx, model, messages, tools)
+}
+
+func (o *openaiProvider) stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools))
params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
IncludeUsage: openai.Bool(true),
}
- cfg := config.Get()
- if cfg.Options.Debug {
+ if o.debug {
jsonData, _ := json.Marshal(params)
slog.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -350,7 +345,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
err := openaiStream.Err()
if err == nil || errors.Is(err, io.EOF) {
- if cfg.Options.Debug {
+ if o.debug {
jsonData, _ := json.Marshal(acc.ChatCompletion)
slog.Debug("Response", "messages", string(jsonData))
}
@@ -421,7 +416,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
return eventChan
}
-func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+func (o *openaiProvider) shouldRetry(attempts int, err error) (bool, int64, error) {
var apiErr *openai.Error
if !errors.As(err, &apiErr) {
return false, 0, err
@@ -433,11 +428,11 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
// Check for token expiration (401 Unauthorized)
if apiErr.StatusCode == 401 {
- o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
+ o.apiKey, err = o.resolver.ResolveValue(o.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
- o.client = createOpenAIClient(o.providerOptions)
+ o.client = createOpenAIClient(o.baseProvider)
return true, 0, nil
}
@@ -459,7 +454,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
return true, int64(retryMs), nil
}
-func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
+func (o *openaiProvider) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
var toolCalls []message.ToolCall
if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
@@ -478,7 +473,7 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too
return toolCalls
}
-func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
+func (o *openaiProvider) usage(completion openai.ChatCompletion) TokenUsage {
cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
inputTokens := completion.Usage.PromptTokens - cachedTokens
@@ -489,7 +484,3 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
CacheReadTokens: cachedTokens,
}
}
-
-func (o *openaiClient) Model() catwalk.Model {
- return o.providerOptions.model(o.providerOptions.modelType)
-}
@@ -1,90 +0,0 @@
-package provider
-
-import (
- "context"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "os"
- "testing"
- "time"
-
- "github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/openai/openai-go"
- "github.com/openai/openai-go/option"
-)
-
-func TestMain(m *testing.M) {
- _, err := config.Init(".", true)
- if err != nil {
- panic("Failed to initialize config: " + err.Error())
- }
-
- os.Exit(m.Run())
-}
-
-func TestOpenAIClientStreamChoices(t *testing.T) {
- // Create a mock server that returns Server-Sent Events with empty choices
- // This simulates the 🤡 behavior when a server returns 200 instead of 404
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.WriteHeader(http.StatusOK)
-
- emptyChoicesChunk := map[string]any{
- "id": "chat-completion-test",
- "object": "chat.completion.chunk",
- "created": time.Now().Unix(),
- "model": "test-model",
- "choices": []any{}, // Empty choices array that causes panic
- }
-
- jsonData, _ := json.Marshal(emptyChoicesChunk)
- w.Write([]byte("data: " + string(jsonData) + "\n\n"))
- w.Write([]byte("data: [DONE]\n\n"))
- }))
- defer server.Close()
-
- // Create OpenAI client pointing to our mock server
- client := &openaiClient{
- providerOptions: providerClientOptions{
- modelType: config.SelectedModelTypeLarge,
- apiKey: "test-key",
- systemMessage: "test",
- model: func(config.SelectedModelType) catwalk.Model {
- return catwalk.Model{
- ID: "test-model",
- Name: "test-model",
- }
- },
- },
- client: openai.NewClient(
- option.WithAPIKey("test-key"),
- option.WithBaseURL(server.URL),
- ),
- }
-
- // Create test messages
- messages := []message.Message{
- {
- Role: message.User,
- Parts: []message.ContentPart{message.TextContent{Text: "Hello"}},
- },
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
-
- eventsChan := client.stream(ctx, messages, nil)
-
- // Collect events - this will panic without the bounds check
- for event := range eventsChan {
- t.Logf("Received event: %+v", event)
- if event.Type == EventError || event.Type == EventComplete {
- break
- }
- }
-}
@@ -3,11 +3,13 @@ package provider
import (
"context"
"fmt"
+ "net/http"
+ "time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/resolver"
)
type EventType string
@@ -52,159 +54,233 @@ type ProviderEvent struct {
ToolCall *message.ToolCall
Error error
}
+
+type Config struct {
+ // The provider's id.
+ ID string `json:"id,omitempty"`
+ // The provider's name, used for display purposes.
+ Name string `json:"name,omitempty"`
+ // The provider's API endpoint.
+ BaseURL string `json:"base_url,omitempty"`
+ // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
+ Type catwalk.Type `json:"type,omitempty"`
+ // The provider's API key.
+ APIKey string `json:"api_key,omitempty"`
+ // Marks the provider as disabled.
+ Disable bool `json:"disable,omitempty"`
+
+ // Custom system prompt prefix.
+ SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
+
+ // Extra headers to send with each request to the provider.
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // Extra body
+ ExtraBody map[string]any `json:"extra_body,omitempty"`
+
+ // Used to pass extra parameters to the provider.
+ ExtraParams map[string]string `json:"-"`
+
+ // The provider models
+ Models []catwalk.Model `json:"models,omitempty"`
+}
+
type Provider interface {
- SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
+ Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
+
+ Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
- StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
+ Model(modelID string) *catwalk.Model
- Model() catwalk.Model
+ SetDebug(debug bool)
}
-type providerClientOptions struct {
+type baseProvider struct {
baseURL string
- config config.ProviderConfig
+ debug bool
+ config Config
apiKey string
- modelType config.SelectedModelType
- model func(config.SelectedModelType) catwalk.Model
disableCache bool
systemMessage string
systemPromptPrefix string
maxTokens int64
+ think bool
+ reasoningEffort string
+ resolver resolver.Resolver
extraHeaders map[string]string
extraBody map[string]any
extraParams map[string]string
}
-type ProviderClientOption func(*providerClientOptions)
-
-type ProviderClient interface {
- send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
- stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
-
- Model() catwalk.Model
-}
-
-type baseProvider[C ProviderClient] struct {
- options providerClientOptions
- client C
-}
+type Option func(*baseProvider)
-func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
- for _, msg := range messages {
- // The message has no content
- if len(msg.Parts) == 0 {
- continue
- }
- cleaned = append(cleaned, msg)
+func WithDisableCache(disableCache bool) Option {
+ return func(options *baseProvider) {
+ options.disableCache = disableCache
}
- return
}
-func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
- messages = p.cleanMessages(messages)
- return p.client.send(ctx, messages, tools)
+func WithSystemMessage(systemMessage string) Option {
+ return func(options *baseProvider) {
+ options.systemMessage = systemMessage
+ }
}
-func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
- messages = p.cleanMessages(messages)
- return p.client.stream(ctx, messages, tools)
+func WithMaxTokens(maxTokens int64) Option {
+ return func(options *baseProvider) {
+ options.maxTokens = maxTokens
+ }
}
-func (p *baseProvider[C]) Model() catwalk.Model {
- return p.client.Model()
+func WithThinking(think bool) Option {
+ return func(options *baseProvider) {
+ options.think = think
+ }
}
-func WithModel(model config.SelectedModelType) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.modelType = model
+func WithReasoningEffort(reasoningEffort string) Option {
+ return func(options *baseProvider) {
+ options.reasoningEffort = reasoningEffort
}
}
-func WithDisableCache(disableCache bool) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.disableCache = disableCache
+func WithDebug(debug bool) Option {
+ return func(options *baseProvider) {
+ options.debug = debug
}
}
-func WithSystemMessage(systemMessage string) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.systemMessage = systemMessage
+func WithResolver(resolver resolver.Resolver) Option {
+ return func(options *baseProvider) {
+ options.resolver = resolver
}
}
-func WithMaxTokens(maxTokens int64) ProviderClientOption {
- return func(options *providerClientOptions) {
- options.maxTokens = maxTokens
+func newBaseProvider(cfg Config, opts ...Option) (*baseProvider, error) {
+ provider := &baseProvider{
+ baseURL: cfg.BaseURL,
+ config: cfg,
+ apiKey: cfg.APIKey,
+ extraHeaders: cfg.ExtraHeaders,
+ extraBody: cfg.ExtraBody,
+ systemPromptPrefix: cfg.SystemPromptPrefix,
+ resolver: resolver.New(),
+ }
+ for _, o := range opts {
+ o(provider)
}
-}
-func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
- resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
+ resolvedAPIKey, err := provider.resolver.ResolveValue(cfg.APIKey)
if err != nil {
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
}
+ resolvedBaseURL, err := provider.resolver.ResolveValue(cfg.BaseURL)
+ if err != nil {
+ resolvedBaseURL = ""
+ }
// Resolve extra headers
resolvedExtraHeaders := make(map[string]string)
for key, value := range cfg.ExtraHeaders {
- resolvedValue, err := config.Get().Resolve(value)
+ resolvedValue, err := provider.resolver.ResolveValue(value)
if err != nil {
return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
}
resolvedExtraHeaders[key] = resolvedValue
}
- clientOptions := providerClientOptions{
- baseURL: cfg.BaseURL,
- config: cfg,
- apiKey: resolvedAPIKey,
- extraHeaders: resolvedExtraHeaders,
- extraBody: cfg.ExtraBody,
- systemPromptPrefix: cfg.SystemPromptPrefix,
- model: func(tp config.SelectedModelType) catwalk.Model {
- return *config.Get().GetModelByType(tp)
- },
- }
- for _, o := range opts {
- o(&clientOptions)
+ provider.apiKey = resolvedAPIKey
+ provider.baseURL = resolvedBaseURL
+ provider.extraHeaders = resolvedExtraHeaders
+ return provider, nil
+}
+
+func NewProvider(cfg Config, opts ...Option) (Provider, error) {
+ base, err := newBaseProvider(cfg, opts...)
+ if err != nil {
+ return nil, err
}
switch cfg.Type {
case catwalk.TypeAnthropic:
- return &baseProvider[AnthropicClient]{
- options: clientOptions,
- client: newAnthropicClient(clientOptions, false),
- }, nil
+ return NewAnthropicProvider(base, false), nil
case catwalk.TypeOpenAI:
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
+ return NewOpenAIProvider(base), nil
case catwalk.TypeGemini:
- return &baseProvider[GeminiClient]{
- options: clientOptions,
- client: newGeminiClient(clientOptions),
- }, nil
+ return NewGeminiProvider(base), nil
case catwalk.TypeBedrock:
- return &baseProvider[BedrockClient]{
- options: clientOptions,
- client: newBedrockClient(clientOptions),
- }, nil
+ return NewBedrockProvider(base), nil
case catwalk.TypeAzure:
- return &baseProvider[AzureClient]{
- options: clientOptions,
- client: newAzureClient(clientOptions),
- }, nil
+ return NewAzureProvider(base), nil
case catwalk.TypeVertexAI:
- return &baseProvider[VertexAIClient]{
- options: clientOptions,
- client: newVertexAIClient(clientOptions),
- }, nil
- case catwalk.TypeXAI:
- clientOptions.baseURL = "https://api.x.ai/v1"
- return &baseProvider[OpenAIClient]{
- options: clientOptions,
- client: newOpenAIClient(clientOptions),
- }, nil
+ return NewVertexAIProvider(base), nil
}
return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
}
+
+func (p *baseProvider) cleanMessages(messages []message.Message) (cleaned []message.Message) {
+ for _, msg := range messages {
+ // The message has no content
+ if len(msg.Parts) == 0 {
+ continue
+ }
+ cleaned = append(cleaned, msg)
+ }
+ return
+}
+
+func (o *baseProvider) Model(model string) *catwalk.Model {
+ for _, m := range o.config.Models {
+ if m.ID == model {
+ return &m
+ }
+ }
+ return nil
+}
+
+func (o *baseProvider) SetDebug(debug bool) {
+ o.debug = debug
+}
+
+func (c *Config) TestConnection(resolver resolver.Resolver) error {
+ testURL := ""
+ headers := make(map[string]string)
+ apiKey, _ := resolver.ResolveValue(c.APIKey)
+ switch c.Type {
+ case catwalk.TypeOpenAI:
+ baseURL, _ := resolver.ResolveValue(c.BaseURL)
+ if baseURL == "" {
+ baseURL = "https://api.openai.com/v1"
+ }
+ testURL = baseURL + "/models"
+ headers["Authorization"] = "Bearer " + apiKey
+ case catwalk.TypeAnthropic:
+ baseURL, _ := resolver.ResolveValue(c.BaseURL)
+ if baseURL == "" {
+ baseURL = "https://api.anthropic.com/v1"
+ }
+ testURL = baseURL + "/models"
+ headers["x-api-key"] = apiKey
+ headers["anthropic-version"] = "2023-06-01"
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ client := &http.Client{}
+ req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ for k, v := range c.ExtraHeaders {
+ req.Header.Set(k, v)
+ }
+ b, err := client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
+ }
+ if b.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
+ }
+ _ = b.Body.Close()
+ return nil
+}
@@ -7,11 +7,9 @@ import (
"google.golang.org/genai"
)
-type VertexAIClient ProviderClient
-
-func newVertexAIClient(opts providerClientOptions) VertexAIClient {
- project := opts.extraHeaders["project"]
- location := opts.extraHeaders["location"]
+func NewVertexAIProvider(base *baseProvider) Provider {
+ project := base.extraHeaders["project"]
+ location := base.extraHeaders["location"]
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
Project: project,
Location: location,
@@ -22,8 +20,8 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
return nil
}
- return &geminiClient{
- providerOptions: opts,
- client: client,
+ return &geminiProvider{
+ baseProvider: base,
+ client: client,
}
}
@@ -3,6 +3,9 @@ package tools
import (
"context"
"encoding/json"
+ "slices"
+
+ "github.com/charmbracelet/crush/internal/csync"
)
type ToolInfo struct {
@@ -83,3 +86,51 @@ func GetContextValues(ctx context.Context) (string, string) {
}
return sessionID.(string), messageID.(string)
}
+
+type Registry interface {
+ GetTool(name string) (BaseTool, bool)
+ SetTool(name string, tool BaseTool)
+ GetAllTools() []BaseTool
+}
+
+type registry struct {
+ tools *csync.LazySlice[BaseTool]
+}
+
+func (r *registry) GetAllTools() []BaseTool {
+ return slices.Collect(r.tools.Seq())
+}
+
+func (r *registry) GetTool(name string) (BaseTool, bool) {
+ for tool := range r.tools.Seq() {
+ if tool.Name() == name {
+ return tool, true
+ }
+ }
+
+ return nil, false
+}
+
+func (r *registry) SetTool(name string, tool BaseTool) {
+ for k, tool := range r.tools.Seq2() {
+ if tool.Name() == name {
+ r.tools.Set(k, tool)
+ return
+ }
+ }
+ r.tools.Append(tool)
+}
+
+type LazyToolsFn func() []BaseTool
+
+func NewRegistry(lazyTools LazyToolsFn) Registry {
+ return ®istry{
+ tools: csync.NewLazySlice(lazyTools),
+ }
+}
+
+func NewRegistryFromTools(tools []BaseTool) Registry {
+ return ®istry{
+ tools: csync.NewLazySlice(func() []BaseTool { return tools }),
+ }
+}
@@ -0,0 +1,188 @@
+package resolver
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/env"
+ "github.com/charmbracelet/crush/internal/shell"
+)
+
+type Resolver interface {
+ ResolveValue(value string) (string, error)
+}
+
+type Shell interface {
+ Exec(ctx context.Context, command string) (stdout, stderr string, err error)
+}
+
+type shellVariableResolver struct {
+ shell Shell
+ env env.Env
+}
+
+func NewShellVariableResolver(env env.Env) Resolver {
+ return &shellVariableResolver{
+ env: env,
+ shell: shell.NewShell(
+ &shell.Options{
+ Env: env.Env(),
+ },
+ ),
+ }
+}
+
+// ResolveValue is a method for resolving values, such as environment variables.
+// it will resolve shell-like variable substitution anywhere in the string, including:
+// - $(command) for command substitution
+// - $VAR or ${VAR} for environment variables
+func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
+ // Special case: lone $ is an error (backward compatibility)
+ if value == "$" {
+ return "", fmt.Errorf("invalid value format: %s", value)
+ }
+
+ // If no $ found, return as-is
+ if !strings.Contains(value, "$") {
+ return value, nil
+ }
+
+ result := value
+
+ // Handle command substitution: $(command)
+ for {
+ start := strings.Index(result, "$(")
+ if start == -1 {
+ break
+ }
+
+ // Find matching closing parenthesis
+ depth := 0
+ end := -1
+ for i := start + 2; i < len(result); i++ {
+ if result[i] == '(' {
+ depth++
+ } else if result[i] == ')' {
+ if depth == 0 {
+ end = i
+ break
+ }
+ depth--
+ }
+ }
+
+ if end == -1 {
+ return "", fmt.Errorf("unmatched $( in value: %s", value)
+ }
+
+ command := result[start+2 : end]
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+
+ stdout, _, err := r.shell.Exec(ctx, command)
+ cancel()
+ if err != nil {
+ return "", fmt.Errorf("command execution failed for '%s': %w", command, err)
+ }
+
+ // Replace the $(command) with the output
+ replacement := strings.TrimSpace(stdout)
+ result = result[:start] + replacement + result[end+1:]
+ }
+
+ // Handle environment variables: $VAR and ${VAR}
+ searchStart := 0
+ for {
+ start := strings.Index(result[searchStart:], "$")
+ if start == -1 {
+ break
+ }
+ start += searchStart // Adjust for the offset
+
+ // Skip if this is part of $( which we already handled
+ if start+1 < len(result) && result[start+1] == '(' {
+ // Skip past this $(...)
+ searchStart = start + 1
+ continue
+ }
+ var varName string
+ var end int
+
+ if start+1 < len(result) && result[start+1] == '{' {
+ // Handle ${VAR} format
+ closeIdx := strings.Index(result[start+2:], "}")
+ if closeIdx == -1 {
+ return "", fmt.Errorf("unmatched ${ in value: %s", value)
+ }
+ varName = result[start+2 : start+2+closeIdx]
+ end = start + 2 + closeIdx + 1
+ } else {
+ // Handle $VAR format - variable names must start with letter or underscore
+ if start+1 >= len(result) {
+ return "", fmt.Errorf("incomplete variable reference at end of string: %s", value)
+ }
+
+ if result[start+1] != '_' &&
+ (result[start+1] < 'a' || result[start+1] > 'z') &&
+ (result[start+1] < 'A' || result[start+1] > 'Z') {
+ return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value)
+ }
+
+ end = start + 1
+ for end < len(result) && (result[end] == '_' ||
+ (result[end] >= 'a' && result[end] <= 'z') ||
+ (result[end] >= 'A' && result[end] <= 'Z') ||
+ (result[end] >= '0' && result[end] <= '9')) {
+ end++
+ }
+ varName = result[start+1 : end]
+ }
+
+ envValue := r.env.Get(varName)
+ if envValue == "" {
+ return "", fmt.Errorf("environment variable %q not set", varName)
+ }
+
+ result = result[:start] + envValue + result[end:]
+ searchStart = start + len(envValue) // Continue searching after the replacement
+ }
+
+ return result, nil
+}
+
+type environmentVariableResolver struct {
+ env env.Env
+}
+
+func NewEnvironmentVariableResolver(env env.Env) Resolver {
+ return &environmentVariableResolver{
+ env: env,
+ }
+}
+
+// ResolveValue resolves environment variables from the provided env.Env.
+func (r *environmentVariableResolver) ResolveValue(value string) (string, error) {
+ if !strings.HasPrefix(value, "$") {
+ return value, nil
+ }
+
+ varName := strings.TrimPrefix(value, "$")
+ resolvedValue := r.env.Get(varName)
+ if resolvedValue == "" {
+ return "", fmt.Errorf("environment variable %q not set", varName)
+ }
+ return resolvedValue, nil
+}
+
+func New() Resolver {
+ env := env.New()
+ return &shellVariableResolver{
+ env: env,
+ shell: shell.NewShell(
+ &shell.Options{
+ Env: env.Env(),
+ },
+ ),
+ }
+}
@@ -0,0 +1,332 @@
+package resolver
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/env"
+ "github.com/stretchr/testify/assert"
+)
+
+// mockShell implements the Shell interface for testing
+type mockShell struct {
+ execFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
+}
+
+func (m *mockShell) Exec(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if m.execFunc != nil {
+ return m.execFunc(ctx, command)
+ }
+ return "", "", nil
+}
+
+func TestShellVariableResolver_ResolveValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ envVars map[string]string
+ shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
+ expected string
+ expectError bool
+ }{
+ {
+ name: "non-variable string returns as-is",
+ value: "plain-string",
+ expected: "plain-string",
+ },
+ {
+ name: "environment variable resolution",
+ value: "$HOME",
+ envVars: map[string]string{"HOME": "/home/user"},
+ expected: "/home/user",
+ },
+ {
+ name: "missing environment variable returns error",
+ value: "$MISSING_VAR",
+ envVars: map[string]string{},
+ expectError: true,
+ },
+
+ {
+ name: "shell command with whitespace trimming",
+ value: "$(echo ' spaced ')",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "echo ' spaced '" {
+ return " spaced \n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "spaced",
+ },
+ {
+ name: "shell command execution error",
+ value: "$(false)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ return "", "", errors.New("command failed")
+ },
+ expectError: true,
+ },
+ {
+ name: "invalid format returns error",
+ value: "$",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ testEnv := env.NewFromMap(tt.envVars)
+ resolver := &shellVariableResolver{
+ shell: &mockShell{execFunc: tt.shellFunc},
+ env: testEnv,
+ }
+
+ result, err := resolver.ResolveValue(tt.value)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ envVars map[string]string
+ shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
+ expected string
+ expectError bool
+ }{
+ {
+ name: "command substitution within string",
+ value: "Bearer $(echo token123)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "echo token123" {
+ return "token123\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "Bearer token123",
+ },
+ {
+ name: "environment variable within string",
+ value: "Bearer $TOKEN",
+ envVars: map[string]string{"TOKEN": "sk-ant-123"},
+ expected: "Bearer sk-ant-123",
+ },
+ {
+ name: "environment variable with braces within string",
+ value: "Bearer ${TOKEN}",
+ envVars: map[string]string{"TOKEN": "sk-ant-456"},
+ expected: "Bearer sk-ant-456",
+ },
+ {
+ name: "mixed command and environment substitution",
+ value: "$USER-$(date +%Y)-$HOST",
+ envVars: map[string]string{
+ "USER": "testuser",
+ "HOST": "localhost",
+ },
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "date +%Y" {
+ return "2024\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "testuser-2024-localhost",
+ },
+ {
+ name: "multiple command substitutions",
+ value: "$(echo hello) $(echo world)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ switch command {
+ case "echo hello":
+ return "hello\n", "", nil
+ case "echo world":
+ return "world\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "hello world",
+ },
+ {
+ name: "nested parentheses in command",
+ value: "$(echo $(echo inner))",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "echo $(echo inner)" {
+ return "nested\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "nested",
+ },
+ {
+ name: "lone dollar with non-variable chars",
+ value: "prefix$123suffix", // Numbers can't start variable names
+ expectError: true,
+ },
+ {
+ name: "dollar with special chars",
+ value: "a$@b$#c", // Special chars aren't valid in variable names
+ expectError: true,
+ },
+ {
+ name: "empty environment variable substitution",
+ value: "Bearer $EMPTY_VAR",
+ envVars: map[string]string{},
+ expectError: true,
+ },
+ {
+ name: "unmatched command substitution opening",
+ value: "Bearer $(echo test",
+ expectError: true,
+ },
+ {
+ name: "unmatched environment variable braces",
+ value: "Bearer ${TOKEN",
+ expectError: true,
+ },
+ {
+ name: "command substitution with error",
+ value: "Bearer $(false)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ return "", "", errors.New("command failed")
+ },
+ expectError: true,
+ },
+ {
+ name: "complex real-world example",
+ value: "Bearer $(cat /tmp/token.txt | base64 -w 0)",
+ shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
+ if command == "cat /tmp/token.txt | base64 -w 0" {
+ return "c2stYW50LXRlc3Q=\n", "", nil
+ }
+ return "", "", errors.New("unexpected command")
+ },
+ expected: "Bearer c2stYW50LXRlc3Q=",
+ },
+ {
+ name: "environment variable with underscores and numbers",
+ value: "Bearer $API_KEY_V2",
+ envVars: map[string]string{"API_KEY_V2": "sk-test-123"},
+ expected: "Bearer sk-test-123",
+ },
+ {
+ name: "no substitution needed",
+ value: "Bearer sk-ant-static-token",
+ expected: "Bearer sk-ant-static-token",
+ },
+ {
+ name: "incomplete variable at end",
+ value: "Bearer $",
+ expectError: true,
+ },
+ {
+ name: "variable with invalid character",
+ value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names
+ expectError: true,
+ },
+ {
+ name: "multiple invalid variables",
+ value: "$1$2$3",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ testEnv := env.NewFromMap(tt.envVars)
+ resolver := &shellVariableResolver{
+ shell: &mockShell{execFunc: tt.shellFunc},
+ env: testEnv,
+ }
+
+ result, err := resolver.ResolveValue(tt.value)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ envVars map[string]string
+ expected string
+ expectError bool
+ }{
+ {
+ name: "non-variable string returns as-is",
+ value: "plain-string",
+ expected: "plain-string",
+ },
+ {
+ name: "environment variable resolution",
+ value: "$HOME",
+ envVars: map[string]string{"HOME": "/home/user"},
+ expected: "/home/user",
+ },
+ {
+ name: "environment variable with complex value",
+ value: "$PATH",
+ envVars: map[string]string{"PATH": "/usr/bin:/bin:/usr/local/bin"},
+ expected: "/usr/bin:/bin:/usr/local/bin",
+ },
+ {
+ name: "missing environment variable returns error",
+ value: "$MISSING_VAR",
+ envVars: map[string]string{},
+ expectError: true,
+ },
+ {
+ name: "empty environment variable returns error",
+ value: "$EMPTY_VAR",
+ envVars: map[string]string{"EMPTY_VAR": ""},
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ testEnv := env.NewFromMap(tt.envVars)
+ resolver := NewEnvironmentVariableResolver(testEnv)
+
+ result, err := resolver.ResolveValue(tt.value)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestNewShellVariableResolver(t *testing.T) {
+ testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
+ resolver := NewShellVariableResolver(testEnv)
+
+ assert.NotNil(t, resolver)
+ assert.Implements(t, (*Resolver)(nil), resolver)
+}
+
+func TestNewEnvironmentVariableResolver(t *testing.T) {
+ testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
+ resolver := NewEnvironmentVariableResolver(testEnv)
+
+ assert.NotNil(t, resolver)
+ assert.Implements(t, (*Resolver)(nil), resolver)
+}
@@ -351,6 +351,7 @@ func (m *messageListCmp) updateAssistantMessageContent(msg message.Message, assi
messages.NewAssistantSection(
msg,
time.Unix(m.lastUserMessageTime, 0),
+ m.app.Config(),
),
)
}
@@ -472,7 +473,7 @@ func (m *messageListCmp) convertMessagesToUI(sessionMessages []message.Message,
case message.Assistant:
uiMessages = append(uiMessages, m.convertAssistantMessage(msg, toolResultMap)...)
if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonEndTurn {
- uiMessages = append(uiMessages, messages.NewAssistantSection(msg, time.Unix(m.lastUserMessageTime, 0)))
+ uiMessages = append(uiMessages, messages.NewAssistantSection(msg, time.Unix(m.lastUserMessageTime, 0), m.app.Config()))
}
}
}
@@ -5,9 +5,8 @@ import (
"strings"
tea "github.com/charmbracelet/bubbletea/v2"
- "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/fsext"
- "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/session"
@@ -27,14 +26,14 @@ type Header interface {
type header struct {
width int
session session.Session
- lspClients map[string]*lsp.Client
+ app *app.App
detailsOpen bool
}
-func New(lspClients map[string]*lsp.Client) Header {
+func New(app *app.App) Header {
return &header{
- lspClients: lspClients,
- width: 0,
+ app: app,
+ width: 0,
}
}
@@ -88,13 +87,13 @@ func (p *header) View() string {
func (h *header) details() string {
t := styles.CurrentTheme()
- cwd := fsext.DirTrim(fsext.PrettyPath(config.Get().WorkingDir()), 4)
+ cwd := fsext.DirTrim(fsext.PrettyPath(h.app.Config().WorkingDir()), 4)
parts := []string{
t.S().Muted.Render(cwd),
}
errorCount := 0
- for _, l := range h.lspClients {
+ for _, l := range h.app.LSPClients {
for _, diagnostics := range l.GetDiagnostics() {
for _, diagnostic := range diagnostics {
if diagnostic.Severity == protocol.SeverityError {
@@ -108,8 +107,7 @@ func (h *header) details() string {
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
}
- agentCfg := config.Get().Agents["coder"]
- model := config.Get().GetModelByType(agentCfg.Model)
+ model := h.app.CoderAgent.Model()
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
parts = append(parts, formattedPercentage)
@@ -352,6 +352,7 @@ type AssistantSection interface {
}
type assistantSectionModel struct {
width int
+ config *config.Config
id string
message message.Message
lastUserMessageTime time.Time
@@ -362,9 +363,10 @@ func (m *assistantSectionModel) ID() string {
return m.id
}
-func NewAssistantSection(message message.Message, lastUserMessageTime time.Time) AssistantSection {
+func NewAssistantSection(message message.Message, lastUserMessageTime time.Time, cfg *config.Config) AssistantSection {
return &assistantSectionModel{
width: 0,
+ config: cfg,
id: uuid.NewString(),
message: message,
lastUserMessageTime: lastUserMessageTime,
@@ -386,7 +388,7 @@ func (m *assistantSectionModel) View() string {
duration := finishTime.Sub(m.lastUserMessageTime)
infoMsg := t.S().Subtle.Render(duration.String())
icon := t.S().Subtle.Render(styles.ModelIcon)
- model := config.Get().GetModel(m.message.Provider, m.message.Model)
+ model := m.config.GetModel(m.message.Provider, m.message.Model)
if model == nil {
// This means the model is not configured anymore
model = &catwalk.Model{
@@ -10,12 +10,11 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/diff"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/session"
@@ -69,16 +68,14 @@ type sidebarCmp struct {
session session.Session
logo string
cwd string
- lspClients map[string]*lsp.Client
compactMode bool
- history history.Service
files *csync.Map[string, SessionFile]
+ app *app.App
}
-func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar {
+func New(app *app.App, compact bool) Sidebar {
return &sidebarCmp{
- lspClients: lspClients,
- history: history,
+ app: app,
compactMode: compact,
files: csync.NewMap[string, SessionFile](),
}
@@ -194,7 +191,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te
before := existing.History.initialVersion.Content
after := existing.History.latestVersion.Content
path := existing.History.initialVersion.Path
- cwd := config.Get().WorkingDir()
+ cwd := m.app.Config().WorkingDir()
path = strings.TrimPrefix(path, cwd)
_, additions, deletions := diff.GenerateDiff(before, after, path)
existing.Additions = additions
@@ -221,7 +218,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te
}
func (m *sidebarCmp) loadSessionFiles() tea.Msg {
- files, err := m.history.ListBySession(context.Background(), m.session.ID)
+ files, err := m.app.History.ListBySession(context.Background(), m.session.ID)
if err != nil {
return util.InfoMsg{
Type: util.InfoTypeError,
@@ -247,7 +244,7 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg {
sessionFiles := make([]SessionFile, 0, len(fileMap))
for path, fh := range fileMap {
- cwd := config.Get().WorkingDir()
+ cwd := m.app.Config().WorkingDir()
path = strings.TrimPrefix(path, cwd)
_, additions, deletions := diff.GenerateDiff(fh.initialVersion.Content, fh.latestVersion.Content, path)
sessionFiles = append(sessionFiles, SessionFile{
@@ -265,7 +262,7 @@ func (m *sidebarCmp) loadSessionFiles() tea.Msg {
func (m *sidebarCmp) SetSize(width, height int) tea.Cmd {
m.logo = m.logoBlock()
- m.cwd = cwd()
+ m.cwd = cwd(m.app.Config().WorkingDir())
m.width = width
m.height = height
return nil
@@ -428,7 +425,7 @@ func (m *sidebarCmp) filesBlockCompact(maxWidth int) string {
}
extraContent := strings.Join(statusParts, " ")
- cwd := config.Get().WorkingDir() + string(os.PathSeparator)
+ cwd := m.app.Config().WorkingDir() + string(os.PathSeparator)
filePath := file.FilePath
filePath = strings.TrimPrefix(filePath, cwd)
filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2)
@@ -471,7 +468,7 @@ func (m *sidebarCmp) lspBlockCompact(maxWidth int) string {
lspList := []string{section, ""}
- lsp := config.Get().LSP.Sorted()
+ lsp := m.app.Config().LSP.Sorted()
if len(lsp) == 0 {
content := lipgloss.JoinVertical(
lipgloss.Left,
@@ -505,7 +502,7 @@ func (m *sidebarCmp) lspBlockCompact(maxWidth int) string {
protocol.SeverityHint: 0,
protocol.SeverityInformation: 0,
}
- if client, ok := m.lspClients[l.Name]; ok {
+ if client, ok := m.app.LSPClients[l.Name]; ok {
for _, diagnostics := range client.GetDiagnostics() {
for _, diagnostic := range diagnostics {
if severity, ok := lspErrs[diagnostic.Severity]; ok {
@@ -559,7 +556,7 @@ func (m *sidebarCmp) mcpBlockCompact(maxWidth int) string {
mcpList := []string{section, ""}
- mcps := config.Get().MCP.Sorted()
+ mcps := m.app.Config().MCP.Sorted()
if len(mcps) == 0 {
content := lipgloss.JoinVertical(
lipgloss.Left,
@@ -653,7 +650,7 @@ func (m *sidebarCmp) filesBlock() string {
}
extraContent := strings.Join(statusParts, " ")
- cwd := config.Get().WorkingDir() + string(os.PathSeparator)
+ cwd := m.app.Config().WorkingDir() + string(os.PathSeparator)
filePath := file.FilePath
filePath = strings.TrimPrefix(filePath, cwd)
filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2)
@@ -701,7 +698,7 @@ func (m *sidebarCmp) lspBlock() string {
lspList := []string{section, ""}
- lsp := config.Get().LSP.Sorted()
+ lsp := m.app.Config().LSP.Sorted()
if len(lsp) == 0 {
return lipgloss.JoinVertical(
lipgloss.Left,
@@ -729,7 +726,7 @@ func (m *sidebarCmp) lspBlock() string {
protocol.SeverityHint: 0,
protocol.SeverityInformation: 0,
}
- if client, ok := m.lspClients[l.Name]; ok {
+ if client, ok := m.app.LSPClients[l.Name]; ok {
for _, diagnostics := range client.GetDiagnostics() {
for _, diagnostic := range diagnostics {
if severity, ok := lspErrs[diagnostic.Severity]; ok {
@@ -789,7 +786,7 @@ func (m *sidebarCmp) mcpBlock() string {
mcpList := []string{section, ""}
- mcps := config.Get().MCP.Sorted()
+ mcps := m.app.Config().MCP.Sorted()
if len(mcps) == 0 {
return lipgloss.JoinVertical(
lipgloss.Left,
@@ -876,13 +873,9 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
}
func (s *sidebarCmp) currentModelBlock() string {
- cfg := config.Get()
- agentCfg := cfg.Agents["coder"]
-
- selectedModel := cfg.Models[agentCfg.Model]
-
- model := config.Get().GetModelByType(agentCfg.Model)
- modelProvider := config.Get().GetProviderForModel(agentCfg.Model)
+ model := s.app.CoderAgent.Model()
+ selectedModel := s.app.CoderAgent.ModelConfig()
+ modelProvider := s.app.CoderAgent.Provider()
t := styles.CurrentTheme()
@@ -938,8 +931,7 @@ func (m *sidebarCmp) SetCompactMode(compact bool) {
m.compactMode = compact
}
-func cwd() string {
- cwd := config.Get().WorkingDir()
+func cwd(cwd string) string {
t := styles.CurrentTheme()
// Replace home directory with ~, unless we're at the top level of the
// home directory).
@@ -12,7 +12,9 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/llm/agent"
"github.com/charmbracelet/crush/internal/llm/prompt"
+ "github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/layout"
@@ -71,9 +73,10 @@ type splashCmp struct {
selectedModel *models.ModelOption
isAPIKeyValid bool
apiKeyValue string
+ config *config.Config
}
-func New() Splash {
+func New(cfg *config.Config) Splash {
keyMap := DefaultKeyMap()
listKeyMap := list.DefaultKeyMap()
listKeyMap.Down.SetEnabled(false)
@@ -85,12 +88,13 @@ func New() Splash {
listKeyMap.DownOneItem = keyMap.Next
listKeyMap.UpOneItem = keyMap.Previous
- modelList := models.NewModelListComponent(listKeyMap, "Find your fave", false)
+ modelList := models.NewModelListComponent(cfg, listKeyMap, "Find your fave", false)
apiKeyInput := models.NewAPIKeyInput()
return &splashCmp{
width: 0,
height: 0,
+ config: cfg,
keyMap: keyMap,
logoRendered: "",
modelList: modelList,
@@ -214,16 +218,16 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return s, nil
}
- provider, err := s.getProvider(s.selectedModel.Provider.ID)
- if err != nil || provider == nil {
+ selectedProvider, err := s.getProvider(s.selectedModel.Provider.ID)
+ if err != nil || selectedProvider == nil {
return s, util.ReportError(fmt.Errorf("provider %s not found", s.selectedModel.Provider.ID))
}
- providerConfig := config.ProviderConfig{
+ providerConfig := provider.Config{
ID: string(s.selectedModel.Provider.ID),
Name: s.selectedModel.Provider.Name,
APIKey: s.apiKeyValue,
- Type: provider.Type,
- BaseURL: provider.APIEndpoint,
+ Type: selectedProvider.Type,
+ BaseURL: selectedProvider.APIEndpoint,
}
return s, tea.Sequence(
util.CmdHandler(models.APIKeyStateChangeMsg{
@@ -231,7 +235,7 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}),
func() tea.Msg {
start := time.Now()
- err := providerConfig.TestConnection(config.Get().Resolver())
+ err := providerConfig.TestConnection(s.config.Resolver())
// intentionally wait for at least 750ms to make sure the user sees the spinner
elapsed := time.Since(start)
if elapsed < 750*time.Millisecond {
@@ -320,8 +324,7 @@ func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
return util.ReportError(fmt.Errorf("no model selected"))
}
- cfg := config.Get()
- err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
+ err := s.config.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
if err != nil {
return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
}
@@ -338,7 +341,7 @@ func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
func (s *splashCmp) initializeProject() tea.Cmd {
s.needsProjectInit = false
- if err := config.MarkProjectInitialized(); err != nil {
+ if err := config.MarkProjectInitialized(s.config); err != nil {
return util.ReportError(err)
}
var cmds []tea.Cmd
@@ -356,20 +359,19 @@ func (s *splashCmp) initializeProject() tea.Cmd {
}
func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
- cfg := config.Get()
- model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
+ model := s.config.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
if model == nil {
return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
}
- selectedModel := config.SelectedModel{
+ selectedModel := agent.Model{
Model: selectedItem.Model.ID,
Provider: string(selectedItem.Provider.ID),
ReasoningEffort: model.DefaultReasoningEffort,
MaxTokens: model.DefaultMaxTokens,
}
- err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
+ err := s.config.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
if err != nil {
return util.ReportError(err)
}
@@ -381,33 +383,32 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
}
if knownProvider == nil {
// for local provider we just use the same model
- err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
+ err = s.config.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
if err != nil {
return util.ReportError(err)
}
} else {
smallModel := knownProvider.DefaultSmallModelID
- model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
+ model := s.config.GetModel(string(selectedItem.Provider.ID), smallModel)
// should never happen
if model == nil {
- err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
+ err = s.config.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
if err != nil {
return util.ReportError(err)
}
return nil
}
- smallSelectedModel := config.SelectedModel{
+ smallSelectedModel := agent.Model{
Model: smallModel,
Provider: string(selectedItem.Provider.ID),
ReasoningEffort: model.DefaultReasoningEffort,
MaxTokens: model.DefaultMaxTokens,
}
- err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
+ err = s.config.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
if err != nil {
return util.ReportError(err)
}
}
- cfg.SetupAgents()
return nil
}
@@ -425,8 +426,7 @@ func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.
}
func (s *splashCmp) isProviderConfigured(providerID string) bool {
- cfg := config.Get()
- if _, ok := cfg.Providers.Get(providerID); ok {
+ if _, ok := s.config.Providers.Get(providerID); ok {
return true
}
return false
@@ -652,7 +652,7 @@ func (s *splashCmp) getMaxInfoWidth() int {
}
func (s *splashCmp) cwd() string {
- cwd := config.Get().WorkingDir()
+ cwd := s.config.WorkingDir()
t := styles.CurrentTheme()
homeDir, err := os.UserHomeDir()
if err == nil && cwd != homeDir {
@@ -662,10 +662,10 @@ func (s *splashCmp) cwd() string {
return t.S().Muted.Width(maxWidth).Render(cwd)
}
-func LSPList(maxWidth int) []string {
+func LSPList(cfg *config.Config, maxWidth int) []string {
t := styles.CurrentTheme()
lspList := []string{}
- lsp := config.Get().LSP.Sorted()
+ lsp := cfg.LSP.Sorted()
if len(lsp) == 0 {
return []string{t.S().Base.Foreground(t.Border).Render("None")}
}
@@ -692,7 +692,7 @@ func (s *splashCmp) lspBlock() string {
t := styles.CurrentTheme()
maxWidth := s.getMaxInfoWidth() / 2
section := t.S().Subtle.Render("LSPs")
- lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
+ lspList := append([]string{section, ""}, LSPList(s.config, maxWidth-1)...)
return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
lipgloss.JoinVertical(
lipgloss.Left,
@@ -701,10 +701,10 @@ func (s *splashCmp) lspBlock() string {
)
}
-func MCPList(maxWidth int) []string {
+func MCPList(cfg *config.Config, maxWidth int) []string {
t := styles.CurrentTheme()
mcpList := []string{}
- mcps := config.Get().MCP.Sorted()
+ mcps := cfg.MCP.Sorted()
if len(mcps) == 0 {
return []string{t.S().Base.Foreground(t.Border).Render("None")}
}
@@ -731,7 +731,7 @@ func (s *splashCmp) mcpBlock() string {
t := styles.CurrentTheme()
maxWidth := s.getMaxInfoWidth() / 2
section := t.S().Subtle.Render("MCPs")
- mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
+ mcpList := append([]string{section, ""}, MCPList(s.config, maxWidth-1)...)
return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
lipgloss.JoinVertical(
lipgloss.Left,
@@ -7,7 +7,7 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/lipgloss/v2"
- "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/core"
@@ -49,6 +49,7 @@ type commandDialogCmp struct {
wWidth int // Width of the terminal window
wHeight int // Height of the terminal window
+ app *app.App
commandList listModel
keyMap CommandsDialogKeyMap
help help.Model
@@ -67,7 +68,7 @@ type (
}
)
-func NewCommandDialog(sessionID string) CommandsDialog {
+func NewCommandDialog(app *app.App, sessionID string) CommandsDialog {
keyMap := DefaultCommandsDialogKeyMap()
listKeyMap := list.DefaultKeyMap()
listKeyMap.Down.SetEnabled(false)
@@ -89,6 +90,7 @@ func NewCommandDialog(sessionID string) CommandsDialog {
help := help.New()
help.Styles = t.S().Help
return &commandDialogCmp{
+ app: app,
commandList: commandList,
width: defaultWidth,
keyMap: DefaultCommandsDialogKeyMap(),
@@ -99,7 +101,7 @@ func NewCommandDialog(sessionID string) CommandsDialog {
}
func (c *commandDialogCmp) Init() tea.Cmd {
- commands, err := LoadCustomCommands()
+ commands, err := LoadCustomCommands(c.app.Config())
if err != nil {
return util.ReportError(err)
}
@@ -274,13 +276,12 @@ func (c *commandDialogCmp) defaultCommands() []Command {
}
// Only show thinking toggle for Anthropic models that can reason
- cfg := config.Get()
- if agentCfg, ok := cfg.Agents["coder"]; ok {
- providerCfg := cfg.GetProviderForModel(agentCfg.Model)
- model := cfg.GetModelByType(agentCfg.Model)
+ if c.app.CoderAgent != nil {
+ providerCfg := c.app.CoderAgent.Provider()
+ model := c.app.CoderAgent.Model()
if providerCfg != nil && model != nil &&
providerCfg.Type == catwalk.TypeAnthropic && model.CanReason {
- selectedModel := cfg.Models[agentCfg.Model]
+ selectedModel := c.app.CoderAgent.ModelConfig()
status := "Enable"
if selectedModel.Think {
status = "Disable"
@@ -29,8 +29,7 @@ type commandSource struct {
prefix string
}
-func LoadCustomCommands() ([]Command, error) {
- cfg := config.Get()
+func LoadCustomCommands(cfg *config.Config) ([]Command, error) {
if cfg == nil {
return nil, fmt.Errorf("config not loaded")
}
@@ -7,6 +7,7 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/llm/agent"
"github.com/charmbracelet/crush/internal/tui/exp/list"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
@@ -18,9 +19,10 @@ type ModelListComponent struct {
list listModel
modelType int
providers []catwalk.Provider
+ config *config.Config
}
-func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
+func NewModelListComponent(cfg *config.Config, keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
t := styles.CurrentTheme()
inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
options := []list.ListOption{
@@ -42,6 +44,7 @@ func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldRe
return &ModelListComponent{
list: modelList,
modelType: LargeModelType,
+ config: cfg,
}
}
@@ -94,12 +97,11 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
// first none section
selectedItemID := ""
- cfg := config.Get()
- var currentModel config.SelectedModel
+ var currentModel agent.Model
if m.modelType == LargeModelType {
- currentModel = cfg.Models[config.SelectedModelTypeLarge]
+ currentModel = m.config.Models[config.SelectedModelTypeLarge]
} else {
- currentModel = cfg.Models[config.SelectedModelTypeSmall]
+ currentModel = m.config.Models[config.SelectedModelTypeSmall]
}
configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
@@ -114,7 +116,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
if err != nil {
return util.ReportError(err)
}
- for providerID, providerConfig := range cfg.Providers.Seq2() {
+ for providerID, providerConfig := range m.config.Providers.Seq2() {
if providerConfig.Disable {
continue
}
@@ -185,7 +187,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
// Check if this provider is configured and not disabled
- if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
+ if providerConfig, exists := m.config.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
continue
}
@@ -195,7 +197,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
section := list.NewItemSection(name)
- if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
+ if _, ok := m.config.Providers.Get(string(provider.ID)); ok {
section.SetInfo(configured)
}
group := list.Group[list.CompletionItem[ModelOption]]{
@@ -10,6 +10,8 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/llm/agent"
+ "github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
"github.com/charmbracelet/crush/internal/tui/exp/list"
@@ -34,7 +36,7 @@ const (
// ModelSelectedMsg is sent when a model is selected
type ModelSelectedMsg struct {
- Model config.SelectedModel
+ Model agent.Model
ModelType config.SelectedModelType
}
@@ -56,6 +58,7 @@ type modelDialogCmp struct {
wWidth int
wHeight int
+ config *config.Config
modelList *ModelListComponent
keyMap KeyMap
help help.Model
@@ -69,7 +72,7 @@ type modelDialogCmp struct {
apiKeyValue string
}
-func NewModelDialogCmp() ModelDialog {
+func NewModelDialogCmp(cfg *config.Config) ModelDialog {
keyMap := DefaultKeyMap()
listKeyMap := list.DefaultKeyMap()
@@ -79,7 +82,7 @@ func NewModelDialogCmp() ModelDialog {
listKeyMap.UpOneItem = keyMap.Previous
t := styles.CurrentTheme()
- modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
+ modelList := NewModelListComponent(cfg, listKeyMap, "Choose a model for large, complex tasks", true)
apiKeyInput := NewAPIKeyInput()
apiKeyInput.SetShowTitle(false)
help := help.New()
@@ -91,6 +94,7 @@ func NewModelDialogCmp() ModelDialog {
width: defaultWidth,
keyMap: DefaultKeyMap(),
help: help,
+ config: cfg,
}
}
@@ -119,16 +123,16 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.needsAPIKey {
// Handle API key submission
m.apiKeyValue = m.apiKeyInput.Value()
- provider, err := m.getProvider(m.selectedModel.Provider.ID)
- if err != nil || provider == nil {
+ selectedProvider, err := m.getProvider(m.selectedModel.Provider.ID)
+ if err != nil || selectedProvider == nil {
return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
}
- providerConfig := config.ProviderConfig{
+ providerConfig := provider.Config{
ID: string(m.selectedModel.Provider.ID),
Name: m.selectedModel.Provider.Name,
APIKey: m.apiKeyValue,
- Type: provider.Type,
- BaseURL: provider.APIEndpoint,
+ Type: selectedProvider.Type,
+ BaseURL: selectedProvider.APIEndpoint,
}
return m, tea.Sequence(
util.CmdHandler(APIKeyStateChangeMsg{
@@ -136,7 +140,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}),
func() tea.Msg {
start := time.Now()
- err := providerConfig.TestConnection(config.Get().Resolver())
+ err := providerConfig.TestConnection(m.config.Resolver())
// intentionally wait for at least 750ms to make sure the user sees the spinner
elapsed := time.Since(start)
if elapsed < 750*time.Millisecond {
@@ -169,7 +173,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
util.CmdHandler(ModelSelectedMsg{
- Model: config.SelectedModel{
+ Model: agent.Model{
Model: selectedItem.Model.ID,
Provider: string(selectedItem.Provider.ID),
},
@@ -342,8 +346,7 @@ func (m *modelDialogCmp) modelTypeRadio() string {
}
func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
- cfg := config.Get()
- if _, ok := cfg.Providers.Get(providerID); ok {
+ if _, ok := m.config.Providers.Get(providerID); ok {
return true
}
return false
@@ -367,8 +370,7 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
return util.ReportError(fmt.Errorf("no model selected"))
}
- cfg := config.Get()
- err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
+ err := m.config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
if err != nil {
return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
}
@@ -378,7 +380,7 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
return tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
util.CmdHandler(ModelSelectedMsg{
- Model: config.SelectedModel{
+ Model: agent.Model{
Model: selectedModel.Model.ID,
Provider: string(selectedModel.Provider.ID),
},
@@ -117,29 +117,28 @@ func New(app *app.App) ChatPage {
return &chatPage{
app: app,
keyMap: DefaultKeyMap(),
- header: header.New(app.LSPClients),
- sidebar: sidebar.New(app.History, app.LSPClients, false),
+ header: header.New(app),
+ sidebar: sidebar.New(app, false),
chat: chat.New(app),
editor: editor.New(app),
- splash: splash.New(),
+ splash: splash.New(app.Config()),
focusedPane: PanelTypeSplash,
}
}
func (p *chatPage) Init() tea.Cmd {
- cfg := config.Get()
- compact := cfg.Options.TUI.CompactMode
+ compact := p.app.Config().Options.TUI.CompactMode
p.compact = compact
p.forceCompact = compact
p.sidebar.SetCompactMode(p.compact)
// Set splash state based on config
- if !config.HasInitialDataConfig() {
+ if !config.HasInitialDataConfig(p.app.Config()) {
// First-time setup: show model selection
p.splash.SetOnboarding(true)
p.isOnboarding = true
p.splashFullScreen = true
- } else if b, _ := config.ProjectNeedsInitialization(); b {
+ } else if b, _ := config.ProjectNeedsInitialization(p.app.Config()); b {
// Project needs CRUSH.md initialization
p.splash.SetProjectInit(true)
p.isProjectInit = true
@@ -275,7 +274,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
case splash.OnboardingCompleteMsg:
p.splashFullScreen = false
- if b, _ := config.ProjectNeedsInitialization(); b {
+ if b, _ := config.ProjectNeedsInitialization(p.app.Config()); b {
p.splash.SetProjectInit(true)
p.splashFullScreen = true
return p, p.SetSize(p.width, p.height)
@@ -296,8 +295,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
return p, p.newSession()
case key.Matches(msg, p.keyMap.AddAttachment):
- agentCfg := config.Get().Agents["coder"]
- model := config.Get().GetModelByType(agentCfg.Model)
+ model := p.app.CoderAgent.Model()
if model.SupportsImages {
return p, util.CmdHandler(OpenFilePickerMsg{})
} else {
@@ -441,7 +439,7 @@ func (p *chatPage) View() string {
func (p *chatPage) updateCompactConfig(compact bool) tea.Cmd {
return func() tea.Msg {
- err := config.Get().SetCompactMode(compact)
+ err := p.app.Config().SetCompactMode(compact)
if err != nil {
return util.InfoMsg{
Type: util.InfoTypeError,
@@ -454,13 +452,11 @@ func (p *chatPage) updateCompactConfig(compact bool) tea.Cmd {
func (p *chatPage) toggleThinking() tea.Cmd {
return func() tea.Msg {
- cfg := config.Get()
- agentCfg := cfg.Agents["coder"]
- currentModel := cfg.Models[agentCfg.Model]
+ currentModel := p.app.CoderAgent.ModelConfig()
// Toggle the thinking mode
currentModel.Think = !currentModel.Think
- cfg.Models[agentCfg.Model] = currentModel
+ p.app.Config().Models[config.SelectedModelTypeLarge] = currentModel
// Update the agent with the new configuration
if err := p.app.UpdateAgentModel(); err != nil {
@@ -93,7 +93,7 @@ func (a appModel) Init() tea.Cmd {
func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
- a.isConfigured = config.HasInitialDataConfig()
+ a.isConfigured = config.HasInitialDataConfig(a.app.Config())
switch msg := msg.(type) {
case tea.KeyboardEnhancementsMsg:
@@ -162,7 +162,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case commands.SwitchModelMsg:
return a, util.CmdHandler(
dialogs.OpenDialogMsg{
- Model: models.NewModelDialogCmp(),
+ Model: models.NewModelDialogCmp(a.app.Config()),
},
)
// Compact
@@ -173,7 +173,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Model Switch
case models.ModelSelectedMsg:
- config.Get().UpdatePreferredModel(msg.ModelType, msg.Model)
+ a.app.Config().UpdatePreferredModel(msg.ModelType, msg.Model)
// Update the agent with the new model/provider configuration
if err := a.app.UpdateAgentModel(); err != nil {
@@ -234,7 +234,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
model := a.app.CoderAgent.Model()
contextWindow := model.ContextWindow
tokens := session.CompletionTokens + session.PromptTokens
- if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog
+ if (tokens >= int64(float64(contextWindow)*0.95)) && !a.app.Config().Options.DisableAutoSummarize { // Show compact confirmation dialog
cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{
Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false),
}))
@@ -244,7 +244,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return a, tea.Batch(cmds...)
case splash.OnboardingCompleteMsg:
- a.isConfigured = config.HasInitialDataConfig()
+ a.isConfigured = config.HasInitialDataConfig(a.app.Config())
updated, pageCmd := a.pages[a.currentPage].Update(msg)
a.pages[a.currentPage] = updated.(util.Model)
cmds = append(cmds, pageCmd)
@@ -348,7 +348,7 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
return nil
}
return util.CmdHandler(dialogs.OpenDialogMsg{
- Model: commands.NewCommandDialog(a.selectedSessionID),
+ Model: commands.NewCommandDialog(a.app, a.selectedSessionID),
})
case key.Matches(msg, a.keyMap.Sessions):
// if the app is not configured show no sessions
@@ -9,7 +9,7 @@ import (
_ "github.com/joho/godotenv/autoload" // automatically load .env files
- "github.com/charmbracelet/crush/internal/cmd"
+ "github.com/charmbracelet/crush/cmd"
"github.com/charmbracelet/crush/internal/log"
)