From a6d56fc095849a43963f0d807b013c5865c0ae81 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 27 Apr 2026 22:10:51 -0400 Subject: [PATCH] fix(config): fix and add API key validatoin for various providers Basically, the /models endpoint is public on many openai-compat providers, so validation was always passing with any string. This likely broke in 7d14abb9 and cce8edf9. Validation has been restored for: - AiHubMix - Avian - Cortecs - HuggingFace - io.net - Quiniu Cloud - Synthetic - Venice New validation added: - Minimax Cannot be validated yet: - Chutes - Neuralwatt --- internal/config/config.go | 340 ++++++++-- internal/config/config_validate_test.go | 776 +++++++++++++++++++++++ internal/ui/dialog/api_key_input.go | 59 +- internal/ui/dialog/api_key_input_test.go | 162 +++++ 4 files changed, 1265 insertions(+), 72 deletions(-) create mode 100644 internal/config/config_validate_test.go create mode 100644 internal/ui/dialog/api_key_input_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 1b8156c68f30b7eddf75d78c414f03662d94f5f2..bd08b5b56ead98ab8b0ce56145ef7b65490950d1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,12 @@ package config import ( + "bytes" "cmp" "context" "errors" "fmt" + "io" "log/slog" "maps" "net/http" @@ -571,100 +573,320 @@ func (c *Config) SetupAgents() { c.Agents = agents } -func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { - var ( - providerID = catwalk.InferenceProvider(c.ID) - testURL = "" - headers = make(map[string]string) - apiKey, _ = resolver.ResolveValue(c.APIKey) - ) +// ErrValidationUnsupported is returned from [ProviderConfig.TestConnection] +// when the provider does not expose a deterministic endpoint that proves API +// key authentication without performing inference. Callers should treat this +// as "saved but not verified" rather than as a validation failure. +var ErrValidationUnsupported = errors.New("provider does not expose a deterministic validation probe") + +// validationProbe describes a single HTTP request used to prove authentication +// for a given provider configuration. +type validationProbe struct { + method string + url string + headers map[string]string + body []byte + classify func(statusCode int) error +} + +// classifyAuthGated treats the probe endpoint as one that is expected to +// return 200 with a valid key and 401/403 with an invalid one. Any other +// status is considered non-deterministic and reported as unsupported so the +// UI can show "not verified" instead of a misleading "invalid key". +func classifyAuthGated(c *ProviderConfig) func(int) error { + return func(status int) error { + switch status { + case http.StatusOK: + return nil + case http.StatusUnauthorized, http.StatusForbidden: + return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status)) + default: + return ErrValidationUnsupported + } + } +} - switch providerID { - case catwalk.InferenceProviderMiniMax, catwalk.InferenceProviderMiniMaxChina: - // NOTE: MiniMax has no good endpoint we can use to validate the API key. - return nil +// classifyOpenAIChatMalformed classifies responses from a deliberately +// malformed POST {baseURL}/chat/completions probe. On most OpenAI-compatible +// gateways authentication happens before schema validation, so 401/403 means +// the key is bad while 400/422 means the key was accepted and only the body +// was rejected. Anything else is treated as unsupported / transient. +func classifyOpenAIChatMalformed(c *ProviderConfig) func(int) error { + return func(status int) error { + switch status { + case http.StatusUnauthorized, http.StatusForbidden: + return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status)) + case http.StatusBadRequest, http.StatusUnprocessableEntity: + return nil + default: + return ErrValidationUnsupported + } } +} - switch c.Type { - case catwalk.TypeOpenAI, catwalk.TypeOpenAICompat, catwalk.TypeOpenRouter: - baseURL, _ := resolver.ResolveValue(c.BaseURL) - baseURL = cmp.Or(baseURL, "https://api.openai.com/v1") - - switch providerID { - case catwalk.InferenceProviderOpenRouter: - testURL = baseURL + "/credits" - case catwalk.InferenceProviderOpenCodeGo: - testURL = strings.Replace(baseURL, "/go", "", 1) + "/models" +// classifyGoogleModels classifies responses from Google's +// `/v1beta/models?key=…` probe. Google returns 400 INVALID_ARGUMENT for a +// malformed or unknown API key, so 400/401/403 all indicate an invalid key. +func classifyGoogleModels(c *ProviderConfig) func(int) error { + return func(status int) error { + switch status { + case http.StatusOK: + return nil + case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden: + return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status)) default: - testURL = baseURL + "/models" + return ErrValidationUnsupported } + } +} - headers["Authorization"] = "Bearer " + apiKey - case catwalk.TypeAnthropic: - baseURL, _ := resolver.ResolveValue(c.BaseURL) - baseURL = cmp.Or(baseURL, "https://api.anthropic.com/v1") +// classifyZAIModels preserves the historical ZAI-specific behaviour: the +// `/models` endpoint returns a variety of non-200 statuses even with a valid +// key, but reliably returns 401 when the key is bad. Treat 401 as invalid +// and anything else as valid (the endpoint is authoritative about bad keys +// but noisy about everything else). +func classifyZAIModels(c *ProviderConfig) func(int) error { + return func(status int) error { + if status == http.StatusUnauthorized { + return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status)) + } + return nil + } +} - switch providerID { - case catwalk.InferenceKimiCoding: - testURL = baseURL + "/v1/models" - default: - testURL = baseURL + "/models" +// openaiCompatModelsAllowlist lists openai-compat providers whose `/models` +// endpoint is known to authenticate the caller (i.e. return 401/403 for a +// bad key rather than 200 with a public listing). New openai-compat +// providers should NOT be added here unless their `/models` behaviour has +// been confirmed to gate on auth — otherwise they should use the malformed +// chat-completions probe or return [ErrValidationUnsupported]. +var openaiCompatModelsAllowlist = map[catwalk.InferenceProvider]struct{}{ + "deepseek": {}, + catwalk.InferenceProviderGROQ: {}, + catwalk.InferenceProviderXAI: {}, + catwalk.InferenceProviderZhipu: {}, + catwalk.InferenceProviderZhipuCoding: {}, + catwalk.InferenceProviderCerebras: {}, + catwalk.InferenceProviderNebius: {}, + catwalk.InferenceProviderCopilot: {}, +} + +// openaiCompatChatProbe builds a malformed-body POST /chat/completions probe +// for OpenAI-compatible providers whose chat-completions endpoint is known to +// gate on auth before validating the request body. +func openaiCompatChatProbe(c *ProviderConfig, baseURL, apiKey string) (*validationProbe, error) { + if baseURL == "" { + return nil, ErrValidationUnsupported + } + return &validationProbe{ + method: http.MethodPost, + url: baseURL + "/chat/completions", + headers: map[string]string{ + "Authorization": "Bearer " + apiKey, + "Content-Type": "application/json", + }, + // Intentionally malformed: required fields missing so the gateway + // rejects the payload after authenticating the caller. + body: []byte(`{"__crush_probe__": true}`), + classify: classifyOpenAIChatMalformed(c), + }, nil +} + +// buildValidationProbe returns the probe to use for this provider, or a +// sentinel error if verification is impossible without performing inference. +// A nil probe with a nil error means "the key is valid by virtue of its +// format and no network probe is necessary" (e.g. Bedrock/Vercel prefix +// checks). +func (c *ProviderConfig) buildValidationProbe(resolver VariableResolver) (*validationProbe, error) { + providerID := catwalk.InferenceProvider(c.ID) + apiKey, _ := resolver.ResolveValue(c.APIKey) + baseURL, _ := resolver.ResolveValue(c.BaseURL) + + // Provider-ID-specific probes take precedence over type-based defaults. + switch providerID { + case catwalk.InferenceProviderMiniMax, catwalk.InferenceProviderMiniMaxChina: + base := cmp.Or(baseURL, "https://api.minimax.io/anthropic") + return &validationProbe{ + method: http.MethodGet, + url: base + "/v1/models", + headers: map[string]string{ + "x-api-key": apiKey, + "anthropic-version": "2023-06-01", + }, + classify: classifyAuthGated(c), + }, nil + case catwalk.InferenceProviderVenice: + base := cmp.Or(baseURL, "https://api.venice.ai/api/v1") + return &validationProbe{ + method: http.MethodGet, + url: base + "/api_keys/rate_limits", + headers: map[string]string{ + "Authorization": "Bearer " + apiKey, + }, + classify: classifyAuthGated(c), + }, nil + case catwalk.InferenceAIHubMix, + catwalk.InferenceProviderAvian, + catwalk.InferenceProviderCortecs, + catwalk.InferenceProviderHuggingFace, + catwalk.InferenceProviderIoNet, + catwalk.InferenceProviderOpenCodeGo, + catwalk.InferenceProviderOpenCodeZen, + catwalk.InferenceProviderQiniuCloud, + catwalk.InferenceProviderSynthetic: + return openaiCompatChatProbe(c, baseURL, apiKey) + case catwalk.InferenceProviderChutes, catwalk.InferenceProviderNeuralwatt: + // These providers have been observed to return ambiguous responses + // for unauthenticated requests, so we cannot safely validate. + return nil, ErrValidationUnsupported + case catwalk.InferenceProviderZAI: + // ZAI's `/models` endpoint is authoritative about bad keys (always + // 401) but returns assorted non-200 statuses for valid keys, so it + // needs its own classifier. + base := baseURL + if base == "" { + return nil, ErrValidationUnsupported } + return &validationProbe{ + method: http.MethodGet, + url: base + "/models", + headers: map[string]string{ + "Authorization": "Bearer " + apiKey, + }, + classify: classifyZAIModels(c), + }, nil + } - headers["x-api-key"] = apiKey - headers["anthropic-version"] = "2023-06-01" + // Type-based defaults for providers without an explicit override. + switch c.Type { + case catwalk.TypeOpenAI: + base := cmp.Or(baseURL, "https://api.openai.com/v1") + return &validationProbe{ + method: http.MethodGet, + url: base + "/models", + headers: map[string]string{ + "Authorization": "Bearer " + apiKey, + }, + classify: classifyAuthGated(c), + }, nil + case catwalk.TypeOpenRouter: + base := cmp.Or(baseURL, "https://openrouter.ai/api/v1") + return &validationProbe{ + method: http.MethodGet, + url: base + "/credits", + headers: map[string]string{ + "Authorization": "Bearer " + apiKey, + }, + classify: classifyAuthGated(c), + }, nil + case catwalk.TypeAnthropic: + base := cmp.Or(baseURL, "https://api.anthropic.com/v1") + testURL := base + "/models" + if providerID == catwalk.InferenceKimiCoding { + testURL = base + "/v1/models" + } + return &validationProbe{ + method: http.MethodGet, + url: testURL, + headers: map[string]string{ + "x-api-key": apiKey, + "anthropic-version": "2023-06-01", + }, + classify: classifyAuthGated(c), + }, nil case catwalk.TypeGoogle: - baseURL, _ := resolver.ResolveValue(c.BaseURL) - baseURL = cmp.Or(baseURL, "https://generativelanguage.googleapis.com") - testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey) + base := cmp.Or(baseURL, "https://generativelanguage.googleapis.com") + return &validationProbe{ + method: http.MethodGet, + url: base + "/v1beta/models?key=" + url.QueryEscape(apiKey), + classify: classifyGoogleModels(c), + }, nil case catwalk.TypeBedrock: // NOTE: Bedrock has a `/foundation-models` endpoint that we could in // theory use, but apparently the authorization is region-specific, - // so it's not so trivial. - if strings.HasPrefix(apiKey, "ABSK") { // Bedrock API keys - return nil + // so it's not so trivial. Fall back to a prefix check. + if strings.HasPrefix(apiKey, "ABSK") { + return nil, nil } - return errors.New("not a valid bedrock api key") + return nil, errors.New("not a valid bedrock api key") case catwalk.TypeVercel: // NOTE: Vercel does not validate API keys on the `/models` endpoint. - if strings.HasPrefix(apiKey, "vck_") { // Vercel API keys - return nil + if strings.HasPrefix(apiKey, "vck_") { + return nil, nil } - return errors.New("not a valid vercel api key") + return nil, errors.New("not a valid vercel api key") + case catwalk.TypeOpenAICompat: + // Generic openai-compat providers often expose a public /models + // endpoint, so hitting it proves nothing about the caller's key. + // Only providers we've confirmed to gate /models on auth use the + // /models probe; everyone else needs an explicit override above or + // returns ErrValidationUnsupported. + if _, ok := openaiCompatModelsAllowlist[providerID]; !ok { + return nil, ErrValidationUnsupported + } + if baseURL == "" { + return nil, ErrValidationUnsupported + } + return &validationProbe{ + method: http.MethodGet, + url: baseURL + "/models", + headers: map[string]string{ + "Authorization": "Bearer " + apiKey, + }, + classify: classifyAuthGated(c), + }, nil + } + + return nil, ErrValidationUnsupported +} + +// TestConnection attempts to prove that the configured API key authenticates +// with the provider. It returns nil on confirmed success, [ErrValidationUnsupported] +// when the provider has no deterministic validation probe, or a non-nil error +// describing the validation failure. +func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { + probe, err := c.buildValidationProbe(resolver) + if err != nil { + return err + } + if probe == nil { + // A nil probe with no error means the configuration was accepted + // without needing a network round-trip (e.g. Bedrock/Vercel prefix + // checks). + return nil + } + if probe.url == "" { + return ErrValidationUnsupported } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - client := &http.Client{} - req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil) + var body io.Reader + if len(probe.body) > 0 { + body = bytes.NewReader(probe.body) + } + req, err := http.NewRequestWithContext(ctx, probe.method, probe.url, body) if err != nil { - return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) + // Probe construction failures shouldn't surface as low-signal user + // errors; treat them as "cannot verify" instead. + return ErrValidationUnsupported } - for k, v := range headers { + for k, v := range probe.headers { req.Header.Set(k, v) } for k, v := range c.ExtraHeaders { req.Header.Set(k, v) } + client := &http.Client{} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) + return fmt.Errorf("failed to connect to provider %s: %w", c.ID, err) } defer resp.Body.Close() - switch providerID { - case catwalk.InferenceProviderZAI: - if resp.StatusCode == http.StatusUnauthorized { - return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status) - } - default: - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status) - } - } - return nil + return probe.classify(resp.StatusCode) } func resolveEnvs(envs map[string]string) []string { diff --git a/internal/config/config_validate_test.go b/internal/config/config_validate_test.go new file mode 100644 index 0000000000000000000000000000000000000000..244919c739b700e70c50c09827eba773059c37f0 --- /dev/null +++ b/internal/config/config_validate_test.go @@ -0,0 +1,776 @@ +package config + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "charm.land/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type capturedRequest struct { + method string + path string + query string + headers http.Header + body []byte +} + +func newCaptureServer(t *testing.T, status int) (*httptest.Server, *capturedRequest) { + t.Helper() + captured := &capturedRequest{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.method = r.Method + captured.path = r.URL.Path + captured.query = r.URL.RawQuery + captured.headers = r.Header.Clone() + captured.body, _ = io.ReadAll(r.Body) + w.WriteHeader(status) + })) + t.Cleanup(srv.Close) + return srv, captured +} + +func TestTestConnectionMiniMaxProbe(t *testing.T) { + t.Parallel() + + for _, id := range []catwalk.InferenceProvider{ + catwalk.InferenceProviderMiniMax, + catwalk.InferenceProviderMiniMaxChina, + } { + t.Run(string(id), func(t *testing.T) { + t.Parallel() + for name, tc := range map[string]struct { + status int + wantErr error + wantNil bool + }{ + "valid": {status: http.StatusOK, wantNil: true}, + "invalid401": {status: http.StatusUnauthorized}, + "invalid403": {status: http.StatusForbidden}, + "unsupported": {status: http.StatusTeapot, wantErr: ErrValidationUnsupported}, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + srv, captured := newCaptureServer(t, tc.status) + c := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeAnthropic, + BaseURL: srv.URL, + APIKey: "key-abc", + } + err := c.TestConnection(IdentityResolver()) + switch { + case tc.wantNil: + require.NoError(t, err) + case tc.wantErr != nil: + require.ErrorIs(t, err, tc.wantErr) + default: + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + } + require.Equal(t, http.MethodGet, captured.method) + require.Equal(t, "/v1/models", captured.path) + require.Equal(t, "key-abc", captured.headers.Get("x-api-key")) + require.Equal(t, "2023-06-01", captured.headers.Get("anthropic-version")) + }) + } + }) + } +} + +func TestTestConnectionVeniceProbe(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + status int + wantErr error + wantNil bool + }{ + "valid": {status: http.StatusOK, wantNil: true}, + "invalid401": {status: http.StatusUnauthorized}, + "invalid403": {status: http.StatusForbidden}, + "rateLimited": {status: http.StatusTooManyRequests, wantErr: ErrValidationUnsupported}, + "paymentReq": {status: http.StatusPaymentRequired, wantErr: ErrValidationUnsupported}, + "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + srv, captured := newCaptureServer(t, tc.status) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderVenice), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-venice", + } + err := c.TestConnection(IdentityResolver()) + switch { + case tc.wantNil: + require.NoError(t, err) + case tc.wantErr != nil: + require.ErrorIs(t, err, tc.wantErr) + default: + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + } + require.Equal(t, http.MethodGet, captured.method) + require.Equal(t, "/api_keys/rate_limits", captured.path) + require.Equal(t, "Bearer sk-venice", captured.headers.Get("Authorization")) + }) + } +} + +func TestTestConnectionOpenAICompatChatProbe(t *testing.T) { + t.Parallel() + + providers := []catwalk.InferenceProvider{ + catwalk.InferenceAIHubMix, + catwalk.InferenceProviderAvian, + catwalk.InferenceProviderCortecs, + catwalk.InferenceProviderHuggingFace, + catwalk.InferenceProviderIoNet, + catwalk.InferenceProviderOpenCodeGo, + catwalk.InferenceProviderOpenCodeZen, + catwalk.InferenceProviderQiniuCloud, + catwalk.InferenceProviderSynthetic, + } + for _, id := range providers { + t.Run(string(id), func(t *testing.T) { + t.Parallel() + cases := map[string]struct { + status int + wantErr error + wantNil bool + }{ + "authPassed400": {status: http.StatusBadRequest, wantNil: true}, + "authPassed422": {status: http.StatusUnprocessableEntity, wantNil: true}, + "invalid401": {status: http.StatusUnauthorized}, + "invalid403": {status: http.StatusForbidden}, + "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported}, + "unexpected200": {status: http.StatusOK, wantErr: ErrValidationUnsupported}, + "unexpectedOther": {status: http.StatusTeapot, wantErr: ErrValidationUnsupported}, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + t.Parallel() + srv, captured := newCaptureServer(t, tc.status) + c := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-test", + } + err := c.TestConnection(IdentityResolver()) + switch { + case tc.wantNil: + require.NoError(t, err) + case tc.wantErr != nil: + require.ErrorIs(t, err, tc.wantErr) + default: + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + } + require.Equal(t, http.MethodPost, captured.method) + require.Equal(t, "/chat/completions", captured.path) + require.Equal(t, "Bearer sk-test", captured.headers.Get("Authorization")) + require.Equal(t, "application/json", captured.headers.Get("Content-Type")) + require.NotEmpty(t, captured.body) + }) + } + }) + } +} + +func TestTestConnectionUnsupportedProviders(t *testing.T) { + t.Parallel() + + for _, id := range []catwalk.InferenceProvider{ + catwalk.InferenceProviderChutes, + catwalk.InferenceProviderNeuralwatt, + } { + t.Run(string(id), func(t *testing.T) { + t.Parallel() + c := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeOpenAICompat, + BaseURL: "https://example.invalid", + APIKey: "sk-test", + } + err := c.TestConnection(IdentityResolver()) + require.ErrorIs(t, err, ErrValidationUnsupported) + }) + } +} + +func TestTestConnectionUnknownOpenAICompatIsUnsupported(t *testing.T) { + t.Parallel() + + c := &ProviderConfig{ + ID: "some-new-openai-compat-provider", + Type: catwalk.TypeOpenAICompat, + BaseURL: "https://example.invalid", + APIKey: "sk-test", + } + err := c.TestConnection(IdentityResolver()) + require.ErrorIs(t, err, ErrValidationUnsupported) +} + +func TestTestConnectionEmptyProbeURLIsUnsupported(t *testing.T) { + t.Parallel() + + // Chutes has a provider override that returns ErrValidationUnsupported + // regardless of configured base URL; this also guards the empty-URL path. + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderChutes), + Type: catwalk.TypeOpenAICompat, + APIKey: "sk-test", + } + err := c.TestConnection(IdentityResolver()) + require.ErrorIs(t, err, ErrValidationUnsupported) +} + +func TestTestConnectionExtraHeadersAreApplied(t *testing.T) { + t.Parallel() + + srv, captured := newCaptureServer(t, http.StatusBadRequest) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderSynthetic), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-test", + ExtraHeaders: map[string]string{ + "X-Custom-Header": "custom-value", + "Authorization": "overridden", + }, + } + err := c.TestConnection(IdentityResolver()) + require.NoError(t, err) + require.Equal(t, "custom-value", captured.headers.Get("X-Custom-Header")) + // ExtraHeaders are applied after the probe headers, so callers can + // override per-provider defaults if necessary. + require.Equal(t, "overridden", captured.headers.Get("Authorization")) +} + +func TestTestConnectionOpenAITypeProbesModelsEndpoint(t *testing.T) { + t.Parallel() + + srv, captured := newCaptureServer(t, http.StatusOK) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderOpenAI), + Type: catwalk.TypeOpenAI, + BaseURL: srv.URL, + APIKey: "sk-openai", + } + err := c.TestConnection(IdentityResolver()) + require.NoError(t, err) + require.Equal(t, http.MethodGet, captured.method) + require.Equal(t, "/models", captured.path) + require.Equal(t, "Bearer sk-openai", captured.headers.Get("Authorization")) +} + +func TestTestConnectionOpenRouterProbesCreditsEndpoint(t *testing.T) { + t.Parallel() + + srv, captured := newCaptureServer(t, http.StatusOK) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderOpenRouter), + Type: catwalk.TypeOpenRouter, + BaseURL: srv.URL, + APIKey: "sk-or", + } + err := c.TestConnection(IdentityResolver()) + require.NoError(t, err) + require.Equal(t, "/credits", captured.path) +} + +func TestTestConnectionAnthropicTypeProbesModels(t *testing.T) { + t.Parallel() + + srv, captured := newCaptureServer(t, http.StatusOK) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderAnthropic), + Type: catwalk.TypeAnthropic, + BaseURL: srv.URL, + APIKey: "ak-test", + } + err := c.TestConnection(IdentityResolver()) + require.NoError(t, err) + require.Equal(t, "/models", captured.path) + require.Equal(t, "ak-test", captured.headers.Get("x-api-key")) +} + +func TestTestConnectionKimiCodingUsesV1Models(t *testing.T) { + t.Parallel() + + srv, captured := newCaptureServer(t, http.StatusOK) + c := &ProviderConfig{ + ID: string(catwalk.InferenceKimiCoding), + Type: catwalk.TypeAnthropic, + BaseURL: srv.URL, + APIKey: "ak-kimi", + } + err := c.TestConnection(IdentityResolver()) + require.NoError(t, err) + require.Equal(t, "/v1/models", captured.path) +} + +func TestTestConnectionGoogleIncludesKeyQueryParam(t *testing.T) { + t.Parallel() + + srv, captured := newCaptureServer(t, http.StatusOK) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderGemini), + Type: catwalk.TypeGoogle, + BaseURL: srv.URL, + APIKey: "google-key", + } + err := c.TestConnection(IdentityResolver()) + require.NoError(t, err) + require.Equal(t, "/v1beta/models", captured.path) + require.Contains(t, captured.query, "key=google-key") +} + +// TestTestConnectionGoogleBadKeyIs400 locks in the fact that Google returns +// 400 INVALID_ARGUMENT (not 401) for an unknown API key, so 400 must map to +// "invalid" and never to [ErrValidationUnsupported]. +func TestTestConnectionGoogleBadKeyIs400(t *testing.T) { + t.Parallel() + + for name, tc := range map[string]struct { + status int + wantNil bool + wantErr error + }{ + "badKey400": {status: http.StatusBadRequest}, + "unauth401": {status: http.StatusUnauthorized}, + "forbidden403": {status: http.StatusForbidden}, + "ok200": {status: http.StatusOK, wantNil: true}, + "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported}, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + srv, _ := newCaptureServer(t, tc.status) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderGemini), + Type: catwalk.TypeGoogle, + BaseURL: srv.URL, + APIKey: "bad-key", + } + err := c.TestConnection(IdentityResolver()) + switch { + case tc.wantNil: + require.NoError(t, err) + case tc.wantErr != nil: + require.ErrorIs(t, err, tc.wantErr) + default: + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + } + }) + } +} + +// TestTestConnectionOpenAICompatAllowlistUsesModelsProbe locks in the +// `/models` probe for openai-compat providers whose /models is known to be +// auth-gated. These providers must not fall through to +// [ErrValidationUnsupported]. +func TestTestConnectionOpenAICompatAllowlistUsesModelsProbe(t *testing.T) { + t.Parallel() + + providers := []catwalk.InferenceProvider{ + "deepseek", + catwalk.InferenceProviderGROQ, + catwalk.InferenceProviderXAI, + catwalk.InferenceProviderZhipu, + catwalk.InferenceProviderZhipuCoding, + catwalk.InferenceProviderCerebras, + catwalk.InferenceProviderNebius, + catwalk.InferenceProviderCopilot, + } + for _, id := range providers { + t.Run(string(id), func(t *testing.T) { + t.Parallel() + t.Run("valid", func(t *testing.T) { + t.Parallel() + srv, captured := newCaptureServer(t, http.StatusOK) + c := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-good", + } + require.NoError(t, c.TestConnection(IdentityResolver())) + require.Equal(t, http.MethodGet, captured.method) + require.Equal(t, "/models", captured.path) + require.Equal(t, "Bearer sk-good", captured.headers.Get("Authorization")) + }) + t.Run("invalid", func(t *testing.T) { + t.Parallel() + srv, _ := newCaptureServer(t, http.StatusUnauthorized) + c := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-bad", + } + err := c.TestConnection(IdentityResolver()) + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + }) + }) + } +} + +// TestTestConnectionZAIUsesZAIClassifier pins ZAI's historical quirk: /models +// returns non-200 for valid keys but always 401 for bad keys. +func TestTestConnectionZAIUsesZAIClassifier(t *testing.T) { + t.Parallel() + + for name, tc := range map[string]struct { + status int + wantNil bool + }{ + "ok200": {status: http.StatusOK, wantNil: true}, + "other400": {status: http.StatusBadRequest, wantNil: true}, + "other500": {status: http.StatusInternalServerError, wantNil: true}, + "badKey401": {status: http.StatusUnauthorized}, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + srv, captured := newCaptureServer(t, tc.status) + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderZAI), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-zai", + } + err := c.TestConnection(IdentityResolver()) + if tc.wantNil { + require.NoError(t, err) + } else { + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + } + require.Equal(t, "/models", captured.path) + require.Equal(t, "Bearer sk-zai", captured.headers.Get("Authorization")) + }) + } +} + +func TestTestConnectionBedrockPrefix(t *testing.T) { + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderBedrock), + Type: catwalk.TypeBedrock, + APIKey: "ABSK-secret", + } + require.NoError(t, c.TestConnection(IdentityResolver())) + }) + t.Run("invalid", func(t *testing.T) { + t.Parallel() + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderBedrock), + Type: catwalk.TypeBedrock, + APIKey: "nope", + } + err := c.TestConnection(IdentityResolver()) + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + }) +} + +func TestTestConnectionVercelPrefix(t *testing.T) { + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderVercel), + Type: catwalk.TypeVercel, + APIKey: "vck_abc", + } + require.NoError(t, c.TestConnection(IdentityResolver())) + }) + t.Run("invalid", func(t *testing.T) { + t.Parallel() + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderVercel), + Type: catwalk.TypeVercel, + APIKey: "nope", + } + err := c.TestConnection(IdentityResolver()) + require.Error(t, err) + require.NotErrorIs(t, err, ErrValidationUnsupported) + }) +} + +// TestTestConnectionPublicModelsAuthGatedChatRegression locks in the core +// regression from the 2025-10-20 expansion of generic /models validation to +// openai-compat: a provider whose /models is intentionally public would +// report any key as "validated" even though /chat/completions actually +// gates on auth. For every provider we currently mark "validated" via the +// malformed-body chat probe, this test simulates both endpoints and asserts +// that: +// +// 1. A bad key (401 on /chat/completions) is reported as invalid, not as +// "validated" — even when /models returns 200 unauthenticated. +// 2. A good key (400/422 on /chat/completions) is reported as valid. +// 3. The probe never hits /models for these providers. +func TestTestConnectionPublicModelsAuthGatedChatRegression(t *testing.T) { + t.Parallel() + + providers := []catwalk.InferenceProvider{ + catwalk.InferenceAIHubMix, + catwalk.InferenceProviderAvian, + catwalk.InferenceProviderCortecs, + catwalk.InferenceProviderHuggingFace, + catwalk.InferenceProviderIoNet, + catwalk.InferenceProviderOpenCodeGo, + catwalk.InferenceProviderOpenCodeZen, + catwalk.InferenceProviderQiniuCloud, + catwalk.InferenceProviderSynthetic, + } + for _, id := range providers { + t.Run(string(id), func(t *testing.T) { + t.Parallel() + + type hits struct { + models int + chat int + } + for name, tc := range map[string]struct { + chatStatus int + wantErr error + wantNil bool + }{ + "badKeyIsInvalidNotValidated": { + chatStatus: http.StatusUnauthorized, + }, + "goodKeyIsValidated": { + chatStatus: http.StatusBadRequest, + wantNil: true, + }, + "forbiddenKeyIsInvalid": { + chatStatus: http.StatusForbidden, + }, + "schemaFailure422IsValidated": { + chatStatus: http.StatusUnprocessableEntity, + wantNil: true, + }, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + h := &hits{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/models": + // Simulate a public /models endpoint that + // returns 200 regardless of the provided key. + h.models++ + w.WriteHeader(http.StatusOK) + case "/chat/completions": + h.chat++ + w.WriteHeader(tc.chatStatus) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + t.Cleanup(srv.Close) + + c := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-test", + } + err := c.TestConnection(IdentityResolver()) + + if tc.wantNil { + require.NoError(t, err, "expected %s to validate on %d", id, tc.chatStatus) + } else { + require.Error(t, err, "expected %s to reject on %d", id, tc.chatStatus) + require.NotErrorIs(t, err, ErrValidationUnsupported) + } + require.Equal(t, 0, h.models, "probe must not rely on public /models for %s", id) + require.Equal(t, 1, h.chat, "probe must hit /chat/completions for %s", id) + }) + } + }) + } +} + +// TestTestConnectionOpenAICompatProviderAudit is an audit table that pins the +// full set of openai-compat providers currently exposed as "validated" (i.e. +// TestConnection can return nil on some response) and documents the exact +// probe each uses. Adding a new openai-compat provider to the validated set +// MUST update this table; this prevents silent drift back into the +// "assume /models proves auth" bug class. +// +// Providers not listed here either: +// - use a different Type (TypeOpenAI / TypeAnthropic / TypeGoogle / ...); +// - are explicitly gated behind ErrValidationUnsupported (chutes, neuralwatt, +// and every unknown openai-compat provider). +func TestTestConnectionOpenAICompatProviderAudit(t *testing.T) { + t.Parallel() + + audit := map[catwalk.InferenceProvider]auditCase{ + catwalk.InferenceProviderVenice: { + method: http.MethodGet, + path: "/api_keys/rate_limits", + validStatus: http.StatusOK, + invalidStatus: http.StatusUnauthorized, + authHeader: "Authorization", + authValue: "Bearer sk-test", + }, + catwalk.InferenceAIHubMix: openaiCompatAuditCase(), + catwalk.InferenceProviderAvian: openaiCompatAuditCase(), + catwalk.InferenceProviderCortecs: openaiCompatAuditCase(), + catwalk.InferenceProviderHuggingFace: openaiCompatAuditCase(), + catwalk.InferenceProviderIoNet: openaiCompatAuditCase(), + catwalk.InferenceProviderOpenCodeGo: openaiCompatAuditCase(), + catwalk.InferenceProviderOpenCodeZen: openaiCompatAuditCase(), + catwalk.InferenceProviderQiniuCloud: openaiCompatAuditCase(), + catwalk.InferenceProviderSynthetic: openaiCompatAuditCase(), + // openai-compat providers with auth-gated /models (allowlist). + "deepseek": openaiCompatModelsAuditCase(), + catwalk.InferenceProviderGROQ: openaiCompatModelsAuditCase(), + catwalk.InferenceProviderXAI: openaiCompatModelsAuditCase(), + catwalk.InferenceProviderZhipu: openaiCompatModelsAuditCase(), + catwalk.InferenceProviderZhipuCoding: openaiCompatModelsAuditCase(), + catwalk.InferenceProviderCerebras: openaiCompatModelsAuditCase(), + catwalk.InferenceProviderNebius: openaiCompatModelsAuditCase(), + catwalk.InferenceProviderCopilot: openaiCompatModelsAuditCase(), + // ZAI uses the /models endpoint but with its own classifier that + // only treats 401 as invalid. Its valid path must therefore be 200 + // here for the audit's generic "valid -> nil" check to hold. + catwalk.InferenceProviderZAI: { + method: http.MethodGet, + path: "/models", + validStatus: http.StatusOK, + invalidStatus: http.StatusUnauthorized, + authHeader: "Authorization", + authValue: "Bearer sk-test", + }, + } + + for id, tc := range audit { + t.Run(string(id), func(t *testing.T) { + t.Parallel() + + // 1) Valid path. + srv, captured := newCaptureServer(t, tc.validStatus) + c := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv.URL, + APIKey: "sk-test", + } + require.NoError(t, c.TestConnection(IdentityResolver())) + require.Equal(t, tc.method, captured.method, "audit: wrong method for %s", id) + require.Equal(t, tc.path, captured.path, "audit: wrong path for %s", id) + require.Equal(t, tc.authValue, captured.headers.Get(tc.authHeader), + "audit: wrong auth header for %s", id) + + // 2) Invalid path. + srv2, _ := newCaptureServer(t, tc.invalidStatus) + c2 := &ProviderConfig{ + ID: string(id), + Type: catwalk.TypeOpenAICompat, + BaseURL: srv2.URL, + APIKey: "sk-test", + } + err := c2.TestConnection(IdentityResolver()) + require.Error(t, err, "audit: %s must reject %d as invalid", id, tc.invalidStatus) + require.NotErrorIs(t, err, ErrValidationUnsupported, + "audit: %s must not leak ErrValidationUnsupported on %d", id, tc.invalidStatus) + }) + } + + // Sanity: every provider that currently enters the openai-compat chat + // probe path must appear in the audit. This guards against a future + // refactor silently adding a provider without test coverage. + chatProbeProviders := []catwalk.InferenceProvider{ + catwalk.InferenceAIHubMix, + catwalk.InferenceProviderAvian, + catwalk.InferenceProviderCortecs, + catwalk.InferenceProviderHuggingFace, + catwalk.InferenceProviderIoNet, + catwalk.InferenceProviderOpenCodeGo, + catwalk.InferenceProviderOpenCodeZen, + catwalk.InferenceProviderQiniuCloud, + catwalk.InferenceProviderSynthetic, + } + for _, id := range chatProbeProviders { + _, ok := audit[id] + require.True(t, ok, "audit table missing entry for %s", id) + } +} + +// auditCase pins the expected probe shape for a given provider. +type auditCase struct { + method string + path string + // validStatus is a response code the probe must translate to + // "validated" (nil error). + validStatus int + // invalidStatus is a response code the probe must translate to an + // invalid-key error (not ErrValidationUnsupported). + invalidStatus int + // authHeader is the name of the header the probe uses to present + // the key. + authHeader string + authValue string +} + +func openaiCompatAuditCase() auditCase { + return auditCase{ + method: http.MethodPost, + path: "/chat/completions", + validStatus: http.StatusBadRequest, + invalidStatus: http.StatusUnauthorized, + authHeader: "Authorization", + authValue: "Bearer sk-test", + } +} + +func openaiCompatModelsAuditCase() auditCase { + return auditCase{ + method: http.MethodGet, + path: "/models", + validStatus: http.StatusOK, + invalidStatus: http.StatusUnauthorized, + authHeader: "Authorization", + authValue: "Bearer sk-test", + } +} + +func TestTestConnectionNetworkErrorIsNotInvalidKey(t *testing.T) { + t.Parallel() + + // Start and immediately close a server so the next request fails at the + // TCP layer. That should produce a non-nil error that is *not* + // ErrValidationUnsupported (transport errors still surface). + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + srv.Close() + c := &ProviderConfig{ + ID: string(catwalk.InferenceProviderOpenAI), + Type: catwalk.TypeOpenAI, + BaseURL: srv.URL, + APIKey: "sk-test", + } + err := c.TestConnection(IdentityResolver()) + require.Error(t, err) + // The error message should mention the provider so users see a useful + // hint, even though we can't classify the status code. + require.True(t, strings.Contains(err.Error(), "openai") || errors.Is(err, ErrValidationUnsupported)) +} diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go index 1bb232a23fa61147ae291fc3ece6a17c402e5afb..8938d1d2861b9fc2022a1a921d617e1c57612d9a 100644 --- a/internal/ui/dialog/api_key_input.go +++ b/internal/ui/dialog/api_key_input.go @@ -1,7 +1,9 @@ package dialog import ( + "errors" "fmt" + "maps" "strings" "time" @@ -25,6 +27,10 @@ const ( APIKeyInputStateInitial APIKeyInputState = iota APIKeyInputStateVerifying APIKeyInputStateVerified + // APIKeyInputStateUnverified indicates the key was saved but the + // provider does not expose a deterministic validation probe, so + // authentication could not be proven. + APIKeyInputStateUnverified APIKeyInputStateError ) @@ -128,7 +134,7 @@ func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action { // do nothing case key.Matches(msg, m.keyMap.Close): switch m.state { - case APIKeyInputStateVerified: + case APIKeyInputStateVerified, APIKeyInputStateUnverified: return m.saveKeyAndContinue() default: return ActionClose{} @@ -137,7 +143,7 @@ func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action { switch m.state { case APIKeyInputStateInitial, APIKeyInputStateError: return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying} - case APIKeyInputStateVerified: + case APIKeyInputStateVerified, APIKeyInputStateUnverified: return m.saveKeyAndContinue() } default: @@ -219,6 +225,8 @@ func (m *APIKeyInput) dialogTitle() string { return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...") case APIKeyInputStateVerified: return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.") + case APIKeyInputStateUnverified: + return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" saved (not verified).") case APIKeyInputStateError: return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?") } @@ -240,7 +248,7 @@ func (m *APIKeyInput) inputView() string { m.input.Prompt = m.spinner.View() m.input.SetStyles(ts) m.input.Blur() - case APIKeyInputStateVerified: + case APIKeyInputStateVerified, APIKeyInputStateUnverified: ts := t.TextInput ts.Blurred.Prompt = ts.Focused.Prompt @@ -284,13 +292,7 @@ func (m *APIKeyInput) ShortHelp() []key.Binding { func (m *APIKeyInput) verifyAPIKey() tea.Msg { start := time.Now() - providerConfig := config.ProviderConfig{ - ID: string(m.provider.ID), - Name: m.provider.Name, - APIKey: m.input.Value(), - Type: m.provider.Type, - BaseURL: m.provider.APIEndpoint, - } + providerConfig := providerConfigForVerify(m.provider, m.input.Value()) err := providerConfig.TestConnection(m.com.Workspace.Resolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner @@ -300,10 +302,41 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { time.Sleep(minimum - elapsed) } - if err == nil { - return ActionChangeAPIKeyState{APIKeyInputStateVerified} + return ActionChangeAPIKeyState{State: apiKeyStateForVerifyErr(err)} +} + +// providerConfigForVerify builds the [config.ProviderConfig] used to probe a +// provider's API key from the dialog. In particular it propagates the +// provider's [catwalk.Provider.DefaultHeaders] into [config.ProviderConfig.ExtraHeaders] +// so validation probes carry any headers the provider expects (e.g. routing +// or tenant hints) — matching the behaviour used for real inference traffic. +func providerConfigForVerify(provider catwalk.Provider, apiKey string) config.ProviderConfig { + cfg := config.ProviderConfig{ + ID: string(provider.ID), + Name: provider.Name, + APIKey: apiKey, + Type: provider.Type, + BaseURL: provider.APIEndpoint, + } + if len(provider.DefaultHeaders) > 0 { + cfg.ExtraHeaders = make(map[string]string, len(provider.DefaultHeaders)) + maps.Copy(cfg.ExtraHeaders, provider.DefaultHeaders) + } + return cfg +} + +// apiKeyStateForVerifyErr maps a [config.ProviderConfig.TestConnection] error +// to the dialog state the UI should transition into. Extracted so the +// mapping can be unit-tested without spinning up a full [common.Common]. +func apiKeyStateForVerifyErr(err error) APIKeyInputState { + switch { + case err == nil: + return APIKeyInputStateVerified + case errors.Is(err, config.ErrValidationUnsupported): + return APIKeyInputStateUnverified + default: + return APIKeyInputStateError } - return ActionChangeAPIKeyState{APIKeyInputStateError} } func (m *APIKeyInput) saveKeyAndContinue() Action { diff --git a/internal/ui/dialog/api_key_input_test.go b/internal/ui/dialog/api_key_input_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bc54a51019dc9552915f3246b0c9e281060565b8 --- /dev/null +++ b/internal/ui/dialog/api_key_input_test.go @@ -0,0 +1,162 @@ +package dialog + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/require" +) + +// TestAPIKeyStateForVerifyErr pins the mapping between TestConnection errors +// and the dialog state the UI should transition into. In particular, an +// [config.ErrValidationUnsupported] error must yield the unverified state +// (so the UI shows "saved (not verified)" instead of "invalid"). +func TestAPIKeyStateForVerifyErr(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + err error + want APIKeyInputState + }{ + "nilIsVerified": { + err: nil, + want: APIKeyInputStateVerified, + }, + "unsupportedIsUnverified": { + err: config.ErrValidationUnsupported, + want: APIKeyInputStateUnverified, + }, + "wrappedUnsupportedIsUnverified": { + err: fmt.Errorf("probing provider: %w", config.ErrValidationUnsupported), + want: APIKeyInputStateUnverified, + }, + "plainErrorIsError": { + err: errors.New("bad key"), + want: APIKeyInputStateError, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.want, apiKeyStateForVerifyErr(tc.err)) + }) + } +} + +// TestProviderConfigForVerifyPropagatesDefaultHeaders locks in the UI +// contract that any catwalk-declared DefaultHeaders are copied into the +// ProviderConfig.ExtraHeaders used for the validation probe. Without this, +// providers that require routing/tenant headers (e.g. DefaultHeaders +// supplied by the catwalk provider definition) would probe with a stripped +// header set and potentially be misclassified as "not verified". +func TestProviderConfigForVerifyPropagatesDefaultHeaders(t *testing.T) { + t.Parallel() + + provider := catwalk.Provider{ + ID: "test-provider", + Name: "Test Provider", + Type: catwalk.TypeOpenAICompat, + APIEndpoint: "https://example.invalid", + DefaultHeaders: map[string]string{ + "X-Tenant": "acme", + "X-Route": "primary", + "X-Shared": "from-default", + "User-Agent": "crush-test", + }, + } + cfg := providerConfigForVerify(provider, "sk-test") + + require.Equal(t, string(provider.ID), cfg.ID) + require.Equal(t, provider.Name, cfg.Name) + require.Equal(t, provider.Type, cfg.Type) + require.Equal(t, provider.APIEndpoint, cfg.BaseURL) + require.Equal(t, "sk-test", cfg.APIKey) + require.Equal(t, provider.DefaultHeaders, cfg.ExtraHeaders, + "DefaultHeaders must be propagated to ExtraHeaders") + + // Mutating the returned config must not leak back into the provider + // definition (the dialog reuses the provider value across retries). + cfg.ExtraHeaders["X-Tenant"] = "mutated" + require.Equal(t, "acme", provider.DefaultHeaders["X-Tenant"], + "providerConfigForVerify must copy DefaultHeaders, not alias them") +} + +func TestProviderConfigForVerifyWithNoDefaultHeaders(t *testing.T) { + t.Parallel() + + provider := catwalk.Provider{ + ID: "test-provider", + Type: catwalk.TypeOpenAICompat, + APIEndpoint: "https://example.invalid", + } + cfg := providerConfigForVerify(provider, "sk-test") + require.Nil(t, cfg.ExtraHeaders, + "no DefaultHeaders should yield no ExtraHeaders allocation") +} + +// TestProviderConfigForVerifyHeadersHitTheWire is an end-to-end UI-level +// check: after providerConfigForVerify builds the probe config, calling +// TestConnection against a local server must deliver the DefaultHeaders on +// the outbound request. This guards against silent regressions where the +// header map is dropped between the dialog and the HTTP layer. +func TestProviderConfigForVerifyHeadersHitTheWire(t *testing.T) { + t.Parallel() + + var captured http.Header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Clone() + _, _ = io.Copy(io.Discard, r.Body) + // 400 on the malformed-body chat-completions probe means "auth + // passed, schema rejected" for the Synthetic-style override. + w.WriteHeader(http.StatusBadRequest) + })) + t.Cleanup(srv.Close) + + provider := catwalk.Provider{ + ID: catwalk.InferenceProviderSynthetic, + Name: "Synthetic", + Type: catwalk.TypeOpenAICompat, + APIEndpoint: srv.URL, + DefaultHeaders: map[string]string{ + "X-Tenant": "acme", + "X-Route": "primary", + }, + } + cfg := providerConfigForVerify(provider, "sk-test") + require.NoError(t, cfg.TestConnection(config.IdentityResolver())) + + require.NotNil(t, captured, "probe must have reached the test server") + require.Equal(t, "acme", captured.Get("X-Tenant")) + require.Equal(t, "primary", captured.Get("X-Route")) + // Probe-defined headers should still be present alongside the + // DefaultHeaders. + require.Equal(t, "Bearer sk-test", captured.Get("Authorization")) +} + +// TestAPIKeyInputStatesAreDistinct guards against someone accidentally making +// APIKeyInputStateUnverified equal to one of the other states (which would +// silently collapse the "saved (not verified)" path onto "validated" or +// "invalid"). +func TestAPIKeyInputStatesAreDistinct(t *testing.T) { + t.Parallel() + + states := []APIKeyInputState{ + APIKeyInputStateInitial, + APIKeyInputStateVerifying, + APIKeyInputStateVerified, + APIKeyInputStateUnverified, + APIKeyInputStateError, + } + seen := map[APIKeyInputState]struct{}{} + for _, s := range states { + _, dup := seen[s] + require.False(t, dup, "state %d declared twice", s) + seen[s] = struct{}{} + } +}