@@ -1,10 +1,12 @@
package config
import (
+ "bytes"
"cmp"
"context"
"errors"
"fmt"
+ "io"
"log/slog"
"maps"
"net/http"
@@ -571,100 +573,320 @@ func (c *Config) SetupAgents() {
c.Agents = agents
}
-func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
- var (
- providerID = catwalk.InferenceProvider(c.ID)
- testURL = ""
- headers = make(map[string]string)
- apiKey, _ = resolver.ResolveValue(c.APIKey)
- )
+// ErrValidationUnsupported is returned from [ProviderConfig.TestConnection]
+// when the provider does not expose a deterministic endpoint that proves API
+// key authentication without performing inference. Callers should treat this
+// as "saved but not verified" rather than as a validation failure.
+var ErrValidationUnsupported = errors.New("provider does not expose a deterministic validation probe")
+
+// validationProbe describes a single HTTP request used to prove authentication
+// for a given provider configuration.
+type validationProbe struct {
+ method string
+ url string
+ headers map[string]string
+ body []byte
+ classify func(statusCode int) error
+}
+
+// classifyAuthGated treats the probe endpoint as one that is expected to
+// return 200 with a valid key and 401/403 with an invalid one. Any other
+// status is considered non-deterministic and reported as unsupported so the
+// UI can show "not verified" instead of a misleading "invalid key".
+func classifyAuthGated(c *ProviderConfig) func(int) error {
+ return func(status int) error {
+ switch status {
+ case http.StatusOK:
+ return nil
+ case http.StatusUnauthorized, http.StatusForbidden:
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status))
+ default:
+ return ErrValidationUnsupported
+ }
+ }
+}
- switch providerID {
- case catwalk.InferenceProviderMiniMax, catwalk.InferenceProviderMiniMaxChina:
- // NOTE: MiniMax has no good endpoint we can use to validate the API key.
- return nil
+// classifyOpenAIChatMalformed classifies responses from a deliberately
+// malformed POST {baseURL}/chat/completions probe. On most OpenAI-compatible
+// gateways authentication happens before schema validation, so 401/403 means
+// the key is bad while 400/422 means the key was accepted and only the body
+// was rejected. Anything else is treated as unsupported / transient.
+func classifyOpenAIChatMalformed(c *ProviderConfig) func(int) error {
+ return func(status int) error {
+ switch status {
+ case http.StatusUnauthorized, http.StatusForbidden:
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status))
+ case http.StatusBadRequest, http.StatusUnprocessableEntity:
+ return nil
+ default:
+ return ErrValidationUnsupported
+ }
}
+}
- switch c.Type {
- case catwalk.TypeOpenAI, catwalk.TypeOpenAICompat, catwalk.TypeOpenRouter:
- baseURL, _ := resolver.ResolveValue(c.BaseURL)
- baseURL = cmp.Or(baseURL, "https://api.openai.com/v1")
-
- switch providerID {
- case catwalk.InferenceProviderOpenRouter:
- testURL = baseURL + "/credits"
- case catwalk.InferenceProviderOpenCodeGo:
- testURL = strings.Replace(baseURL, "/go", "", 1) + "/models"
+// classifyGoogleModels classifies responses from Google's
+// `/v1beta/models?key=…` probe. Google returns 400 INVALID_ARGUMENT for a
+// malformed or unknown API key, so 400/401/403 all indicate an invalid key.
+func classifyGoogleModels(c *ProviderConfig) func(int) error {
+ return func(status int) error {
+ switch status {
+ case http.StatusOK:
+ return nil
+ case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden:
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status))
default:
- testURL = baseURL + "/models"
+ return ErrValidationUnsupported
}
+ }
+}
- headers["Authorization"] = "Bearer " + apiKey
- case catwalk.TypeAnthropic:
- baseURL, _ := resolver.ResolveValue(c.BaseURL)
- baseURL = cmp.Or(baseURL, "https://api.anthropic.com/v1")
+// classifyZAIModels preserves the historical ZAI-specific behaviour: the
+// `/models` endpoint returns a variety of non-200 statuses even with a valid
+// key, but reliably returns 401 when the key is bad. Treat 401 as invalid
+// and anything else as valid (the endpoint is authoritative about bad keys
+// but noisy about everything else).
+func classifyZAIModels(c *ProviderConfig) func(int) error {
+ return func(status int) error {
+ if status == http.StatusUnauthorized {
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, http.StatusText(status))
+ }
+ return nil
+ }
+}
- switch providerID {
- case catwalk.InferenceKimiCoding:
- testURL = baseURL + "/v1/models"
- default:
- testURL = baseURL + "/models"
+// openaiCompatModelsAllowlist lists openai-compat providers whose `/models`
+// endpoint is known to authenticate the caller (i.e. return 401/403 for a
+// bad key rather than 200 with a public listing). New openai-compat
+// providers should NOT be added here unless their `/models` behaviour has
+// been confirmed to gate on auth — otherwise they should use the malformed
+// chat-completions probe or return [ErrValidationUnsupported].
+var openaiCompatModelsAllowlist = map[catwalk.InferenceProvider]struct{}{
+ "deepseek": {},
+ catwalk.InferenceProviderGROQ: {},
+ catwalk.InferenceProviderXAI: {},
+ catwalk.InferenceProviderZhipu: {},
+ catwalk.InferenceProviderZhipuCoding: {},
+ catwalk.InferenceProviderCerebras: {},
+ catwalk.InferenceProviderNebius: {},
+ catwalk.InferenceProviderCopilot: {},
+}
+
+// openaiCompatChatProbe builds a malformed-body POST /chat/completions probe
+// for OpenAI-compatible providers whose chat-completions endpoint is known to
+// gate on auth before validating the request body.
+func openaiCompatChatProbe(c *ProviderConfig, baseURL, apiKey string) (*validationProbe, error) {
+ if baseURL == "" {
+ return nil, ErrValidationUnsupported
+ }
+ return &validationProbe{
+ method: http.MethodPost,
+ url: baseURL + "/chat/completions",
+ headers: map[string]string{
+ "Authorization": "Bearer " + apiKey,
+ "Content-Type": "application/json",
+ },
+ // Intentionally malformed: required fields missing so the gateway
+ // rejects the payload after authenticating the caller.
+ body: []byte(`{"__crush_probe__": true}`),
+ classify: classifyOpenAIChatMalformed(c),
+ }, nil
+}
+
+// buildValidationProbe returns the probe to use for this provider, or a
+// sentinel error if verification is impossible without performing inference.
+// A nil probe with a nil error means "the key is valid by virtue of its
+// format and no network probe is necessary" (e.g. Bedrock/Vercel prefix
+// checks).
+func (c *ProviderConfig) buildValidationProbe(resolver VariableResolver) (*validationProbe, error) {
+ providerID := catwalk.InferenceProvider(c.ID)
+ apiKey, _ := resolver.ResolveValue(c.APIKey)
+ baseURL, _ := resolver.ResolveValue(c.BaseURL)
+
+ // Provider-ID-specific probes take precedence over type-based defaults.
+ switch providerID {
+ case catwalk.InferenceProviderMiniMax, catwalk.InferenceProviderMiniMaxChina:
+ base := cmp.Or(baseURL, "https://api.minimax.io/anthropic")
+ return &validationProbe{
+ method: http.MethodGet,
+ url: base + "/v1/models",
+ headers: map[string]string{
+ "x-api-key": apiKey,
+ "anthropic-version": "2023-06-01",
+ },
+ classify: classifyAuthGated(c),
+ }, nil
+ case catwalk.InferenceProviderVenice:
+ base := cmp.Or(baseURL, "https://api.venice.ai/api/v1")
+ return &validationProbe{
+ method: http.MethodGet,
+ url: base + "/api_keys/rate_limits",
+ headers: map[string]string{
+ "Authorization": "Bearer " + apiKey,
+ },
+ classify: classifyAuthGated(c),
+ }, nil
+ case catwalk.InferenceAIHubMix,
+ catwalk.InferenceProviderAvian,
+ catwalk.InferenceProviderCortecs,
+ catwalk.InferenceProviderHuggingFace,
+ catwalk.InferenceProviderIoNet,
+ catwalk.InferenceProviderOpenCodeGo,
+ catwalk.InferenceProviderOpenCodeZen,
+ catwalk.InferenceProviderQiniuCloud,
+ catwalk.InferenceProviderSynthetic:
+ return openaiCompatChatProbe(c, baseURL, apiKey)
+ case catwalk.InferenceProviderChutes, catwalk.InferenceProviderNeuralwatt:
+ // These providers have been observed to return ambiguous responses
+ // for unauthenticated requests, so we cannot safely validate.
+ return nil, ErrValidationUnsupported
+ case catwalk.InferenceProviderZAI:
+ // ZAI's `/models` endpoint is authoritative about bad keys (always
+ // 401) but returns assorted non-200 statuses for valid keys, so it
+ // needs its own classifier.
+ base := baseURL
+ if base == "" {
+ return nil, ErrValidationUnsupported
}
+ return &validationProbe{
+ method: http.MethodGet,
+ url: base + "/models",
+ headers: map[string]string{
+ "Authorization": "Bearer " + apiKey,
+ },
+ classify: classifyZAIModels(c),
+ }, nil
+ }
- headers["x-api-key"] = apiKey
- headers["anthropic-version"] = "2023-06-01"
+ // Type-based defaults for providers without an explicit override.
+ switch c.Type {
+ case catwalk.TypeOpenAI:
+ base := cmp.Or(baseURL, "https://api.openai.com/v1")
+ return &validationProbe{
+ method: http.MethodGet,
+ url: base + "/models",
+ headers: map[string]string{
+ "Authorization": "Bearer " + apiKey,
+ },
+ classify: classifyAuthGated(c),
+ }, nil
+ case catwalk.TypeOpenRouter:
+ base := cmp.Or(baseURL, "https://openrouter.ai/api/v1")
+ return &validationProbe{
+ method: http.MethodGet,
+ url: base + "/credits",
+ headers: map[string]string{
+ "Authorization": "Bearer " + apiKey,
+ },
+ classify: classifyAuthGated(c),
+ }, nil
+ case catwalk.TypeAnthropic:
+ base := cmp.Or(baseURL, "https://api.anthropic.com/v1")
+ testURL := base + "/models"
+ if providerID == catwalk.InferenceKimiCoding {
+ testURL = base + "/v1/models"
+ }
+ return &validationProbe{
+ method: http.MethodGet,
+ url: testURL,
+ headers: map[string]string{
+ "x-api-key": apiKey,
+ "anthropic-version": "2023-06-01",
+ },
+ classify: classifyAuthGated(c),
+ }, nil
case catwalk.TypeGoogle:
- baseURL, _ := resolver.ResolveValue(c.BaseURL)
- baseURL = cmp.Or(baseURL, "https://generativelanguage.googleapis.com")
- testURL = baseURL + "/v1beta/models?key=" + url.QueryEscape(apiKey)
+ base := cmp.Or(baseURL, "https://generativelanguage.googleapis.com")
+ return &validationProbe{
+ method: http.MethodGet,
+ url: base + "/v1beta/models?key=" + url.QueryEscape(apiKey),
+ classify: classifyGoogleModels(c),
+ }, nil
case catwalk.TypeBedrock:
// NOTE: Bedrock has a `/foundation-models` endpoint that we could in
// theory use, but apparently the authorization is region-specific,
- // so it's not so trivial.
- if strings.HasPrefix(apiKey, "ABSK") { // Bedrock API keys
- return nil
+ // so it's not so trivial. Fall back to a prefix check.
+ if strings.HasPrefix(apiKey, "ABSK") {
+ return nil, nil
}
- return errors.New("not a valid bedrock api key")
+ return nil, errors.New("not a valid bedrock api key")
case catwalk.TypeVercel:
// NOTE: Vercel does not validate API keys on the `/models` endpoint.
- if strings.HasPrefix(apiKey, "vck_") { // Vercel API keys
- return nil
+ if strings.HasPrefix(apiKey, "vck_") {
+ return nil, nil
}
- return errors.New("not a valid vercel api key")
+ return nil, errors.New("not a valid vercel api key")
+ case catwalk.TypeOpenAICompat:
+ // Generic openai-compat providers often expose a public /models
+ // endpoint, so hitting it proves nothing about the caller's key.
+ // Only providers we've confirmed to gate /models on auth use the
+ // /models probe; everyone else needs an explicit override above or
+ // returns ErrValidationUnsupported.
+ if _, ok := openaiCompatModelsAllowlist[providerID]; !ok {
+ return nil, ErrValidationUnsupported
+ }
+ if baseURL == "" {
+ return nil, ErrValidationUnsupported
+ }
+ return &validationProbe{
+ method: http.MethodGet,
+ url: baseURL + "/models",
+ headers: map[string]string{
+ "Authorization": "Bearer " + apiKey,
+ },
+ classify: classifyAuthGated(c),
+ }, nil
+ }
+
+ return nil, ErrValidationUnsupported
+}
+
+// TestConnection attempts to prove that the configured API key authenticates
+// with the provider. It returns nil on confirmed success, [ErrValidationUnsupported]
+// when the provider has no deterministic validation probe, or a non-nil error
+// describing the validation failure.
+func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
+ probe, err := c.buildValidationProbe(resolver)
+ if err != nil {
+ return err
+ }
+ if probe == nil {
+ // A nil probe with no error means the configuration was accepted
+ // without needing a network round-trip (e.g. Bedrock/Vercel prefix
+ // checks).
+ return nil
+ }
+ if probe.url == "" {
+ return ErrValidationUnsupported
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
- client := &http.Client{}
- req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
+ var body io.Reader
+ if len(probe.body) > 0 {
+ body = bytes.NewReader(probe.body)
+ }
+ req, err := http.NewRequestWithContext(ctx, probe.method, probe.url, body)
if err != nil {
- return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
+ // Probe construction failures shouldn't surface as low-signal user
+ // errors; treat them as "cannot verify" instead.
+ return ErrValidationUnsupported
}
- for k, v := range headers {
+ for k, v := range probe.headers {
req.Header.Set(k, v)
}
for k, v := range c.ExtraHeaders {
req.Header.Set(k, v)
}
+ client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
- return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
+ return fmt.Errorf("failed to connect to provider %s: %w", c.ID, err)
}
defer resp.Body.Close()
- switch providerID {
- case catwalk.InferenceProviderZAI:
- if resp.StatusCode == http.StatusUnauthorized {
- return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
- }
- default:
- if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("failed to connect to provider %s: %s", c.ID, resp.Status)
- }
- }
- return nil
+ return probe.classify(resp.StatusCode)
}
func resolveEnvs(envs map[string]string) []string {
@@ -0,0 +1,776 @@
+package config
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "charm.land/catwalk/pkg/catwalk"
+ "github.com/stretchr/testify/require"
+)
+
+type capturedRequest struct {
+ method string
+ path string
+ query string
+ headers http.Header
+ body []byte
+}
+
+func newCaptureServer(t *testing.T, status int) (*httptest.Server, *capturedRequest) {
+ t.Helper()
+ captured := &capturedRequest{}
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ captured.method = r.Method
+ captured.path = r.URL.Path
+ captured.query = r.URL.RawQuery
+ captured.headers = r.Header.Clone()
+ captured.body, _ = io.ReadAll(r.Body)
+ w.WriteHeader(status)
+ }))
+ t.Cleanup(srv.Close)
+ return srv, captured
+}
+
+func TestTestConnectionMiniMaxProbe(t *testing.T) {
+ t.Parallel()
+
+ for _, id := range []catwalk.InferenceProvider{
+ catwalk.InferenceProviderMiniMax,
+ catwalk.InferenceProviderMiniMaxChina,
+ } {
+ t.Run(string(id), func(t *testing.T) {
+ t.Parallel()
+ for name, tc := range map[string]struct {
+ status int
+ wantErr error
+ wantNil bool
+ }{
+ "valid": {status: http.StatusOK, wantNil: true},
+ "invalid401": {status: http.StatusUnauthorized},
+ "invalid403": {status: http.StatusForbidden},
+ "unsupported": {status: http.StatusTeapot, wantErr: ErrValidationUnsupported},
+ } {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ srv, captured := newCaptureServer(t, tc.status)
+ c := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeAnthropic,
+ BaseURL: srv.URL,
+ APIKey: "key-abc",
+ }
+ err := c.TestConnection(IdentityResolver())
+ switch {
+ case tc.wantNil:
+ require.NoError(t, err)
+ case tc.wantErr != nil:
+ require.ErrorIs(t, err, tc.wantErr)
+ default:
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ }
+ require.Equal(t, http.MethodGet, captured.method)
+ require.Equal(t, "/v1/models", captured.path)
+ require.Equal(t, "key-abc", captured.headers.Get("x-api-key"))
+ require.Equal(t, "2023-06-01", captured.headers.Get("anthropic-version"))
+ })
+ }
+ })
+ }
+}
+
+func TestTestConnectionVeniceProbe(t *testing.T) {
+ t.Parallel()
+
+ tests := map[string]struct {
+ status int
+ wantErr error
+ wantNil bool
+ }{
+ "valid": {status: http.StatusOK, wantNil: true},
+ "invalid401": {status: http.StatusUnauthorized},
+ "invalid403": {status: http.StatusForbidden},
+ "rateLimited": {status: http.StatusTooManyRequests, wantErr: ErrValidationUnsupported},
+ "paymentReq": {status: http.StatusPaymentRequired, wantErr: ErrValidationUnsupported},
+ "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported},
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ srv, captured := newCaptureServer(t, tc.status)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderVenice),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-venice",
+ }
+ err := c.TestConnection(IdentityResolver())
+ switch {
+ case tc.wantNil:
+ require.NoError(t, err)
+ case tc.wantErr != nil:
+ require.ErrorIs(t, err, tc.wantErr)
+ default:
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ }
+ require.Equal(t, http.MethodGet, captured.method)
+ require.Equal(t, "/api_keys/rate_limits", captured.path)
+ require.Equal(t, "Bearer sk-venice", captured.headers.Get("Authorization"))
+ })
+ }
+}
+
+func TestTestConnectionOpenAICompatChatProbe(t *testing.T) {
+ t.Parallel()
+
+ providers := []catwalk.InferenceProvider{
+ catwalk.InferenceAIHubMix,
+ catwalk.InferenceProviderAvian,
+ catwalk.InferenceProviderCortecs,
+ catwalk.InferenceProviderHuggingFace,
+ catwalk.InferenceProviderIoNet,
+ catwalk.InferenceProviderOpenCodeGo,
+ catwalk.InferenceProviderOpenCodeZen,
+ catwalk.InferenceProviderQiniuCloud,
+ catwalk.InferenceProviderSynthetic,
+ }
+ for _, id := range providers {
+ t.Run(string(id), func(t *testing.T) {
+ t.Parallel()
+ cases := map[string]struct {
+ status int
+ wantErr error
+ wantNil bool
+ }{
+ "authPassed400": {status: http.StatusBadRequest, wantNil: true},
+ "authPassed422": {status: http.StatusUnprocessableEntity, wantNil: true},
+ "invalid401": {status: http.StatusUnauthorized},
+ "invalid403": {status: http.StatusForbidden},
+ "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported},
+ "unexpected200": {status: http.StatusOK, wantErr: ErrValidationUnsupported},
+ "unexpectedOther": {status: http.StatusTeapot, wantErr: ErrValidationUnsupported},
+ }
+ for name, tc := range cases {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ srv, captured := newCaptureServer(t, tc.status)
+ c := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-test",
+ }
+ err := c.TestConnection(IdentityResolver())
+ switch {
+ case tc.wantNil:
+ require.NoError(t, err)
+ case tc.wantErr != nil:
+ require.ErrorIs(t, err, tc.wantErr)
+ default:
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ }
+ require.Equal(t, http.MethodPost, captured.method)
+ require.Equal(t, "/chat/completions", captured.path)
+ require.Equal(t, "Bearer sk-test", captured.headers.Get("Authorization"))
+ require.Equal(t, "application/json", captured.headers.Get("Content-Type"))
+ require.NotEmpty(t, captured.body)
+ })
+ }
+ })
+ }
+}
+
+func TestTestConnectionUnsupportedProviders(t *testing.T) {
+ t.Parallel()
+
+ for _, id := range []catwalk.InferenceProvider{
+ catwalk.InferenceProviderChutes,
+ catwalk.InferenceProviderNeuralwatt,
+ } {
+ t.Run(string(id), func(t *testing.T) {
+ t.Parallel()
+ c := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: "https://example.invalid",
+ APIKey: "sk-test",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.ErrorIs(t, err, ErrValidationUnsupported)
+ })
+ }
+}
+
+func TestTestConnectionUnknownOpenAICompatIsUnsupported(t *testing.T) {
+ t.Parallel()
+
+ c := &ProviderConfig{
+ ID: "some-new-openai-compat-provider",
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: "https://example.invalid",
+ APIKey: "sk-test",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.ErrorIs(t, err, ErrValidationUnsupported)
+}
+
+func TestTestConnectionEmptyProbeURLIsUnsupported(t *testing.T) {
+ t.Parallel()
+
+ // Chutes has a provider override that returns ErrValidationUnsupported
+ // regardless of configured base URL; this also guards the empty-URL path.
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderChutes),
+ Type: catwalk.TypeOpenAICompat,
+ APIKey: "sk-test",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.ErrorIs(t, err, ErrValidationUnsupported)
+}
+
+func TestTestConnectionExtraHeadersAreApplied(t *testing.T) {
+ t.Parallel()
+
+ srv, captured := newCaptureServer(t, http.StatusBadRequest)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderSynthetic),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-test",
+ ExtraHeaders: map[string]string{
+ "X-Custom-Header": "custom-value",
+ "Authorization": "overridden",
+ },
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.NoError(t, err)
+ require.Equal(t, "custom-value", captured.headers.Get("X-Custom-Header"))
+ // ExtraHeaders are applied after the probe headers, so callers can
+ // override per-provider defaults if necessary.
+ require.Equal(t, "overridden", captured.headers.Get("Authorization"))
+}
+
+func TestTestConnectionOpenAITypeProbesModelsEndpoint(t *testing.T) {
+ t.Parallel()
+
+ srv, captured := newCaptureServer(t, http.StatusOK)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderOpenAI),
+ Type: catwalk.TypeOpenAI,
+ BaseURL: srv.URL,
+ APIKey: "sk-openai",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.NoError(t, err)
+ require.Equal(t, http.MethodGet, captured.method)
+ require.Equal(t, "/models", captured.path)
+ require.Equal(t, "Bearer sk-openai", captured.headers.Get("Authorization"))
+}
+
+func TestTestConnectionOpenRouterProbesCreditsEndpoint(t *testing.T) {
+ t.Parallel()
+
+ srv, captured := newCaptureServer(t, http.StatusOK)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderOpenRouter),
+ Type: catwalk.TypeOpenRouter,
+ BaseURL: srv.URL,
+ APIKey: "sk-or",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.NoError(t, err)
+ require.Equal(t, "/credits", captured.path)
+}
+
+func TestTestConnectionAnthropicTypeProbesModels(t *testing.T) {
+ t.Parallel()
+
+ srv, captured := newCaptureServer(t, http.StatusOK)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderAnthropic),
+ Type: catwalk.TypeAnthropic,
+ BaseURL: srv.URL,
+ APIKey: "ak-test",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.NoError(t, err)
+ require.Equal(t, "/models", captured.path)
+ require.Equal(t, "ak-test", captured.headers.Get("x-api-key"))
+}
+
+func TestTestConnectionKimiCodingUsesV1Models(t *testing.T) {
+ t.Parallel()
+
+ srv, captured := newCaptureServer(t, http.StatusOK)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceKimiCoding),
+ Type: catwalk.TypeAnthropic,
+ BaseURL: srv.URL,
+ APIKey: "ak-kimi",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.NoError(t, err)
+ require.Equal(t, "/v1/models", captured.path)
+}
+
+func TestTestConnectionGoogleIncludesKeyQueryParam(t *testing.T) {
+ t.Parallel()
+
+ srv, captured := newCaptureServer(t, http.StatusOK)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderGemini),
+ Type: catwalk.TypeGoogle,
+ BaseURL: srv.URL,
+ APIKey: "google-key",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.NoError(t, err)
+ require.Equal(t, "/v1beta/models", captured.path)
+ require.Contains(t, captured.query, "key=google-key")
+}
+
+// TestTestConnectionGoogleBadKeyIs400 locks in the fact that Google returns
+// 400 INVALID_ARGUMENT (not 401) for an unknown API key, so 400 must map to
+// "invalid" and never to [ErrValidationUnsupported].
+func TestTestConnectionGoogleBadKeyIs400(t *testing.T) {
+ t.Parallel()
+
+ for name, tc := range map[string]struct {
+ status int
+ wantNil bool
+ wantErr error
+ }{
+ "badKey400": {status: http.StatusBadRequest},
+ "unauth401": {status: http.StatusUnauthorized},
+ "forbidden403": {status: http.StatusForbidden},
+ "ok200": {status: http.StatusOK, wantNil: true},
+ "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported},
+ } {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ srv, _ := newCaptureServer(t, tc.status)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderGemini),
+ Type: catwalk.TypeGoogle,
+ BaseURL: srv.URL,
+ APIKey: "bad-key",
+ }
+ err := c.TestConnection(IdentityResolver())
+ switch {
+ case tc.wantNil:
+ require.NoError(t, err)
+ case tc.wantErr != nil:
+ require.ErrorIs(t, err, tc.wantErr)
+ default:
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ }
+ })
+ }
+}
+
+// TestTestConnectionOpenAICompatAllowlistUsesModelsProbe locks in the
+// `/models` probe for openai-compat providers whose /models is known to be
+// auth-gated. These providers must not fall through to
+// [ErrValidationUnsupported].
+func TestTestConnectionOpenAICompatAllowlistUsesModelsProbe(t *testing.T) {
+ t.Parallel()
+
+ providers := []catwalk.InferenceProvider{
+ "deepseek",
+ catwalk.InferenceProviderGROQ,
+ catwalk.InferenceProviderXAI,
+ catwalk.InferenceProviderZhipu,
+ catwalk.InferenceProviderZhipuCoding,
+ catwalk.InferenceProviderCerebras,
+ catwalk.InferenceProviderNebius,
+ catwalk.InferenceProviderCopilot,
+ }
+ for _, id := range providers {
+ t.Run(string(id), func(t *testing.T) {
+ t.Parallel()
+ t.Run("valid", func(t *testing.T) {
+ t.Parallel()
+ srv, captured := newCaptureServer(t, http.StatusOK)
+ c := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-good",
+ }
+ require.NoError(t, c.TestConnection(IdentityResolver()))
+ require.Equal(t, http.MethodGet, captured.method)
+ require.Equal(t, "/models", captured.path)
+ require.Equal(t, "Bearer sk-good", captured.headers.Get("Authorization"))
+ })
+ t.Run("invalid", func(t *testing.T) {
+ t.Parallel()
+ srv, _ := newCaptureServer(t, http.StatusUnauthorized)
+ c := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-bad",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ })
+ })
+ }
+}
+
+// TestTestConnectionZAIUsesZAIClassifier pins ZAI's historical quirk: /models
+// returns non-200 for valid keys but always 401 for bad keys.
+func TestTestConnectionZAIUsesZAIClassifier(t *testing.T) {
+ t.Parallel()
+
+ for name, tc := range map[string]struct {
+ status int
+ wantNil bool
+ }{
+ "ok200": {status: http.StatusOK, wantNil: true},
+ "other400": {status: http.StatusBadRequest, wantNil: true},
+ "other500": {status: http.StatusInternalServerError, wantNil: true},
+ "badKey401": {status: http.StatusUnauthorized},
+ } {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ srv, captured := newCaptureServer(t, tc.status)
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderZAI),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-zai",
+ }
+ err := c.TestConnection(IdentityResolver())
+ if tc.wantNil {
+ require.NoError(t, err)
+ } else {
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ }
+ require.Equal(t, "/models", captured.path)
+ require.Equal(t, "Bearer sk-zai", captured.headers.Get("Authorization"))
+ })
+ }
+}
+
+func TestTestConnectionBedrockPrefix(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid", func(t *testing.T) {
+ t.Parallel()
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderBedrock),
+ Type: catwalk.TypeBedrock,
+ APIKey: "ABSK-secret",
+ }
+ require.NoError(t, c.TestConnection(IdentityResolver()))
+ })
+ t.Run("invalid", func(t *testing.T) {
+ t.Parallel()
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderBedrock),
+ Type: catwalk.TypeBedrock,
+ APIKey: "nope",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ })
+}
+
+func TestTestConnectionVercelPrefix(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid", func(t *testing.T) {
+ t.Parallel()
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderVercel),
+ Type: catwalk.TypeVercel,
+ APIKey: "vck_abc",
+ }
+ require.NoError(t, c.TestConnection(IdentityResolver()))
+ })
+ t.Run("invalid", func(t *testing.T) {
+ t.Parallel()
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderVercel),
+ Type: catwalk.TypeVercel,
+ APIKey: "nope",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.Error(t, err)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ })
+}
+
+// TestTestConnectionPublicModelsAuthGatedChatRegression locks in the core
+// regression from the 2025-10-20 expansion of generic /models validation to
+// openai-compat: a provider whose /models is intentionally public would
+// report any key as "validated" even though /chat/completions actually
+// gates on auth. For every provider we currently mark "validated" via the
+// malformed-body chat probe, this test simulates both endpoints and asserts
+// that:
+//
+// 1. A bad key (401 on /chat/completions) is reported as invalid, not as
+// "validated" — even when /models returns 200 unauthenticated.
+// 2. A good key (400/422 on /chat/completions) is reported as valid.
+// 3. The probe never hits /models for these providers.
+func TestTestConnectionPublicModelsAuthGatedChatRegression(t *testing.T) {
+ t.Parallel()
+
+ providers := []catwalk.InferenceProvider{
+ catwalk.InferenceAIHubMix,
+ catwalk.InferenceProviderAvian,
+ catwalk.InferenceProviderCortecs,
+ catwalk.InferenceProviderHuggingFace,
+ catwalk.InferenceProviderIoNet,
+ catwalk.InferenceProviderOpenCodeGo,
+ catwalk.InferenceProviderOpenCodeZen,
+ catwalk.InferenceProviderQiniuCloud,
+ catwalk.InferenceProviderSynthetic,
+ }
+ for _, id := range providers {
+ t.Run(string(id), func(t *testing.T) {
+ t.Parallel()
+
+ type hits struct {
+ models int
+ chat int
+ }
+ for name, tc := range map[string]struct {
+ chatStatus int
+ wantErr error
+ wantNil bool
+ }{
+ "badKeyIsInvalidNotValidated": {
+ chatStatus: http.StatusUnauthorized,
+ },
+ "goodKeyIsValidated": {
+ chatStatus: http.StatusBadRequest,
+ wantNil: true,
+ },
+ "forbiddenKeyIsInvalid": {
+ chatStatus: http.StatusForbidden,
+ },
+ "schemaFailure422IsValidated": {
+ chatStatus: http.StatusUnprocessableEntity,
+ wantNil: true,
+ },
+ } {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ h := &hits{}
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/models":
+ // Simulate a public /models endpoint that
+ // returns 200 regardless of the provided key.
+ h.models++
+ w.WriteHeader(http.StatusOK)
+ case "/chat/completions":
+ h.chat++
+ w.WriteHeader(tc.chatStatus)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ t.Cleanup(srv.Close)
+
+ c := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-test",
+ }
+ err := c.TestConnection(IdentityResolver())
+
+ if tc.wantNil {
+ require.NoError(t, err, "expected %s to validate on %d", id, tc.chatStatus)
+ } else {
+ require.Error(t, err, "expected %s to reject on %d", id, tc.chatStatus)
+ require.NotErrorIs(t, err, ErrValidationUnsupported)
+ }
+ require.Equal(t, 0, h.models, "probe must not rely on public /models for %s", id)
+ require.Equal(t, 1, h.chat, "probe must hit /chat/completions for %s", id)
+ })
+ }
+ })
+ }
+}
+
+// TestTestConnectionOpenAICompatProviderAudit is an audit table that pins the
+// full set of openai-compat providers currently exposed as "validated" (i.e.
+// TestConnection can return nil on some response) and documents the exact
+// probe each uses. Adding a new openai-compat provider to the validated set
+// MUST update this table; this prevents silent drift back into the
+// "assume /models proves auth" bug class.
+//
+// Providers not listed here either:
+// - use a different Type (TypeOpenAI / TypeAnthropic / TypeGoogle / ...);
+// - are explicitly gated behind ErrValidationUnsupported (chutes, neuralwatt,
+// and every unknown openai-compat provider).
+func TestTestConnectionOpenAICompatProviderAudit(t *testing.T) {
+ t.Parallel()
+
+ audit := map[catwalk.InferenceProvider]auditCase{
+ catwalk.InferenceProviderVenice: {
+ method: http.MethodGet,
+ path: "/api_keys/rate_limits",
+ validStatus: http.StatusOK,
+ invalidStatus: http.StatusUnauthorized,
+ authHeader: "Authorization",
+ authValue: "Bearer sk-test",
+ },
+ catwalk.InferenceAIHubMix: openaiCompatAuditCase(),
+ catwalk.InferenceProviderAvian: openaiCompatAuditCase(),
+ catwalk.InferenceProviderCortecs: openaiCompatAuditCase(),
+ catwalk.InferenceProviderHuggingFace: openaiCompatAuditCase(),
+ catwalk.InferenceProviderIoNet: openaiCompatAuditCase(),
+ catwalk.InferenceProviderOpenCodeGo: openaiCompatAuditCase(),
+ catwalk.InferenceProviderOpenCodeZen: openaiCompatAuditCase(),
+ catwalk.InferenceProviderQiniuCloud: openaiCompatAuditCase(),
+ catwalk.InferenceProviderSynthetic: openaiCompatAuditCase(),
+ // openai-compat providers with auth-gated /models (allowlist).
+ "deepseek": openaiCompatModelsAuditCase(),
+ catwalk.InferenceProviderGROQ: openaiCompatModelsAuditCase(),
+ catwalk.InferenceProviderXAI: openaiCompatModelsAuditCase(),
+ catwalk.InferenceProviderZhipu: openaiCompatModelsAuditCase(),
+ catwalk.InferenceProviderZhipuCoding: openaiCompatModelsAuditCase(),
+ catwalk.InferenceProviderCerebras: openaiCompatModelsAuditCase(),
+ catwalk.InferenceProviderNebius: openaiCompatModelsAuditCase(),
+ catwalk.InferenceProviderCopilot: openaiCompatModelsAuditCase(),
+ // ZAI uses the /models endpoint but with its own classifier that
+ // only treats 401 as invalid. Its valid path must therefore be 200
+ // here for the audit's generic "valid -> nil" check to hold.
+ catwalk.InferenceProviderZAI: {
+ method: http.MethodGet,
+ path: "/models",
+ validStatus: http.StatusOK,
+ invalidStatus: http.StatusUnauthorized,
+ authHeader: "Authorization",
+ authValue: "Bearer sk-test",
+ },
+ }
+
+ for id, tc := range audit {
+ t.Run(string(id), func(t *testing.T) {
+ t.Parallel()
+
+ // 1) Valid path.
+ srv, captured := newCaptureServer(t, tc.validStatus)
+ c := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv.URL,
+ APIKey: "sk-test",
+ }
+ require.NoError(t, c.TestConnection(IdentityResolver()))
+ require.Equal(t, tc.method, captured.method, "audit: wrong method for %s", id)
+ require.Equal(t, tc.path, captured.path, "audit: wrong path for %s", id)
+ require.Equal(t, tc.authValue, captured.headers.Get(tc.authHeader),
+ "audit: wrong auth header for %s", id)
+
+ // 2) Invalid path.
+ srv2, _ := newCaptureServer(t, tc.invalidStatus)
+ c2 := &ProviderConfig{
+ ID: string(id),
+ Type: catwalk.TypeOpenAICompat,
+ BaseURL: srv2.URL,
+ APIKey: "sk-test",
+ }
+ err := c2.TestConnection(IdentityResolver())
+ require.Error(t, err, "audit: %s must reject %d as invalid", id, tc.invalidStatus)
+ require.NotErrorIs(t, err, ErrValidationUnsupported,
+ "audit: %s must not leak ErrValidationUnsupported on %d", id, tc.invalidStatus)
+ })
+ }
+
+ // Sanity: every provider that currently enters the openai-compat chat
+ // probe path must appear in the audit. This guards against a future
+ // refactor silently adding a provider without test coverage.
+ chatProbeProviders := []catwalk.InferenceProvider{
+ catwalk.InferenceAIHubMix,
+ catwalk.InferenceProviderAvian,
+ catwalk.InferenceProviderCortecs,
+ catwalk.InferenceProviderHuggingFace,
+ catwalk.InferenceProviderIoNet,
+ catwalk.InferenceProviderOpenCodeGo,
+ catwalk.InferenceProviderOpenCodeZen,
+ catwalk.InferenceProviderQiniuCloud,
+ catwalk.InferenceProviderSynthetic,
+ }
+ for _, id := range chatProbeProviders {
+ _, ok := audit[id]
+ require.True(t, ok, "audit table missing entry for %s", id)
+ }
+}
+
+// auditCase pins the expected probe shape for a given provider.
+type auditCase struct {
+ method string
+ path string
+ // validStatus is a response code the probe must translate to
+ // "validated" (nil error).
+ validStatus int
+ // invalidStatus is a response code the probe must translate to an
+ // invalid-key error (not ErrValidationUnsupported).
+ invalidStatus int
+ // authHeader is the name of the header the probe uses to present
+ // the key.
+ authHeader string
+ authValue string
+}
+
+func openaiCompatAuditCase() auditCase {
+ return auditCase{
+ method: http.MethodPost,
+ path: "/chat/completions",
+ validStatus: http.StatusBadRequest,
+ invalidStatus: http.StatusUnauthorized,
+ authHeader: "Authorization",
+ authValue: "Bearer sk-test",
+ }
+}
+
+func openaiCompatModelsAuditCase() auditCase {
+ return auditCase{
+ method: http.MethodGet,
+ path: "/models",
+ validStatus: http.StatusOK,
+ invalidStatus: http.StatusUnauthorized,
+ authHeader: "Authorization",
+ authValue: "Bearer sk-test",
+ }
+}
+
+func TestTestConnectionNetworkErrorIsNotInvalidKey(t *testing.T) {
+ t.Parallel()
+
+ // Start and immediately close a server so the next request fails at the
+ // TCP layer. That should produce a non-nil error that is *not*
+ // ErrValidationUnsupported (transport errors still surface).
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ srv.Close()
+ c := &ProviderConfig{
+ ID: string(catwalk.InferenceProviderOpenAI),
+ Type: catwalk.TypeOpenAI,
+ BaseURL: srv.URL,
+ APIKey: "sk-test",
+ }
+ err := c.TestConnection(IdentityResolver())
+ require.Error(t, err)
+ // The error message should mention the provider so users see a useful
+ // hint, even though we can't classify the status code.
+ require.True(t, strings.Contains(err.Error(), "openai") || errors.Is(err, ErrValidationUnsupported))
+}
@@ -0,0 +1,162 @@
+package dialog
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "charm.land/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+// TestAPIKeyStateForVerifyErr pins the mapping between TestConnection errors
+// and the dialog state the UI should transition into. In particular, an
+// [config.ErrValidationUnsupported] error must yield the unverified state
+// (so the UI shows "saved (not verified)" instead of "invalid").
+func TestAPIKeyStateForVerifyErr(t *testing.T) {
+ t.Parallel()
+
+ tests := map[string]struct {
+ err error
+ want APIKeyInputState
+ }{
+ "nilIsVerified": {
+ err: nil,
+ want: APIKeyInputStateVerified,
+ },
+ "unsupportedIsUnverified": {
+ err: config.ErrValidationUnsupported,
+ want: APIKeyInputStateUnverified,
+ },
+ "wrappedUnsupportedIsUnverified": {
+ err: fmt.Errorf("probing provider: %w", config.ErrValidationUnsupported),
+ want: APIKeyInputStateUnverified,
+ },
+ "plainErrorIsError": {
+ err: errors.New("bad key"),
+ want: APIKeyInputStateError,
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tc.want, apiKeyStateForVerifyErr(tc.err))
+ })
+ }
+}
+
+// TestProviderConfigForVerifyPropagatesDefaultHeaders locks in the UI
+// contract that any catwalk-declared DefaultHeaders are copied into the
+// ProviderConfig.ExtraHeaders used for the validation probe. Without this,
+// providers that require routing/tenant headers (e.g. DefaultHeaders
+// supplied by the catwalk provider definition) would probe with a stripped
+// header set and potentially be misclassified as "not verified".
+func TestProviderConfigForVerifyPropagatesDefaultHeaders(t *testing.T) {
+ t.Parallel()
+
+ provider := catwalk.Provider{
+ ID: "test-provider",
+ Name: "Test Provider",
+ Type: catwalk.TypeOpenAICompat,
+ APIEndpoint: "https://example.invalid",
+ DefaultHeaders: map[string]string{
+ "X-Tenant": "acme",
+ "X-Route": "primary",
+ "X-Shared": "from-default",
+ "User-Agent": "crush-test",
+ },
+ }
+ cfg := providerConfigForVerify(provider, "sk-test")
+
+ require.Equal(t, string(provider.ID), cfg.ID)
+ require.Equal(t, provider.Name, cfg.Name)
+ require.Equal(t, provider.Type, cfg.Type)
+ require.Equal(t, provider.APIEndpoint, cfg.BaseURL)
+ require.Equal(t, "sk-test", cfg.APIKey)
+ require.Equal(t, provider.DefaultHeaders, cfg.ExtraHeaders,
+ "DefaultHeaders must be propagated to ExtraHeaders")
+
+ // Mutating the returned config must not leak back into the provider
+ // definition (the dialog reuses the provider value across retries).
+ cfg.ExtraHeaders["X-Tenant"] = "mutated"
+ require.Equal(t, "acme", provider.DefaultHeaders["X-Tenant"],
+ "providerConfigForVerify must copy DefaultHeaders, not alias them")
+}
+
+func TestProviderConfigForVerifyWithNoDefaultHeaders(t *testing.T) {
+ t.Parallel()
+
+ provider := catwalk.Provider{
+ ID: "test-provider",
+ Type: catwalk.TypeOpenAICompat,
+ APIEndpoint: "https://example.invalid",
+ }
+ cfg := providerConfigForVerify(provider, "sk-test")
+ require.Nil(t, cfg.ExtraHeaders,
+ "no DefaultHeaders should yield no ExtraHeaders allocation")
+}
+
+// TestProviderConfigForVerifyHeadersHitTheWire is an end-to-end UI-level
+// check: after providerConfigForVerify builds the probe config, calling
+// TestConnection against a local server must deliver the DefaultHeaders on
+// the outbound request. This guards against silent regressions where the
+// header map is dropped between the dialog and the HTTP layer.
+func TestProviderConfigForVerifyHeadersHitTheWire(t *testing.T) {
+ t.Parallel()
+
+ var captured http.Header
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ captured = r.Header.Clone()
+ _, _ = io.Copy(io.Discard, r.Body)
+ // 400 on the malformed-body chat-completions probe means "auth
+ // passed, schema rejected" for the Synthetic-style override.
+ w.WriteHeader(http.StatusBadRequest)
+ }))
+ t.Cleanup(srv.Close)
+
+ provider := catwalk.Provider{
+ ID: catwalk.InferenceProviderSynthetic,
+ Name: "Synthetic",
+ Type: catwalk.TypeOpenAICompat,
+ APIEndpoint: srv.URL,
+ DefaultHeaders: map[string]string{
+ "X-Tenant": "acme",
+ "X-Route": "primary",
+ },
+ }
+ cfg := providerConfigForVerify(provider, "sk-test")
+ require.NoError(t, cfg.TestConnection(config.IdentityResolver()))
+
+ require.NotNil(t, captured, "probe must have reached the test server")
+ require.Equal(t, "acme", captured.Get("X-Tenant"))
+ require.Equal(t, "primary", captured.Get("X-Route"))
+ // Probe-defined headers should still be present alongside the
+ // DefaultHeaders.
+ require.Equal(t, "Bearer sk-test", captured.Get("Authorization"))
+}
+
+// TestAPIKeyInputStatesAreDistinct guards against someone accidentally making
+// APIKeyInputStateUnverified equal to one of the other states (which would
+// silently collapse the "saved (not verified)" path onto "validated" or
+// "invalid").
+func TestAPIKeyInputStatesAreDistinct(t *testing.T) {
+ t.Parallel()
+
+ states := []APIKeyInputState{
+ APIKeyInputStateInitial,
+ APIKeyInputStateVerifying,
+ APIKeyInputStateVerified,
+ APIKeyInputStateUnverified,
+ APIKeyInputStateError,
+ }
+ seen := map[APIKeyInputState]struct{}{}
+ for _, s := range states {
+ _, dup := seen[s]
+ require.False(t, dup, "state %d declared twice", s)
+ seen[s] = struct{}{}
+ }
+}