Detailed changes
@@ -36,7 +36,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, lspCfg
app.updateLSPState(name, lsp.StateStarting, nil, 0)
// Create LSP client.
- lspClient, err := lsp.New(ctx, app.config, name, lspCfg, app.config.Resolver())
+ lspClient, err := lsp.New(ctx, app.config, name, lspCfg, config.OsShellResolver)
if err != nil {
slog.Error("Failed to create LSP client for", name, err)
app.updateLSPState(name, lsp.StateError, err, 0)
@@ -220,6 +220,7 @@ func setupApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Inst
DataDir: dataDir,
Debug: debug,
YOLO: yolo,
+ Env: os.Environ(),
})
if err != nil {
return nil, nil, fmt.Errorf("failed to create or connect to instance: %v", err)
@@ -16,6 +16,12 @@ import (
"github.com/tidwall/sjson"
)
+// OsShellResolver is a shell resolver that uses the current process environment
+// variables.
+//
+// Deprecated: pass resolver explicitly instead.
+var OsShellResolver = NewShellVariableResolver(os.Environ())
+
const (
defaultDataDirectory = ".crush"
)
@@ -267,7 +273,6 @@ type Config struct {
// TODO: most likely remove this concept when I come back to it
Agents map[string]Agent `json:"-"`
// TODO: find a better way to do this this should probably not be part of the config
- resolver VariableResolver
dataConfigDir string `json:"-"`
knownProviders []catwalk.Provider `json:"-"`
}
@@ -345,13 +350,6 @@ func (c *Config) SetCompactMode(enabled bool) error {
return c.SetConfigField("options.tui.compact_mode", enabled)
}
-func (c *Config) Resolve(key string) (string, error) {
- if c.resolver == nil {
- return "", fmt.Errorf("no variable resolver configured")
- }
- return c.resolver.ResolveValue(key)
-}
-
func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
c.Models[modelType] = model
if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
@@ -494,10 +492,6 @@ func (c *Config) SetupAgents() {
c.Agents = agents
}
-func (c *Config) Resolver() VariableResolver {
- return c.resolver
-}
-
func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
testURL := ""
headers := make(map[string]string)
@@ -67,7 +67,6 @@ func Load(workingDir, dataDir string, debug bool, envs []string) (*Config, error
// Configure providers
valueResolver := NewShellVariableResolver(envs)
- cfg.resolver = valueResolver
if err := cfg.configureProviders(envs, valueResolver, cfg.knownProviders); err != nil {
return nil, fmt.Errorf("failed to configure providers: %w", err)
}
@@ -17,13 +17,13 @@ type Shell interface {
Exec(ctx context.Context, command string) (stdout, stderr string, err error)
}
-type shellVariableResolver struct {
+type ShellVariableResolver struct {
shell Shell
env []string
}
-func NewShellVariableResolver(env []string) VariableResolver {
- return &shellVariableResolver{
+func NewShellVariableResolver(env []string) *ShellVariableResolver {
+ return &ShellVariableResolver{
env: env,
shell: shell.NewShell(
&shell.Options{
@@ -38,7 +38,7 @@ func NewShellVariableResolver(env []string) VariableResolver {
// - $(command) for command substitution
// - $VAR or ${VAR} for environment variables
// TODO: can we replace this with [os.Expand](https://pkg.go.dev/os#Expand) somehow?
-func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
+func (r *ShellVariableResolver) ResolveValue(value string) (string, error) {
// Special case: lone $ is an error (backward compatibility)
if value == "$" {
return "", fmt.Errorf("invalid value format: %s", value)
@@ -76,7 +76,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testEnv := environ(tt.envVars)
- resolver := &shellVariableResolver{
+ resolver := &ShellVariableResolver{
shell: &mockShell{execFunc: tt.shellFunc},
env: testEnv,
}
@@ -241,7 +241,7 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testEnv := environ(tt.envVars)
- resolver := &shellVariableResolver{
+ resolver := &ShellVariableResolver{
shell: &mockShell{execFunc: tt.shellFunc},
env: testEnv,
}
@@ -136,7 +136,7 @@ func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*cl
}
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
- c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
+ c, err = createAndInitializeClient(ctx, name, m, config.OsShellResolver)
if err != nil {
return nil, err
}
@@ -288,7 +288,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
defer cancel()
- c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
+ c, err := createAndInitializeClient(ctx, name, m, config.OsShellResolver)
if err != nil {
return
}
@@ -81,7 +81,7 @@ func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) a
}
if opts.baseURL != "" {
- resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL)
+ resolvedBaseURL, err := opts.resolver.ResolveValue(opts.baseURL)
if err == nil && resolvedBaseURL != "" {
anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(resolvedBaseURL))
}
@@ -496,7 +496,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
if apiErr.StatusCode == http.StatusUnauthorized {
prev := a.providerOptions.apiKey
// in case the key comes from a script, we try to re-evaluate it.
- a.providerOptions.apiKey, err = a.providerOptions.cfg.Resolve(a.providerOptions.config.APIKey)
+ a.providerOptions.apiKey, err = a.providerOptions.resolver.ResolveValue(a.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -438,7 +438,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
prev := g.providerOptions.apiKey
// in case the key comes from a script, we try to re-evaluate it.
- g.providerOptions.apiKey, err = g.providerOptions.cfg.Resolve(g.providerOptions.config.APIKey)
+ g.providerOptions.apiKey, err = g.providerOptions.resolver.ResolveValue(g.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -43,7 +43,7 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
if opts.baseURL != "" {
- resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL)
+ resolvedBaseURL, err := opts.resolver.ResolveValue(opts.baseURL)
if err == nil && resolvedBaseURL != "" {
openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
}
@@ -517,7 +517,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
if apiErr.StatusCode == http.StatusUnauthorized {
prev := o.providerOptions.apiKey
// in case the key comes from a script, we try to re-evaluate it.
- o.providerOptions.apiKey, err = o.providerOptions.cfg.Resolve(o.providerOptions.config.APIKey)
+ o.providerOptions.apiKey, err = o.providerOptions.resolver.ResolveValue(o.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -64,6 +64,8 @@ type Provider interface {
type providerClientOptions struct {
cfg *config.Config
+ resolver config.VariableResolver
+
baseURL string
config config.ProviderConfig
apiKey string
@@ -141,30 +143,17 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
}
}
-func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
- restore := config.PushPopCrushEnv()
- defer restore()
- resolvedAPIKey, err := cfg.Resolve(pcfg.APIKey)
- if err != nil {
- return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
- }
-
- // Resolve extra headers
- resolvedExtraHeaders := make(map[string]string)
- for key, value := range pcfg.ExtraHeaders {
- resolvedValue, err := cfg.Resolve(value)
- if err != nil {
- return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
- }
- resolvedExtraHeaders[key] = resolvedValue
+func WithResolver(resolver config.VariableResolver) ProviderClientOption {
+ return func(options *providerClientOptions) {
+ options.resolver = resolver
}
+}
+func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
clientOptions := providerClientOptions{
cfg: cfg,
baseURL: pcfg.BaseURL,
config: pcfg,
- apiKey: resolvedAPIKey,
- extraHeaders: resolvedExtraHeaders,
extraBody: pcfg.ExtraBody,
extraParams: pcfg.ExtraParams,
systemPromptPrefix: pcfg.SystemPromptPrefix,
@@ -175,6 +164,30 @@ func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...Provide
for _, o := range opts {
o(&clientOptions)
}
+ if clientOptions.resolver == nil {
+ clientOptions.resolver = config.OsShellResolver
+ }
+
+ restore := config.PushPopCrushEnv()
+ defer restore()
+ resolvedAPIKey, err := clientOptions.resolver.ResolveValue(pcfg.APIKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
+ }
+
+ // Resolve extra headers
+ resolvedExtraHeaders := make(map[string]string)
+ for key, value := range pcfg.ExtraHeaders {
+ resolvedValue, err := clientOptions.resolver.ResolveValue(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
+ }
+ resolvedExtraHeaders[key] = resolvedValue
+ }
+
+ clientOptions.apiKey = resolvedAPIKey
+ clientOptions.extraHeaders = resolvedExtraHeaders
+
switch pcfg.Type {
case catwalk.TypeAnthropic:
return &baseProvider[AnthropicClient]{
@@ -17,6 +17,13 @@ type Instance struct {
Debug bool `json:"debug,omitempty"`
DataDir string `json:"data_dir,omitempty"`
Config *config.Config `json:"config,omitempty"`
+ Env []string `json:"env,omitempty"`
+}
+
+// ShellResolver returns a new [config.ShellResolver] based on the instance's
+// environment variables.
+func (i Instance) ShellResolver() *config.ShellVariableResolver {
+ return config.NewShellVariableResolver(i.Env)
}
// Error represents an error response.
@@ -460,6 +460,19 @@ func (c *controllerV1) handleGetInstancePermissionsSkip(w http.ResponseWriter, r
jsonEncode(w, proto.PermissionSkipRequest{Skip: skip})
}
+func (c *controllerV1) handleGetInstanceProviders(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ins, ok := c.instances.Get(id)
+ if !ok {
+ c.logError(r, "instance not found", "id", id)
+ jsonError(w, http.StatusNotFound, "instance not found")
+ return
+ }
+
+ providers, _ := config.Providers(ins.cfg)
+ jsonEncode(w, providers)
+}
+
func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
flusher := http.NewResponseController(w)
id := r.PathValue("id")
@@ -587,6 +600,7 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques
id: id,
path: args.Path,
cfg: cfg,
+ env: args.Env,
}
c.instances.Set(id, ins)
@@ -597,6 +611,7 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques
Debug: cfg.Options.Debug,
YOLO: cfg.Permissions.SkipRequests,
Config: cfg,
+ Env: args.Env,
})
}
@@ -40,6 +40,7 @@ type Instance struct {
cfg *config.Config
id string
path string
+ env []string
}
// ParseHostURL parses a host URL into a [url.URL].
@@ -134,6 +135,7 @@ func NewServer(cfg *config.Config, network, address string) *Server {
mux.HandleFunc("GET /v1/instances/{id}", c.handleGetInstance)
mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
+ mux.HandleFunc("GET /v1/instances/{id}/providers", c.handleGetInstanceProviders)
mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession)
@@ -218,7 +218,7 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}),
func() tea.Msg {
start := time.Now()
- err := providerConfig.TestConnection(s.ins.Config.Resolver())
+ err := providerConfig.TestConnection(s.ins.ShellResolver())
// intentionally wait for at least 750ms to make sure the user sees the spinner
elapsed := time.Since(start)
if elapsed < 750*time.Millisecond {
@@ -10,6 +10,7 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
"github.com/charmbracelet/crush/internal/tui/exp/list"
@@ -68,10 +69,10 @@ type modelDialogCmp struct {
isAPIKeyValid bool
apiKeyValue string
- cfg *config.Config
+ ins *proto.Instance
}
-func NewModelDialogCmp(cfg *config.Config) ModelDialog {
+func NewModelDialogCmp(ins *proto.Instance) ModelDialog {
keyMap := DefaultKeyMap()
listKeyMap := list.DefaultKeyMap()
@@ -81,7 +82,7 @@ func NewModelDialogCmp(cfg *config.Config) ModelDialog {
listKeyMap.UpOneItem = keyMap.Previous
t := styles.CurrentTheme()
- modelList := NewModelListComponent(cfg, listKeyMap, largeModelInputPlaceholder, true)
+ modelList := NewModelListComponent(ins.Config, listKeyMap, largeModelInputPlaceholder, true)
apiKeyInput := NewAPIKeyInput()
apiKeyInput.SetShowTitle(false)
help := help.New()
@@ -93,7 +94,7 @@ func NewModelDialogCmp(cfg *config.Config) ModelDialog {
width: defaultWidth,
keyMap: DefaultKeyMap(),
help: help,
- cfg: cfg,
+ ins: ins,
}
}
@@ -139,7 +140,7 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}),
func() tea.Msg {
start := time.Now()
- err := providerConfig.TestConnection(m.cfg.Resolver())
+ err := providerConfig.TestConnection(m.ins.ShellResolver())
// intentionally wait for at least 750ms to make sure the user sees the spinner
elapsed := time.Since(start)
if elapsed < 750*time.Millisecond {
@@ -347,14 +348,14 @@ func (m *modelDialogCmp) modelTypeRadio() string {
}
func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
- if _, ok := m.cfg.Providers.Get(providerID); ok {
+ if _, ok := m.ins.Config.Providers.Get(providerID); ok {
return true
}
return false
}
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
- providers, err := config.Providers(m.cfg)
+ providers, err := config.Providers(m.ins.Config)
if err != nil {
return nil, err
}
@@ -371,7 +372,7 @@ func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
return util.ReportError(fmt.Errorf("no model selected"))
}
- err := m.cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
+ err := m.ins.Config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
if err != nil {
return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
}
@@ -178,7 +178,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case commands.SwitchModelMsg:
return a, util.CmdHandler(
dialogs.OpenDialogMsg{
- Model: models.NewModelDialogCmp(a.ins.Config),
+ Model: models.NewModelDialogCmp(a.ins),
},
)
// Compact