diff --git a/internal/agent/tools/mcp/init_test.go b/internal/agent/tools/mcp/init_test.go index 5e906054f649aa90083fcbb120cff556d51dd06c..e8473124ba3f006031c716df54ff998ab997226e 100644 --- a/internal/agent/tools/mcp/init_test.go +++ b/internal/agent/tools/mcp/init_test.go @@ -376,7 +376,7 @@ func TestCreateTransport_HeadersResolution(t *testing.T) { t.Run("sse unset var header drops silently", func(t *testing.T) { t.Parallel() - // Pinning test for design decision #18 + lenient nounset: + // Pinning test for empty-header drop + lenient nounset: // a header whose value resolves to "" (here because the // bare $VAR is unset) is omitted from the round tripper // rather than sent as "X-Header:". Guards against a @@ -488,10 +488,9 @@ func TestCreateSession_ResolutionFailureUpdatesState(t *testing.T) { }, { // Bare $MISSING in a header resolves to "" silently - // and is then dropped (design decision #18). The - // "header Authorization" wrap only surfaces on a - // $(cmd) failure; that is what this subtest now - // pins for the SSE path. + // and is then dropped. The "header Authorization" + // wrap only surfaces on a $(cmd) failure; that is + // what this subtest now pins for the SSE path. name: "sse header failure", mcpName: "test-sse-header-fail", cfg: config.MCPConfig{ diff --git a/internal/client/config.go b/internal/client/config.go index 64c45ecd91cfabbe6727695071e7a62e5fe435ba..e882464969eab7bfdbc428c0281fb12e38ab7347 100644 --- a/internal/client/config.go +++ b/internal/client/config.go @@ -8,6 +8,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/proto" ) // SetConfigField sets a config key/value pair on the server. @@ -76,13 +77,43 @@ func (c *Client) SetCompactMode(ctx context.Context, id string, scope config.Sco return nil } -// SetProviderAPIKey sets a provider API key on the server. +// SetProviderAPIKey sets a provider API key on the server. The wire +// format tags the credential with an explicit Kind so the server can +// decode it back into the right Go type — JSON's `any` loses that +// information across the socket. func (c *Client) SetProviderAPIKey(ctx context.Context, id string, scope config.Scope, providerID string, apiKey any) error { - rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/provider-key", id), nil, jsonBody(struct { - Scope config.Scope `json:"scope"` - ProviderID string `json:"provider_id"` - APIKey any `json:"api_key"` - }{Scope: scope, ProviderID: providerID, APIKey: apiKey}), http.Header{"Content-Type": []string{"application/json"}}) + var ( + kind proto.APIKeyKind + raw json.RawMessage + ) + switch v := apiKey.(type) { + case string: + kind = proto.APIKeyKindString + b, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal api key string: %w", err) + } + raw = b + case *oauth.Token: + if v == nil { + return fmt.Errorf("oauth token is nil") + } + kind = proto.APIKeyKindOAuth + b, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal oauth token: %w", err) + } + raw = b + default: + return fmt.Errorf("unsupported api key type %T", apiKey) + } + + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/provider-key", id), nil, jsonBody(proto.ConfigProviderKeyRequest{ + Scope: scope, + ProviderID: providerID, + Kind: kind, + APIKey: raw, + }), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return fmt.Errorf("failed to set provider API key: %w", err) } diff --git a/internal/client/config_test.go b/internal/client/config_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bd07d5f9649002a18efb1c85b51e1197cda1e066 --- /dev/null +++ b/internal/client/config_test.go @@ -0,0 +1,101 @@ +package client + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/proto" + "github.com/stretchr/testify/require" +) + +// captureClient returns a Client that talks to the given test server, +// plus a channel receiving the parsed request body for each call. +func captureClient(t *testing.T, srv *httptest.Server) *Client { + t.Helper() + u, err := url.Parse(srv.URL) + require.NoError(t, err) + c, err := NewClient(t.TempDir(), "tcp", u.Host) + require.NoError(t, err) + return c +} + +func TestSetProviderAPIKeyStringSendsKind(t *testing.T) { + t.Parallel() + + var got proto.ConfigProviderKeyRequest + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &got)) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + c := captureClient(t, srv) + require.NoError(t, c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "openai", "sk-xyz")) + + require.Equal(t, proto.APIKeyKindString, got.Kind) + require.Equal(t, "openai", got.ProviderID) + require.Equal(t, config.ScopeGlobal, got.Scope) + decoded, err := got.DecodeAPIKey() + require.NoError(t, err) + require.Equal(t, "sk-xyz", decoded) +} + +func TestSetProviderAPIKeyOAuthSendsKind(t *testing.T) { + t.Parallel() + + var got proto.ConfigProviderKeyRequest + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &got)) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + tok := &oauth.Token{AccessToken: "a", RefreshToken: "r", ExpiresIn: 60, ExpiresAt: 1234567890} + c := captureClient(t, srv) + require.NoError(t, c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "hyper", tok)) + + require.Equal(t, proto.APIKeyKindOAuth, got.Kind) + decoded, err := got.DecodeAPIKey() + require.NoError(t, err) + require.Equal(t, tok, decoded.(*oauth.Token)) +} + +func TestSetProviderAPIKeyUnsupportedTypeFailsLocally(t *testing.T) { + t.Parallel() + + called := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "x", 42) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported api key type") + require.False(t, called, "server should not have been reached") +} + +func TestSetProviderAPIKeyNilOAuthFailsLocally(t *testing.T) { + t.Parallel() + + c := captureClient(t, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + + var tok *oauth.Token + err := c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "x", tok) + require.Error(t, err) +} diff --git a/internal/config/config.go b/internal/config/config.go index 6620120fa07404a3858bf1502e2c8109bf5510a3..002595b6f774ff4784863f180c4b8a7f3d478095 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -113,8 +113,7 @@ type ProviderConfig struct { // $(cmd) work the same way they do in MCP headers. A header whose // value resolves to the empty string (unset bare $VAR under // lenient nounset, $(echo), or literal "") is omitted from the - // outgoing request rather than sent as "Header:". See PLAN.md - // Phase 2 design decision #18. + // outgoing request rather than sent as "Header:". ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"description=Additional HTTP headers to send with requests"` // ExtraBody is merged verbatim into OpenAI-compatible request // bodies. String values are NOT shell-expanded: this is a plain @@ -123,7 +122,7 @@ type ProviderConfig struct { // recursive walker guessing at intent. If you need an env-var- // driven value at request time, put it in extra_headers, or in // the provider's top-level api_key / base_url, all of which do - // expand. See PLAN.md Phase 2 design decision #16. + // expand. ExtraBody map[string]any `json:"extra_body,omitempty" jsonschema:"description=Additional fields to include in request bodies\\, only works with openai-compatible providers"` ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for this provider"` @@ -196,7 +195,7 @@ type MCPConfig struct { // work. A header whose value resolves to the empty string (unset // bare $VAR under lenient nounset, $(echo), or literal "") is // omitted from the outgoing request rather than sent as - // "Header:". See PLAN.md Phase 2 design decision #18. + // "Header:". Headers map[string]string `json:"headers,omitempty" jsonschema:"description=HTTP headers for HTTP/SSE MCP servers"` } @@ -398,8 +397,7 @@ func (m MCPConfig) ResolvedURL(r VariableResolver) (string, error) { // under lenient nounset, $(echo), or literal "") is omitted from the // returned map — sending "X-Auth:" with an empty value is rejected by // some providers and the user's intent in "optional, env-gated -// header" is clearly "absent when the var isn't set." See PLAN.md -// Phase 2 design decision #18. +// header" is clearly "absent when the var isn't set." // // See ResolvedEnv for guidance on picking a resolver. func (m MCPConfig) ResolvedHeaders(r VariableResolver) (map[string]string, error) { @@ -435,13 +433,11 @@ func (m MCPConfig) ResolvedHeaders(r VariableResolver) (map[string]string, error // errors.Is/As continues to work. // // Empty resolved values are kept (a deliberate "empty positional arg" -// like --flag "" is sometimes valid), matching MCPConfig.ResolvedArgs; -// see PLAN.md Phase 2 design decision #18. +// like --flag "" is sometimes valid), matching MCPConfig.ResolvedArgs. // // The resolver choice matters: in server mode pass the shell resolver // so $VAR / $(cmd) expand; in client mode pass IdentityResolver so the -// template is forwarded verbatim. See PLAN.md Phase 2 design decision -// #13. +// template is forwarded verbatim. func (l LSPConfig) ResolvedArgs(r VariableResolver) ([]string, error) { if len(l.Args) == 0 { return nil, nil @@ -465,15 +461,13 @@ func (l LSPConfig) ResolvedArgs(r VariableResolver) ([]string, error) { // continues to work. // // Empty resolved values are kept ("FOO=" is a legitimate request; -// opt out via ${VAR:+...}), matching MCPConfig.ResolvedEnv; see -// PLAN.md Phase 2 design decision #18. +// opt out via ${VAR:+...}), matching MCPConfig.ResolvedEnv. // // Shape note: this returns map[string]string rather than the []string // shape MCPConfig.ResolvedEnv uses because the consumer // (powernap.ClientConfig.Environment in internal/lsp/client.go) takes // a map directly — returning a []string here would only force a -// round-trip back to a map at the call site. See PLAN.md Phase 2 -// design decision #13. +// round-trip back to a map at the call site. // // See ResolvedArgs for guidance on picking a resolver. func (l LSPConfig) ResolvedEnv(r VariableResolver) (map[string]string, error) { diff --git a/internal/config/load.go b/internal/config/load.go index 2b7358662393b42a4a2ee2db730403b019bdbd01..f816d2692e14c9baf2e281d852003088ffbe8a5d 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -230,8 +230,7 @@ func (c *Config) configureProviders(store *ConfigStore, env env.Env, resolver Va // a failing $(...) aborts the provider load with a clear // message, and a header that resolves to the empty string // (unset bare $VAR under lenient nounset, $(echo), or literal - // "") is dropped from the outgoing request. See PLAN.md - // Phase 2 design decisions #14 and #18. + // "") is dropped from the outgoing request. for k, v := range headers { resolved, err := resolver.ResolveValue(v) if err != nil { @@ -390,8 +389,7 @@ func (c *Config) configureProviders(store *ConfigStore, env env.Env, resolver Va } // Custom-provider headers share the MCP error contract; see - // the known-provider loop above and PLAN.md Phase 2 design - // decisions #14 and #18. + // the known-provider loop above. for k, v := range providerConfig.ExtraHeaders { resolved, err := resolver.ResolveValue(v) if err != nil { diff --git a/internal/config/load_test.go b/internal/config/load_test.go index cb69976f7f8130e6023fde36629790a86d5138f1..8f7dd60b3189c25caea05448e046142e908cefb0 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -1828,12 +1828,10 @@ func TestConfig_configureProviders_HyperAPIKeyFromConfigOverrides(t *testing.T) require.Equal(t, "env-api-key", pc.APIKey) } -// TestConfig_configureProviders_ProviderHeaderResolveError pins -// Phase 2 design decision #14: a failing $(cmd) in a provider header -// must fail the provider load with a clear message that names the -// offending header. The Phase 1 log-and-continue divergence at -// load.go:225 is gone; provider headers now share the MCP error -// contract. +// TestConfig_configureProviders_ProviderHeaderResolveError verifies +// that a failing $(cmd) in a provider header fails the provider load +// with a clear message that names the offending header. Provider +// headers share the MCP error contract. func TestConfig_configureProviders_ProviderHeaderResolveError(t *testing.T) { knownProviders := []catwalk.Provider{ { @@ -1868,12 +1866,11 @@ func TestConfig_configureProviders_ProviderHeaderResolveError(t *testing.T) { require.Contains(t, err.Error(), "X-Broken", "error must name the offending header") } -// TestConfig_configureProviders_CatwalkDefaultWithUnsetVarLoads pins -// Phase 2 design decisions #11 and #18 from the provider angle: a -// Catwalk-style default header like -// "OpenAI-Organization": "$OPENAI_ORG_ID" must load cleanly under -// lenient nounset (unset → "" → header dropped), not fail the load -// and not leave the literal template on the wire. +// TestConfig_configureProviders_CatwalkDefaultWithUnsetVarLoads +// verifies that a Catwalk-style default header like +// "OpenAI-Organization": "$OPENAI_ORG_ID" loads cleanly under lenient +// nounset (unset → "" → header dropped), and does not fail the load +// or leave the literal template on the wire. func TestConfig_configureProviders_CatwalkDefaultWithUnsetVarLoads(t *testing.T) { knownProviders := []catwalk.Provider{ { @@ -1979,11 +1976,11 @@ func TestConfig_configureProviders_EchoEmptyHeaderDropped(t *testing.T) { require.Equal(t, "present", pc.ExtraHeaders["X-Kept"]) } -// TestConfig_configureProviders_UnsetAPIKeySkipsProvider pins Phase 2 -// Step 12 / design decision #15: under the lenient-nounset shell -// resolver, $UNSET_API_KEY expands to ("", nil) rather than ("", err), -// and the existing `v == "" || err != nil` skip path at load.go:331 -// still drops the provider. The slog.Warn line is emitted on the same +// TestConfig_configureProviders_UnsetAPIKeySkipsProvider verifies that +// under the lenient-nounset shell resolver, $UNSET_API_KEY expands to +// ("", nil) rather than ("", err), and the existing +// `v == "" || err != nil` skip path at load.go:331 still drops the +// provider. The slog.Warn line is emitted on the same // path but is not asserted here — internal/config/load_test.go's // TestMain replaces the default slog handler with an io.Discard // writer, so capturing that log line would require mid-test handler diff --git a/internal/config/mcp_resolved_url_test.go b/internal/config/mcp_resolved_url_test.go index a1f262f2224b53ed8e7dd1f1044152b853ba90db..44d75b0fc6b29ef1a8885308672c0c7e949b82e4 100644 --- a/internal/config/mcp_resolved_url_test.go +++ b/internal/config/mcp_resolved_url_test.go @@ -46,7 +46,7 @@ func TestMCPConfig_ResolvedURL(t *testing.T) { t.Run("unset var expands to empty under lenient default", func(t *testing.T) { t.Parallel() - // Phase 2 defaults to nounset-off: bare $VAR on an unset + // The default is nounset-off: bare $VAR on an unset // variable expands to "" rather than erroring. Here the // host collapses to empty, so the caller sees a malformed // URL rather than a resolver error; that's the expected diff --git a/internal/config/resolve_real_test.go b/internal/config/resolve_real_test.go index 64cecc5769c6574975ae38ecdafd982eadbbe252..ab107c0e22f88e0155d0a3da9c2e5736d0403e89 100644 --- a/internal/config/resolve_real_test.go +++ b/internal/config/resolve_real_test.go @@ -159,21 +159,20 @@ func TestResolvedEnv_RealShell_Deterministic(t *testing.T) { require.True(t, slices.IsSorted(got), "env slice must be sorted; got %v", got) } -// TestResolvedEnv_RealShell_UnsetExpandsEmpty pins Phase 2's lenient +// TestResolvedEnv_RealShell_UnsetExpandsEmpty pins the lenient // default: an unset bare $VAR expands to the empty string, matching -// bash. The silent-empty-credential class of bug that motivated Phase -// 1's nounset-on default is already prevented by the pure-function -// error-returning contract of ResolvedEnv, so we no longer rely on -// nounset to catch typo'd variable names. Users who want strict -// behaviour for a required credential opt in per-reference with -// ${VAR:?msg}; see TestResolvedEnv_RealShell_ColonQuestionIsStrict. +// bash. The silent-empty-credential class of bug is prevented by the +// pure-function error-returning contract of ResolvedEnv, so we don't +// rely on nounset to catch typo'd variable names. Users who want +// strict behaviour for a required credential opt in per-reference +// with ${VAR:?msg}; see TestResolvedEnv_RealShell_ColonQuestionIsStrict. func TestResolvedEnv_RealShell_UnsetExpandsEmpty(t *testing.T) { t.Parallel() m := MCPConfig{Env: map[string]string{ - // Intentional typo: user meant $FORGEJO_TOKEN. Under Phase 2 - // defaults this expands to "" rather than erroring, matching - // bash's behaviour on bare $VAR. + // Intentional typo: user meant $FORGEJO_TOKEN. Under the + // lenient default this expands to "" rather than erroring, + // matching bash's behaviour on bare $VAR. "FORGEJO_ACCESS_TOKEN": "$FORGJO_TOKEN", }} got, err := m.ResolvedEnv(realShellResolver(nil)) @@ -259,13 +258,12 @@ func TestResolvedHeaders_RealShell_FailurePreservesOriginal(t *testing.T) { require.Equal(t, orig, m.Headers, "receiver Headers must be preserved") } -// TestResolvedHeaders_RealShell_DropEmpty pins Phase 2 design -// decision #18 on the MCP side: a header whose value resolves to the -// empty string is omitted from the returned map. Covers the three -// ways a value can legitimately land on empty — unset bare $VAR -// under lenient nounset, a literal "", and a non-failing command -// whose stdout is empty — and also pins that a failing $(cmd) still -// errors rather than silently dropping. +// TestResolvedHeaders_RealShell_DropEmpty verifies that a header +// whose value resolves to the empty string is omitted from the +// returned map. Covers the three ways a value can legitimately land +// on empty — unset bare $VAR under lenient nounset, a literal "", +// and a non-failing command whose stdout is empty — and also pins +// that a failing $(cmd) still errors rather than silently dropping. func TestResolvedHeaders_RealShell_DropEmpty(t *testing.T) { t.Parallel() diff --git a/internal/proto/requests.go b/internal/proto/requests.go index fe5327079b4f6514c23b34ce5c8e00666c75ba43..e66807def95ee816e731b62af78e058872e574d4 100644 --- a/internal/proto/requests.go +++ b/internal/proto/requests.go @@ -1,6 +1,12 @@ package proto -import "github.com/charmbracelet/crush/internal/config" +import ( + "encoding/json" + "fmt" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" +) // ConfigSetRequest represents a request to set a config field. type ConfigSetRequest struct { @@ -28,11 +34,50 @@ type ConfigCompactRequest struct { Enabled bool `json:"enabled"` } -// ConfigProviderKeyRequest represents a request to set a provider API key. +// APIKeyKind discriminates the kind of credential carried in a +// ConfigProviderKeyRequest. JSON's `any` loses Go type information, so +// the wire format names the kind explicitly and the server decodes +// APIKey accordingly. +type APIKeyKind string + +const ( + // APIKeyKindString is a plain string API key. + APIKeyKindString APIKeyKind = "string" + // APIKeyKindOAuth is an oauth.Token credential. + APIKeyKindOAuth APIKeyKind = "oauth" +) + +// ConfigProviderKeyRequest represents a request to set a provider API +// key. APIKey is the raw JSON for the credential; Kind selects the +// concrete Go type APIKey should be decoded into via DecodeAPIKey. type ConfigProviderKeyRequest struct { - Scope config.Scope `json:"scope"` - ProviderID string `json:"provider_id"` - APIKey any `json:"api_key"` + Scope config.Scope `json:"scope"` + ProviderID string `json:"provider_id"` + Kind APIKeyKind `json:"kind"` + APIKey json.RawMessage `json:"api_key"` +} + +// DecodeAPIKey decodes APIKey into the Go type indicated by Kind. It +// returns a string for APIKeyKindString and a *oauth.Token for +// APIKeyKindOAuth. An unknown kind or malformed payload is reported +// as an error. +func (r ConfigProviderKeyRequest) DecodeAPIKey() (any, error) { + switch r.Kind { + case APIKeyKindString: + var s string + if err := json.Unmarshal(r.APIKey, &s); err != nil { + return nil, fmt.Errorf("decode api key string: %w", err) + } + return s, nil + case APIKeyKindOAuth: + var tok oauth.Token + if err := json.Unmarshal(r.APIKey, &tok); err != nil { + return nil, fmt.Errorf("decode api key oauth token: %w", err) + } + return &tok, nil + default: + return nil, fmt.Errorf("unsupported api key kind %q", r.Kind) + } } // ConfigRefreshOAuthRequest represents a request to refresh an OAuth token. diff --git a/internal/proto/requests_test.go b/internal/proto/requests_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8ca718004b90b374df26d7d91ef1900534de8d49 --- /dev/null +++ b/internal/proto/requests_test.go @@ -0,0 +1,106 @@ +package proto_test + +import ( + "encoding/json" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/proto" + "github.com/stretchr/testify/require" +) + +func TestConfigProviderKeyRequestStringRoundTrip(t *testing.T) { + t.Parallel() + + apiKey, err := json.Marshal("sk-test-123") + require.NoError(t, err) + + src := proto.ConfigProviderKeyRequest{ + Scope: config.ScopeGlobal, + ProviderID: "openai", + Kind: proto.APIKeyKindString, + APIKey: apiKey, + } + b, err := json.Marshal(src) + require.NoError(t, err) + + var got proto.ConfigProviderKeyRequest + require.NoError(t, json.Unmarshal(b, &got)) + require.Equal(t, proto.APIKeyKindString, got.Kind) + + decoded, err := got.DecodeAPIKey() + require.NoError(t, err) + s, ok := decoded.(string) + require.True(t, ok, "expected string, got %T", decoded) + require.Equal(t, "sk-test-123", s) +} + +func TestConfigProviderKeyRequestOAuthRoundTrip(t *testing.T) { + t.Parallel() + + tok := &oauth.Token{ + AccessToken: "access", + RefreshToken: "refresh", + ExpiresIn: 60, + ExpiresAt: 1234567890, + } + apiKey, err := json.Marshal(tok) + require.NoError(t, err) + + src := proto.ConfigProviderKeyRequest{ + Scope: config.ScopeGlobal, + ProviderID: "hyper", + Kind: proto.APIKeyKindOAuth, + APIKey: apiKey, + } + b, err := json.Marshal(src) + require.NoError(t, err) + + var got proto.ConfigProviderKeyRequest + require.NoError(t, json.Unmarshal(b, &got)) + require.Equal(t, proto.APIKeyKindOAuth, got.Kind) + + decoded, err := got.DecodeAPIKey() + require.NoError(t, err) + gotTok, ok := decoded.(*oauth.Token) + require.True(t, ok, "expected *oauth.Token, got %T", decoded) + require.Equal(t, tok, gotTok) +} + +func TestConfigProviderKeyRequestUnknownKind(t *testing.T) { + t.Parallel() + + req := proto.ConfigProviderKeyRequest{ + Kind: proto.APIKeyKind("bogus"), + APIKey: json.RawMessage(`"x"`), + } + _, err := req.DecodeAPIKey() + require.Error(t, err) + require.Contains(t, err.Error(), "bogus") +} + +func TestConfigProviderKeyRequestMalformedPayload(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + kind proto.APIKeyKind + raw string + }{ + {"string kind with object payload", proto.APIKeyKindString, `{"foo":"bar"}`}, + {"oauth kind with string payload", proto.APIKeyKindOAuth, `"not-a-token"`}, + {"oauth kind with invalid json", proto.APIKeyKindOAuth, `{`}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + req := proto.ConfigProviderKeyRequest{ + Kind: tc.kind, + APIKey: json.RawMessage(tc.raw), + } + _, err := req.DecodeAPIKey() + require.Error(t, err) + }) + } +} diff --git a/internal/server/config.go b/internal/server/config.go index b5277d08e32935a59ee0748809b0868f3dc3b5a9..cd96c3603fc94a41aa0d3ae4607e54fb487531ba 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -145,7 +145,14 @@ func (c *controllerV1) handlePostWorkspaceConfigProviderKey(w http.ResponseWrite return } - if err := c.backend.SetProviderAPIKey(id, req.Scope, req.ProviderID, req.APIKey); err != nil { + apiKey, err := req.DecodeAPIKey() + if err != nil { + c.server.logError(r, "Failed to decode api key", "error", err, "kind", req.Kind) + jsonError(w, http.StatusBadRequest, err.Error()) + return + } + + if err := c.backend.SetProviderAPIKey(id, req.Scope, req.ProviderID, apiKey); err != nil { c.handleError(w, r, err) return } diff --git a/internal/server/events_test.go b/internal/server/events_test.go new file mode 100644 index 0000000000000000000000000000000000000000..80b9428c104651bc3a372bf614403177fe2ab5d7 --- /dev/null +++ b/internal/server/events_test.go @@ -0,0 +1,46 @@ +package server + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/stretchr/testify/require" +) + +// TestMessageToProtoToolResult ensures that ToolResult metadata, +// data, and MIME type survive the conversion to proto. Without these +// fields the TUI cannot render rich tool output (e.g. syntax- +// highlighted code from view, diffs from edit, images, etc.) and +// falls back to the raw LLM-facing string. +func TestMessageToProtoToolResult(t *testing.T) { + t.Parallel() + + src := message.Message{ + ID: "m1", + Role: message.Tool, + Parts: []message.ContentPart{ + message.ToolResult{ + ToolCallID: "call-1", + Name: "view", + Content: "\n 1| hi\n", + Data: "base64data", + MIMEType: "image/png", + Metadata: `{"file_path":"/tmp/x","content":"hi"}`, + IsError: false, + }, + }, + } + + got := messageToProto(src) + require.Len(t, got.Parts, 1) + tr, ok := got.Parts[0].(proto.ToolResult) + require.True(t, ok, "expected proto.ToolResult, got %T", got.Parts[0]) + require.Equal(t, "call-1", tr.ToolCallID) + require.Equal(t, "view", tr.Name) + require.Equal(t, "\n 1| hi\n", tr.Content) + require.Equal(t, "base64data", tr.Data) + require.Equal(t, "image/png", tr.MIMEType) + require.Equal(t, `{"file_path":"/tmp/x","content":"hi"}`, tr.Metadata) + require.False(t, tr.IsError) +} diff --git a/internal/shell/expand.go b/internal/shell/expand.go index 5b31c9009fae156de00819718f2d341fcc9cb880..bef83310bfb1546af05bf8b4317ed2df9694fc9c 100644 --- a/internal/shell/expand.go +++ b/internal/shell/expand.go @@ -31,9 +31,6 @@ const maxInnerStderrBytes = 512 // memory model regardless of test-level happens-before reasoning. The // atomic load on the hot path is negligible against the cost of parsing // and running through mvdan. -// -// See PLAN.md Phase 2 design decisions #11 and #12 for the full -// rationale. var NoUnset atomic.Bool // ExpandValue expands shell-style substitutions in a single config value. diff --git a/internal/swagger/docs.go b/internal/swagger/docs.go index ec106954f54e055d1ebe9ff64db9c63454a644e2..3492035a038d63bac6a80703d00d37b5fa85c8b9 100644 --- a/internal/swagger/docs.go +++ b/internal/swagger/docs.go @@ -1418,6 +1418,74 @@ const docTemplate = `{ } } }, + "/workspaces/{id}/mcp/docker/disable": { + "post": { + "tags": [ + "mcp" + ], + "summary": "Disable Docker MCP", + "parameters": [ + { + "type": "string", + "description": "Workspace ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/proto.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/proto.Error" + } + } + } + } + }, + "/workspaces/{id}/mcp/docker/enable": { + "post": { + "tags": [ + "mcp" + ], + "summary": "Enable Docker MCP", + "parameters": [ + { + "type": "string", + "description": "Workspace ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/proto.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/proto.Error" + } + } + } + } + }, "/workspaces/{id}/mcp/get-prompt": { "post": { "consumes": [ @@ -2618,6 +2686,23 @@ const docTemplate = `{ } } }, + "config.HookConfig": { + "type": "object", + "properties": { + "command": { + "description": "Shell command to execute.", + "type": "string" + }, + "matcher": { + "description": "Regex pattern tested against the tool name. Empty means match all.", + "type": "string" + }, + "timeout": { + "description": "Timeout in seconds. Default 30.", + "type": "integer" + } + } + }, "config.LSPConfig": { "type": "object", "properties": { @@ -2691,6 +2776,12 @@ const docTemplate = `{ "type": "string" } }, + "enabled_tools": { + "type": "array", + "items": { + "type": "string" + } + }, "env": { "type": "object", "additionalProperties": { @@ -2698,7 +2789,7 @@ const docTemplate = `{ } }, "headers": { - "description": "TODO: maybe make it possible to get the value from the env", + "description": "Headers are HTTP headers for HTTP/SSE MCP servers. Values run\nthrough shell expansion at MCP startup, so $VAR and $(cmd)\nwork. A header whose value resolves to the empty string (unset\nbare $VAR under lenient nounset, $(echo), or literal \"\") is\nomitted from the outgoing request rather than sent as\n\"Header:\".", "type": "object", "additionalProperties": { "type": "string" @@ -2881,6 +2972,15 @@ const docTemplate = `{ "$schema": { "type": "string" }, + "hooks": { + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "$ref": "#/definitions/config.HookConfig" + } + } + }, "lsp": { "$ref": "#/definitions/config.LSPs" }, @@ -2939,7 +3039,7 @@ const docTemplate = `{ } }, "data_directory": { - "description": "DataDirectory is where Crush keeps per-project state such as the SQLite database and workspace overrides. Relative paths are resolved against the working directory; absolute paths are used verbatim. After defaulting the stored value is always absolute.", + "description": "DataDirectory is where Crush keeps per-project state such as\nthe SQLite database and workspace overrides. Relative paths are\nresolved against the working directory; absolute paths are used\nverbatim. After defaulting the stored value is always absolute.", "type": "string" }, "debug": { @@ -2963,6 +3063,12 @@ const docTemplate = `{ "disable_provider_auto_update": { "type": "boolean" }, + "disabled_skills": { + "type": "array", + "items": { + "type": "string" + } + }, "disabled_tools": { "type": "array", "items": { @@ -3035,6 +3141,17 @@ const docTemplate = `{ "StateDisabled" ] }, + "proto.APIKeyKind": { + "type": "string", + "enum": [ + "string", + "oauth" + ], + "x-enum-varnames": [ + "APIKeyKindString", + "APIKeyKindOAuth" + ] + }, "proto.AgentInfo": { "type": "object", "properties": { @@ -3155,7 +3272,15 @@ const docTemplate = `{ "proto.ConfigProviderKeyRequest": { "type": "object", "properties": { - "api_key": {}, + "api_key": { + "type": "array", + "items": { + "type": "integer" + } + }, + "kind": { + "$ref": "#/definitions/proto.APIKeyKind" + }, "provider_id": { "type": "string" }, @@ -3497,6 +3622,9 @@ const docTemplate = `{ "proto.VersionInfo": { "type": "object", "properties": { + "build_id": { + "type": "string" + }, "commit": { "type": "string" }, diff --git a/internal/swagger/swagger.json b/internal/swagger/swagger.json index b3ccbe22b783b78f508fa1a6b05f38d2f45f8612..8e333ade22bf6aadab57d2aaf2ab77fa6dd0885b 100644 --- a/internal/swagger/swagger.json +++ b/internal/swagger/swagger.json @@ -1411,6 +1411,74 @@ } } }, + "/workspaces/{id}/mcp/docker/disable": { + "post": { + "tags": [ + "mcp" + ], + "summary": "Disable Docker MCP", + "parameters": [ + { + "type": "string", + "description": "Workspace ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/proto.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/proto.Error" + } + } + } + } + }, + "/workspaces/{id}/mcp/docker/enable": { + "post": { + "tags": [ + "mcp" + ], + "summary": "Enable Docker MCP", + "parameters": [ + { + "type": "string", + "description": "Workspace ID", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/proto.Error" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/proto.Error" + } + } + } + } + }, "/workspaces/{id}/mcp/get-prompt": { "post": { "consumes": [ @@ -2611,6 +2679,23 @@ } } }, + "config.HookConfig": { + "type": "object", + "properties": { + "command": { + "description": "Shell command to execute.", + "type": "string" + }, + "matcher": { + "description": "Regex pattern tested against the tool name. Empty means match all.", + "type": "string" + }, + "timeout": { + "description": "Timeout in seconds. Default 30.", + "type": "integer" + } + } + }, "config.LSPConfig": { "type": "object", "properties": { @@ -2684,6 +2769,12 @@ "type": "string" } }, + "enabled_tools": { + "type": "array", + "items": { + "type": "string" + } + }, "env": { "type": "object", "additionalProperties": { @@ -2691,7 +2782,7 @@ } }, "headers": { - "description": "TODO: maybe make it possible to get the value from the env", + "description": "Headers are HTTP headers for HTTP/SSE MCP servers. Values run\nthrough shell expansion at MCP startup, so $VAR and $(cmd)\nwork. A header whose value resolves to the empty string (unset\nbare $VAR under lenient nounset, $(echo), or literal \"\") is\nomitted from the outgoing request rather than sent as\n\"Header:\".", "type": "object", "additionalProperties": { "type": "string" @@ -2874,6 +2965,15 @@ "$schema": { "type": "string" }, + "hooks": { + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "$ref": "#/definitions/config.HookConfig" + } + } + }, "lsp": { "$ref": "#/definitions/config.LSPs" }, @@ -2932,7 +3032,7 @@ } }, "data_directory": { - "description": "DataDirectory is where Crush keeps per-project state such as the SQLite database and workspace overrides. Relative paths are resolved against the working directory; absolute paths are used verbatim. After defaulting the stored value is always absolute.", + "description": "DataDirectory is where Crush keeps per-project state such as\nthe SQLite database and workspace overrides. Relative paths are\nresolved against the working directory; absolute paths are used\nverbatim. After defaulting the stored value is always absolute.", "type": "string" }, "debug": { @@ -2956,6 +3056,12 @@ "disable_provider_auto_update": { "type": "boolean" }, + "disabled_skills": { + "type": "array", + "items": { + "type": "string" + } + }, "disabled_tools": { "type": "array", "items": { @@ -3028,6 +3134,17 @@ "StateDisabled" ] }, + "proto.APIKeyKind": { + "type": "string", + "enum": [ + "string", + "oauth" + ], + "x-enum-varnames": [ + "APIKeyKindString", + "APIKeyKindOAuth" + ] + }, "proto.AgentInfo": { "type": "object", "properties": { @@ -3148,7 +3265,15 @@ "proto.ConfigProviderKeyRequest": { "type": "object", "properties": { - "api_key": {}, + "api_key": { + "type": "array", + "items": { + "type": "integer" + } + }, + "kind": { + "$ref": "#/definitions/proto.APIKeyKind" + }, "provider_id": { "type": "string" }, @@ -3490,6 +3615,9 @@ "proto.VersionInfo": { "type": "object", "properties": { + "build_id": { + "type": "string" + }, "commit": { "type": "string" }, diff --git a/internal/swagger/swagger.yaml b/internal/swagger/swagger.yaml index d1265256c785236bac581401c7c6e44838ff5dae..edd8d0a020931b2e0b4262887098dfa2a0832a53 100644 --- a/internal/swagger/swagger.yaml +++ b/internal/swagger/swagger.yaml @@ -63,6 +63,19 @@ definitions: max_items: type: integer type: object + config.HookConfig: + properties: + command: + description: Shell command to execute. + type: string + matcher: + description: Regex pattern tested against the tool name. Empty means match + all. + type: string + timeout: + description: Timeout in seconds. Default 30. + type: integer + type: object config.LSPConfig: properties: args: @@ -112,6 +125,10 @@ definitions: items: type: string type: array + enabled_tools: + items: + type: string + type: array env: additionalProperties: type: string @@ -119,7 +136,13 @@ definitions: headers: additionalProperties: type: string - description: 'TODO: maybe make it possible to get the value from the env' + description: |- + Headers are HTTP headers for HTTP/SSE MCP servers. Values run + through shell expansion at MCP startup, so $VAR and $(cmd) + work. A header whose value resolves to the empty string (unset + bare $VAR under lenient nounset, $(echo), or literal "") is + omitted from the outgoing request rather than sent as + "Header:". type: object timeout: type: integer @@ -249,6 +272,12 @@ definitions: properties: $schema: type: string + hooks: + additionalProperties: + items: + $ref: '#/definitions/config.HookConfig' + type: array + type: object lsp: $ref: '#/definitions/config.LSPs' mcp: @@ -288,10 +317,10 @@ definitions: type: array data_directory: description: |- - DataDirectory is where Crush keeps per-project state such as the SQLite - database and workspace overrides. Relative paths are resolved against - the working directory; absolute paths are used verbatim. After - defaulting the stored value is always absolute. + DataDirectory is where Crush keeps per-project state such as + the SQLite database and workspace overrides. Relative paths are + resolved against the working directory; absolute paths are used + verbatim. After defaulting the stored value is always absolute. type: string debug: type: boolean @@ -307,6 +336,10 @@ definitions: type: boolean disable_provider_auto_update: type: boolean + disabled_skills: + items: + type: string + type: array disabled_tools: items: type: string @@ -358,6 +391,14 @@ definitions: - StateError - StateStopped - StateDisabled + proto.APIKeyKind: + enum: + - string + - oauth + type: string + x-enum-varnames: + - APIKeyKindString + - APIKeyKindOAuth proto.AgentInfo: properties: is_busy: @@ -436,7 +477,12 @@ definitions: type: object proto.ConfigProviderKeyRequest: properties: - api_key: {} + api_key: + items: + type: integer + type: array + kind: + $ref: '#/definitions/proto.APIKeyKind' provider_id: type: string scope: @@ -664,6 +710,8 @@ definitions: type: object proto.VersionInfo: properties: + build_id: + type: string commit: type: string go_version: @@ -1639,6 +1687,50 @@ paths: summary: Stop all LSP servers tags: - lsp + /workspaces/{id}/mcp/docker/disable: + post: + parameters: + - description: Workspace ID + in: path + name: id + required: true + type: string + responses: + "200": + description: OK + "404": + description: Not Found + schema: + $ref: '#/definitions/proto.Error' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/proto.Error' + summary: Disable Docker MCP + tags: + - mcp + /workspaces/{id}/mcp/docker/enable: + post: + parameters: + - description: Workspace ID + in: path + name: id + required: true + type: string + responses: + "200": + description: OK + "404": + description: Not Found + schema: + $ref: '#/definitions/proto.Error' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/proto.Error' + summary: Enable Docker MCP + tags: + - mcp /workspaces/{id}/mcp/get-prompt: post: consumes: diff --git a/internal/workspace/client_workspace_test.go b/internal/workspace/client_workspace_test.go new file mode 100644 index 0000000000000000000000000000000000000000..43d7e3a0b0554d8028541e91f952797338c3038f --- /dev/null +++ b/internal/workspace/client_workspace_test.go @@ -0,0 +1,46 @@ +package workspace + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/stretchr/testify/require" +) + +// TestProtoToMessageToolResult ensures that ToolResult metadata, +// data, and MIME type survive the conversion from proto on the +// client. Without these fields the TUI cannot render rich tool +// output (e.g. syntax-highlighted code from view, diffs from edit, +// images, etc.) and falls back to the raw LLM-facing string. +func TestProtoToMessageToolResult(t *testing.T) { + t.Parallel() + + src := proto.Message{ + ID: "m1", + Role: proto.Tool, + Parts: []proto.ContentPart{ + proto.ToolResult{ + ToolCallID: "call-1", + Name: "view", + Content: "\n 1| hi\n", + Data: "base64data", + MIMEType: "image/png", + Metadata: `{"file_path":"/tmp/x","content":"hi"}`, + IsError: false, + }, + }, + } + + got := protoToMessage(src) + require.Len(t, got.Parts, 1) + tr, ok := got.Parts[0].(message.ToolResult) + require.True(t, ok, "expected message.ToolResult, got %T", got.Parts[0]) + require.Equal(t, "call-1", tr.ToolCallID) + require.Equal(t, "view", tr.Name) + require.Equal(t, "\n 1| hi\n", tr.Content) + require.Equal(t, "base64data", tr.Data) + require.Equal(t, "image/png", tr.MIMEType) + require.Equal(t, `{"file_path":"/tmp/x","content":"hi"}`, tr.Metadata) + require.False(t, tr.IsError) +}