1package config
2
3import (
4 "errors"
5 "io"
6 "net/http"
7 "net/http/httptest"
8 "strings"
9 "testing"
10
11 "charm.land/catwalk/pkg/catwalk"
12 "github.com/stretchr/testify/require"
13)
14
15type capturedRequest struct {
16 method string
17 path string
18 query string
19 headers http.Header
20 body []byte
21}
22
23func newCaptureServer(t *testing.T, status int) (*httptest.Server, *capturedRequest) {
24 t.Helper()
25 captured := &capturedRequest{}
26 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27 captured.method = r.Method
28 captured.path = r.URL.Path
29 captured.query = r.URL.RawQuery
30 captured.headers = r.Header.Clone()
31 captured.body, _ = io.ReadAll(r.Body)
32 w.WriteHeader(status)
33 }))
34 t.Cleanup(srv.Close)
35 return srv, captured
36}
37
38func TestTestConnectionMiniMaxProbe(t *testing.T) {
39 t.Parallel()
40
41 for _, id := range []catwalk.InferenceProvider{
42 catwalk.InferenceProviderMiniMax,
43 catwalk.InferenceProviderMiniMaxChina,
44 } {
45 t.Run(string(id), func(t *testing.T) {
46 t.Parallel()
47 for name, tc := range map[string]struct {
48 status int
49 wantErr error
50 wantNil bool
51 }{
52 "valid": {status: http.StatusOK, wantNil: true},
53 "invalid401": {status: http.StatusUnauthorized},
54 "invalid403": {status: http.StatusForbidden},
55 "unsupported": {status: http.StatusTeapot, wantErr: ErrValidationUnsupported},
56 } {
57 t.Run(name, func(t *testing.T) {
58 t.Parallel()
59 srv, captured := newCaptureServer(t, tc.status)
60 c := &ProviderConfig{
61 ID: string(id),
62 Type: catwalk.TypeAnthropic,
63 BaseURL: srv.URL,
64 APIKey: "key-abc",
65 }
66 err := c.TestConnection(IdentityResolver())
67 switch {
68 case tc.wantNil:
69 require.NoError(t, err)
70 case tc.wantErr != nil:
71 require.ErrorIs(t, err, tc.wantErr)
72 default:
73 require.Error(t, err)
74 require.NotErrorIs(t, err, ErrValidationUnsupported)
75 }
76 require.Equal(t, http.MethodGet, captured.method)
77 require.Equal(t, "/v1/models", captured.path)
78 require.Equal(t, "key-abc", captured.headers.Get("x-api-key"))
79 require.Equal(t, "2023-06-01", captured.headers.Get("anthropic-version"))
80 })
81 }
82 })
83 }
84}
85
86func TestTestConnectionVeniceProbe(t *testing.T) {
87 t.Parallel()
88
89 tests := map[string]struct {
90 status int
91 wantErr error
92 wantNil bool
93 }{
94 "valid": {status: http.StatusOK, wantNil: true},
95 "invalid401": {status: http.StatusUnauthorized},
96 "invalid403": {status: http.StatusForbidden},
97 "rateLimited": {status: http.StatusTooManyRequests, wantErr: ErrValidationUnsupported},
98 "paymentReq": {status: http.StatusPaymentRequired, wantErr: ErrValidationUnsupported},
99 "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported},
100 }
101 for name, tc := range tests {
102 t.Run(name, func(t *testing.T) {
103 t.Parallel()
104 srv, captured := newCaptureServer(t, tc.status)
105 c := &ProviderConfig{
106 ID: string(catwalk.InferenceProviderVenice),
107 Type: catwalk.TypeOpenAICompat,
108 BaseURL: srv.URL,
109 APIKey: "sk-venice",
110 }
111 err := c.TestConnection(IdentityResolver())
112 switch {
113 case tc.wantNil:
114 require.NoError(t, err)
115 case tc.wantErr != nil:
116 require.ErrorIs(t, err, tc.wantErr)
117 default:
118 require.Error(t, err)
119 require.NotErrorIs(t, err, ErrValidationUnsupported)
120 }
121 require.Equal(t, http.MethodGet, captured.method)
122 require.Equal(t, "/api_keys/rate_limits", captured.path)
123 require.Equal(t, "Bearer sk-venice", captured.headers.Get("Authorization"))
124 })
125 }
126}
127
128func TestTestConnectionOpenAICompatChatProbe(t *testing.T) {
129 t.Parallel()
130
131 providers := []catwalk.InferenceProvider{
132 catwalk.InferenceAIHubMix,
133 catwalk.InferenceProviderAvian,
134 catwalk.InferenceProviderCortecs,
135 catwalk.InferenceProviderHuggingFace,
136 catwalk.InferenceProviderIoNet,
137 catwalk.InferenceProviderOpenCodeGo,
138 catwalk.InferenceProviderOpenCodeZen,
139 catwalk.InferenceProviderQiniuCloud,
140 catwalk.InferenceProviderSynthetic,
141 }
142 for _, id := range providers {
143 t.Run(string(id), func(t *testing.T) {
144 t.Parallel()
145 cases := map[string]struct {
146 status int
147 wantErr error
148 wantNil bool
149 }{
150 "authPassed400": {status: http.StatusBadRequest, wantNil: true},
151 "authPassed422": {status: http.StatusUnprocessableEntity, wantNil: true},
152 "invalid401": {status: http.StatusUnauthorized},
153 "invalid403": {status: http.StatusForbidden},
154 "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported},
155 "unexpected200": {status: http.StatusOK, wantErr: ErrValidationUnsupported},
156 "unexpectedOther": {status: http.StatusTeapot, wantErr: ErrValidationUnsupported},
157 }
158 for name, tc := range cases {
159 t.Run(name, func(t *testing.T) {
160 t.Parallel()
161 srv, captured := newCaptureServer(t, tc.status)
162 c := &ProviderConfig{
163 ID: string(id),
164 Type: catwalk.TypeOpenAICompat,
165 BaseURL: srv.URL,
166 APIKey: "sk-test",
167 }
168 err := c.TestConnection(IdentityResolver())
169 switch {
170 case tc.wantNil:
171 require.NoError(t, err)
172 case tc.wantErr != nil:
173 require.ErrorIs(t, err, tc.wantErr)
174 default:
175 require.Error(t, err)
176 require.NotErrorIs(t, err, ErrValidationUnsupported)
177 }
178 require.Equal(t, http.MethodPost, captured.method)
179 require.Equal(t, "/chat/completions", captured.path)
180 require.Equal(t, "Bearer sk-test", captured.headers.Get("Authorization"))
181 require.Equal(t, "application/json", captured.headers.Get("Content-Type"))
182 require.NotEmpty(t, captured.body)
183 })
184 }
185 })
186 }
187}
188
189func TestTestConnectionUnsupportedProviders(t *testing.T) {
190 t.Parallel()
191
192 for _, id := range []catwalk.InferenceProvider{
193 catwalk.InferenceProviderChutes,
194 catwalk.InferenceProviderNeuralwatt,
195 } {
196 t.Run(string(id), func(t *testing.T) {
197 t.Parallel()
198 c := &ProviderConfig{
199 ID: string(id),
200 Type: catwalk.TypeOpenAICompat,
201 BaseURL: "https://example.invalid",
202 APIKey: "sk-test",
203 }
204 err := c.TestConnection(IdentityResolver())
205 require.ErrorIs(t, err, ErrValidationUnsupported)
206 })
207 }
208}
209
210func TestTestConnectionUnknownOpenAICompatIsUnsupported(t *testing.T) {
211 t.Parallel()
212
213 c := &ProviderConfig{
214 ID: "some-new-openai-compat-provider",
215 Type: catwalk.TypeOpenAICompat,
216 BaseURL: "https://example.invalid",
217 APIKey: "sk-test",
218 }
219 err := c.TestConnection(IdentityResolver())
220 require.ErrorIs(t, err, ErrValidationUnsupported)
221}
222
223func TestTestConnectionEmptyProbeURLIsUnsupported(t *testing.T) {
224 t.Parallel()
225
226 // Chutes has a provider override that returns ErrValidationUnsupported
227 // regardless of configured base URL; this also guards the empty-URL path.
228 c := &ProviderConfig{
229 ID: string(catwalk.InferenceProviderChutes),
230 Type: catwalk.TypeOpenAICompat,
231 APIKey: "sk-test",
232 }
233 err := c.TestConnection(IdentityResolver())
234 require.ErrorIs(t, err, ErrValidationUnsupported)
235}
236
237func TestTestConnectionExtraHeadersAreApplied(t *testing.T) {
238 t.Parallel()
239
240 srv, captured := newCaptureServer(t, http.StatusBadRequest)
241 c := &ProviderConfig{
242 ID: string(catwalk.InferenceProviderSynthetic),
243 Type: catwalk.TypeOpenAICompat,
244 BaseURL: srv.URL,
245 APIKey: "sk-test",
246 ExtraHeaders: map[string]string{
247 "X-Custom-Header": "custom-value",
248 "Authorization": "overridden",
249 },
250 }
251 err := c.TestConnection(IdentityResolver())
252 require.NoError(t, err)
253 require.Equal(t, "custom-value", captured.headers.Get("X-Custom-Header"))
254 // ExtraHeaders are applied after the probe headers, so callers can
255 // override per-provider defaults if necessary.
256 require.Equal(t, "overridden", captured.headers.Get("Authorization"))
257}
258
259func TestTestConnectionOpenAITypeProbesModelsEndpoint(t *testing.T) {
260 t.Parallel()
261
262 srv, captured := newCaptureServer(t, http.StatusOK)
263 c := &ProviderConfig{
264 ID: string(catwalk.InferenceProviderOpenAI),
265 Type: catwalk.TypeOpenAI,
266 BaseURL: srv.URL,
267 APIKey: "sk-openai",
268 }
269 err := c.TestConnection(IdentityResolver())
270 require.NoError(t, err)
271 require.Equal(t, http.MethodGet, captured.method)
272 require.Equal(t, "/models", captured.path)
273 require.Equal(t, "Bearer sk-openai", captured.headers.Get("Authorization"))
274}
275
276func TestTestConnectionOpenRouterProbesCreditsEndpoint(t *testing.T) {
277 t.Parallel()
278
279 srv, captured := newCaptureServer(t, http.StatusOK)
280 c := &ProviderConfig{
281 ID: string(catwalk.InferenceProviderOpenRouter),
282 Type: catwalk.TypeOpenRouter,
283 BaseURL: srv.URL,
284 APIKey: "sk-or",
285 }
286 err := c.TestConnection(IdentityResolver())
287 require.NoError(t, err)
288 require.Equal(t, "/credits", captured.path)
289}
290
291func TestTestConnectionAnthropicTypeProbesModels(t *testing.T) {
292 t.Parallel()
293
294 srv, captured := newCaptureServer(t, http.StatusOK)
295 c := &ProviderConfig{
296 ID: string(catwalk.InferenceProviderAnthropic),
297 Type: catwalk.TypeAnthropic,
298 BaseURL: srv.URL,
299 APIKey: "ak-test",
300 }
301 err := c.TestConnection(IdentityResolver())
302 require.NoError(t, err)
303 require.Equal(t, "/models", captured.path)
304 require.Equal(t, "ak-test", captured.headers.Get("x-api-key"))
305}
306
307func TestTestConnectionKimiCodingUsesV1Models(t *testing.T) {
308 t.Parallel()
309
310 srv, captured := newCaptureServer(t, http.StatusOK)
311 c := &ProviderConfig{
312 ID: string(catwalk.InferenceKimiCoding),
313 Type: catwalk.TypeAnthropic,
314 BaseURL: srv.URL,
315 APIKey: "ak-kimi",
316 }
317 err := c.TestConnection(IdentityResolver())
318 require.NoError(t, err)
319 require.Equal(t, "/v1/models", captured.path)
320}
321
322func TestTestConnectionGoogleIncludesKeyQueryParam(t *testing.T) {
323 t.Parallel()
324
325 srv, captured := newCaptureServer(t, http.StatusOK)
326 c := &ProviderConfig{
327 ID: string(catwalk.InferenceProviderGemini),
328 Type: catwalk.TypeGoogle,
329 BaseURL: srv.URL,
330 APIKey: "google-key",
331 }
332 err := c.TestConnection(IdentityResolver())
333 require.NoError(t, err)
334 require.Equal(t, "/v1beta/models", captured.path)
335 require.Contains(t, captured.query, "key=google-key")
336}
337
338// TestTestConnectionGoogleBadKeyIs400 locks in the fact that Google returns
339// 400 INVALID_ARGUMENT (not 401) for an unknown API key, so 400 must map to
340// "invalid" and never to [ErrValidationUnsupported].
341func TestTestConnectionGoogleBadKeyIs400(t *testing.T) {
342 t.Parallel()
343
344 for name, tc := range map[string]struct {
345 status int
346 wantNil bool
347 wantErr error
348 }{
349 "badKey400": {status: http.StatusBadRequest},
350 "unauth401": {status: http.StatusUnauthorized},
351 "forbidden403": {status: http.StatusForbidden},
352 "ok200": {status: http.StatusOK, wantNil: true},
353 "transient500": {status: http.StatusInternalServerError, wantErr: ErrValidationUnsupported},
354 } {
355 t.Run(name, func(t *testing.T) {
356 t.Parallel()
357 srv, _ := newCaptureServer(t, tc.status)
358 c := &ProviderConfig{
359 ID: string(catwalk.InferenceProviderGemini),
360 Type: catwalk.TypeGoogle,
361 BaseURL: srv.URL,
362 APIKey: "bad-key",
363 }
364 err := c.TestConnection(IdentityResolver())
365 switch {
366 case tc.wantNil:
367 require.NoError(t, err)
368 case tc.wantErr != nil:
369 require.ErrorIs(t, err, tc.wantErr)
370 default:
371 require.Error(t, err)
372 require.NotErrorIs(t, err, ErrValidationUnsupported)
373 }
374 })
375 }
376}
377
378// TestTestConnectionOpenAICompatAllowlistUsesModelsProbe locks in the
379// `/models` probe for openai-compat providers whose /models is known to be
380// auth-gated. These providers must not fall through to
381// [ErrValidationUnsupported].
382func TestTestConnectionOpenAICompatAllowlistUsesModelsProbe(t *testing.T) {
383 t.Parallel()
384
385 providers := []catwalk.InferenceProvider{
386 "deepseek",
387 catwalk.InferenceProviderGROQ,
388 catwalk.InferenceProviderXAI,
389 catwalk.InferenceProviderZhipu,
390 catwalk.InferenceProviderZhipuCoding,
391 catwalk.InferenceProviderCerebras,
392 catwalk.InferenceProviderNebius,
393 catwalk.InferenceProviderCopilot,
394 }
395 for _, id := range providers {
396 t.Run(string(id), func(t *testing.T) {
397 t.Parallel()
398 t.Run("valid", func(t *testing.T) {
399 t.Parallel()
400 srv, captured := newCaptureServer(t, http.StatusOK)
401 c := &ProviderConfig{
402 ID: string(id),
403 Type: catwalk.TypeOpenAICompat,
404 BaseURL: srv.URL,
405 APIKey: "sk-good",
406 }
407 require.NoError(t, c.TestConnection(IdentityResolver()))
408 require.Equal(t, http.MethodGet, captured.method)
409 require.Equal(t, "/models", captured.path)
410 require.Equal(t, "Bearer sk-good", captured.headers.Get("Authorization"))
411 })
412 t.Run("invalid", func(t *testing.T) {
413 t.Parallel()
414 srv, _ := newCaptureServer(t, http.StatusUnauthorized)
415 c := &ProviderConfig{
416 ID: string(id),
417 Type: catwalk.TypeOpenAICompat,
418 BaseURL: srv.URL,
419 APIKey: "sk-bad",
420 }
421 err := c.TestConnection(IdentityResolver())
422 require.Error(t, err)
423 require.NotErrorIs(t, err, ErrValidationUnsupported)
424 })
425 })
426 }
427}
428
429// TestTestConnectionZAIUsesZAIClassifier pins ZAI's historical quirk: /models
430// returns non-200 for valid keys but always 401 for bad keys.
431func TestTestConnectionZAIUsesZAIClassifier(t *testing.T) {
432 t.Parallel()
433
434 for name, tc := range map[string]struct {
435 status int
436 wantNil bool
437 }{
438 "ok200": {status: http.StatusOK, wantNil: true},
439 "other400": {status: http.StatusBadRequest, wantNil: true},
440 "other500": {status: http.StatusInternalServerError, wantNil: true},
441 "badKey401": {status: http.StatusUnauthorized},
442 } {
443 t.Run(name, func(t *testing.T) {
444 t.Parallel()
445 srv, captured := newCaptureServer(t, tc.status)
446 c := &ProviderConfig{
447 ID: string(catwalk.InferenceProviderZAI),
448 Type: catwalk.TypeOpenAICompat,
449 BaseURL: srv.URL,
450 APIKey: "sk-zai",
451 }
452 err := c.TestConnection(IdentityResolver())
453 if tc.wantNil {
454 require.NoError(t, err)
455 } else {
456 require.Error(t, err)
457 require.NotErrorIs(t, err, ErrValidationUnsupported)
458 }
459 require.Equal(t, "/models", captured.path)
460 require.Equal(t, "Bearer sk-zai", captured.headers.Get("Authorization"))
461 })
462 }
463}
464
465func TestTestConnectionBedrockPrefix(t *testing.T) {
466 t.Parallel()
467
468 t.Run("valid", func(t *testing.T) {
469 t.Parallel()
470 c := &ProviderConfig{
471 ID: string(catwalk.InferenceProviderBedrock),
472 Type: catwalk.TypeBedrock,
473 APIKey: "ABSK-secret",
474 }
475 require.NoError(t, c.TestConnection(IdentityResolver()))
476 })
477 t.Run("invalid", func(t *testing.T) {
478 t.Parallel()
479 c := &ProviderConfig{
480 ID: string(catwalk.InferenceProviderBedrock),
481 Type: catwalk.TypeBedrock,
482 APIKey: "nope",
483 }
484 err := c.TestConnection(IdentityResolver())
485 require.Error(t, err)
486 require.NotErrorIs(t, err, ErrValidationUnsupported)
487 })
488}
489
490func TestTestConnectionVercelPrefix(t *testing.T) {
491 t.Parallel()
492
493 t.Run("valid", func(t *testing.T) {
494 t.Parallel()
495 c := &ProviderConfig{
496 ID: string(catwalk.InferenceProviderVercel),
497 Type: catwalk.TypeVercel,
498 APIKey: "vck_abc",
499 }
500 require.NoError(t, c.TestConnection(IdentityResolver()))
501 })
502 t.Run("invalid", func(t *testing.T) {
503 t.Parallel()
504 c := &ProviderConfig{
505 ID: string(catwalk.InferenceProviderVercel),
506 Type: catwalk.TypeVercel,
507 APIKey: "nope",
508 }
509 err := c.TestConnection(IdentityResolver())
510 require.Error(t, err)
511 require.NotErrorIs(t, err, ErrValidationUnsupported)
512 })
513}
514
515// TestTestConnectionPublicModelsAuthGatedChatRegression locks in the core
516// regression from the 2025-10-20 expansion of generic /models validation to
517// openai-compat: a provider whose /models is intentionally public would
518// report any key as "validated" even though /chat/completions actually
519// gates on auth. For every provider we currently mark "validated" via the
520// malformed-body chat probe, this test simulates both endpoints and asserts
521// that:
522//
523// 1. A bad key (401 on /chat/completions) is reported as invalid, not as
524// "validated" — even when /models returns 200 unauthenticated.
525// 2. A good key (400/422 on /chat/completions) is reported as valid.
526// 3. The probe never hits /models for these providers.
527func TestTestConnectionPublicModelsAuthGatedChatRegression(t *testing.T) {
528 t.Parallel()
529
530 providers := []catwalk.InferenceProvider{
531 catwalk.InferenceAIHubMix,
532 catwalk.InferenceProviderAvian,
533 catwalk.InferenceProviderCortecs,
534 catwalk.InferenceProviderHuggingFace,
535 catwalk.InferenceProviderIoNet,
536 catwalk.InferenceProviderOpenCodeGo,
537 catwalk.InferenceProviderOpenCodeZen,
538 catwalk.InferenceProviderQiniuCloud,
539 catwalk.InferenceProviderSynthetic,
540 }
541 for _, id := range providers {
542 t.Run(string(id), func(t *testing.T) {
543 t.Parallel()
544
545 type hits struct {
546 models int
547 chat int
548 }
549 for name, tc := range map[string]struct {
550 chatStatus int
551 wantErr error
552 wantNil bool
553 }{
554 "badKeyIsInvalidNotValidated": {
555 chatStatus: http.StatusUnauthorized,
556 },
557 "goodKeyIsValidated": {
558 chatStatus: http.StatusBadRequest,
559 wantNil: true,
560 },
561 "forbiddenKeyIsInvalid": {
562 chatStatus: http.StatusForbidden,
563 },
564 "schemaFailure422IsValidated": {
565 chatStatus: http.StatusUnprocessableEntity,
566 wantNil: true,
567 },
568 } {
569 t.Run(name, func(t *testing.T) {
570 t.Parallel()
571 h := &hits{}
572 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
573 switch r.URL.Path {
574 case "/models":
575 // Simulate a public /models endpoint that
576 // returns 200 regardless of the provided key.
577 h.models++
578 w.WriteHeader(http.StatusOK)
579 case "/chat/completions":
580 h.chat++
581 w.WriteHeader(tc.chatStatus)
582 default:
583 w.WriteHeader(http.StatusNotFound)
584 }
585 }))
586 t.Cleanup(srv.Close)
587
588 c := &ProviderConfig{
589 ID: string(id),
590 Type: catwalk.TypeOpenAICompat,
591 BaseURL: srv.URL,
592 APIKey: "sk-test",
593 }
594 err := c.TestConnection(IdentityResolver())
595
596 if tc.wantNil {
597 require.NoError(t, err, "expected %s to validate on %d", id, tc.chatStatus)
598 } else {
599 require.Error(t, err, "expected %s to reject on %d", id, tc.chatStatus)
600 require.NotErrorIs(t, err, ErrValidationUnsupported)
601 }
602 require.Equal(t, 0, h.models, "probe must not rely on public /models for %s", id)
603 require.Equal(t, 1, h.chat, "probe must hit /chat/completions for %s", id)
604 })
605 }
606 })
607 }
608}
609
610// TestTestConnectionOpenAICompatProviderAudit is an audit table that pins the
611// full set of openai-compat providers currently exposed as "validated" (i.e.
612// TestConnection can return nil on some response) and documents the exact
613// probe each uses. Adding a new openai-compat provider to the validated set
614// MUST update this table; this prevents silent drift back into the
615// "assume /models proves auth" bug class.
616//
617// Providers not listed here either:
618// - use a different Type (TypeOpenAI / TypeAnthropic / TypeGoogle / ...);
619// - are explicitly gated behind ErrValidationUnsupported (chutes, neuralwatt,
620// and every unknown openai-compat provider).
621func TestTestConnectionOpenAICompatProviderAudit(t *testing.T) {
622 t.Parallel()
623
624 audit := map[catwalk.InferenceProvider]auditCase{
625 catwalk.InferenceProviderVenice: {
626 method: http.MethodGet,
627 path: "/api_keys/rate_limits",
628 validStatus: http.StatusOK,
629 invalidStatus: http.StatusUnauthorized,
630 authHeader: "Authorization",
631 authValue: "Bearer sk-test",
632 },
633 catwalk.InferenceAIHubMix: openaiCompatAuditCase(),
634 catwalk.InferenceProviderAvian: openaiCompatAuditCase(),
635 catwalk.InferenceProviderCortecs: openaiCompatAuditCase(),
636 catwalk.InferenceProviderHuggingFace: openaiCompatAuditCase(),
637 catwalk.InferenceProviderIoNet: openaiCompatAuditCase(),
638 catwalk.InferenceProviderOpenCodeGo: openaiCompatAuditCase(),
639 catwalk.InferenceProviderOpenCodeZen: openaiCompatAuditCase(),
640 catwalk.InferenceProviderQiniuCloud: openaiCompatAuditCase(),
641 catwalk.InferenceProviderSynthetic: openaiCompatAuditCase(),
642 // openai-compat providers with auth-gated /models (allowlist).
643 "deepseek": openaiCompatModelsAuditCase(),
644 catwalk.InferenceProviderGROQ: openaiCompatModelsAuditCase(),
645 catwalk.InferenceProviderXAI: openaiCompatModelsAuditCase(),
646 catwalk.InferenceProviderZhipu: openaiCompatModelsAuditCase(),
647 catwalk.InferenceProviderZhipuCoding: openaiCompatModelsAuditCase(),
648 catwalk.InferenceProviderCerebras: openaiCompatModelsAuditCase(),
649 catwalk.InferenceProviderNebius: openaiCompatModelsAuditCase(),
650 catwalk.InferenceProviderCopilot: openaiCompatModelsAuditCase(),
651 // ZAI uses the /models endpoint but with its own classifier that
652 // only treats 401 as invalid. Its valid path must therefore be 200
653 // here for the audit's generic "valid -> nil" check to hold.
654 catwalk.InferenceProviderZAI: {
655 method: http.MethodGet,
656 path: "/models",
657 validStatus: http.StatusOK,
658 invalidStatus: http.StatusUnauthorized,
659 authHeader: "Authorization",
660 authValue: "Bearer sk-test",
661 },
662 }
663
664 for id, tc := range audit {
665 t.Run(string(id), func(t *testing.T) {
666 t.Parallel()
667
668 // 1) Valid path.
669 srv, captured := newCaptureServer(t, tc.validStatus)
670 c := &ProviderConfig{
671 ID: string(id),
672 Type: catwalk.TypeOpenAICompat,
673 BaseURL: srv.URL,
674 APIKey: "sk-test",
675 }
676 require.NoError(t, c.TestConnection(IdentityResolver()))
677 require.Equal(t, tc.method, captured.method, "audit: wrong method for %s", id)
678 require.Equal(t, tc.path, captured.path, "audit: wrong path for %s", id)
679 require.Equal(t, tc.authValue, captured.headers.Get(tc.authHeader),
680 "audit: wrong auth header for %s", id)
681
682 // 2) Invalid path.
683 srv2, _ := newCaptureServer(t, tc.invalidStatus)
684 c2 := &ProviderConfig{
685 ID: string(id),
686 Type: catwalk.TypeOpenAICompat,
687 BaseURL: srv2.URL,
688 APIKey: "sk-test",
689 }
690 err := c2.TestConnection(IdentityResolver())
691 require.Error(t, err, "audit: %s must reject %d as invalid", id, tc.invalidStatus)
692 require.NotErrorIs(t, err, ErrValidationUnsupported,
693 "audit: %s must not leak ErrValidationUnsupported on %d", id, tc.invalidStatus)
694 })
695 }
696
697 // Sanity: every provider that currently enters the openai-compat chat
698 // probe path must appear in the audit. This guards against a future
699 // refactor silently adding a provider without test coverage.
700 chatProbeProviders := []catwalk.InferenceProvider{
701 catwalk.InferenceAIHubMix,
702 catwalk.InferenceProviderAvian,
703 catwalk.InferenceProviderCortecs,
704 catwalk.InferenceProviderHuggingFace,
705 catwalk.InferenceProviderIoNet,
706 catwalk.InferenceProviderOpenCodeGo,
707 catwalk.InferenceProviderOpenCodeZen,
708 catwalk.InferenceProviderQiniuCloud,
709 catwalk.InferenceProviderSynthetic,
710 }
711 for _, id := range chatProbeProviders {
712 _, ok := audit[id]
713 require.True(t, ok, "audit table missing entry for %s", id)
714 }
715}
716
717// auditCase pins the expected probe shape for a given provider.
718type auditCase struct {
719 method string
720 path string
721 // validStatus is a response code the probe must translate to
722 // "validated" (nil error).
723 validStatus int
724 // invalidStatus is a response code the probe must translate to an
725 // invalid-key error (not ErrValidationUnsupported).
726 invalidStatus int
727 // authHeader is the name of the header the probe uses to present
728 // the key.
729 authHeader string
730 authValue string
731}
732
733func openaiCompatAuditCase() auditCase {
734 return auditCase{
735 method: http.MethodPost,
736 path: "/chat/completions",
737 validStatus: http.StatusBadRequest,
738 invalidStatus: http.StatusUnauthorized,
739 authHeader: "Authorization",
740 authValue: "Bearer sk-test",
741 }
742}
743
744func openaiCompatModelsAuditCase() auditCase {
745 return auditCase{
746 method: http.MethodGet,
747 path: "/models",
748 validStatus: http.StatusOK,
749 invalidStatus: http.StatusUnauthorized,
750 authHeader: "Authorization",
751 authValue: "Bearer sk-test",
752 }
753}
754
755func TestTestConnectionNetworkErrorIsNotInvalidKey(t *testing.T) {
756 t.Parallel()
757
758 // Start and immediately close a server so the next request fails at the
759 // TCP layer. That should produce a non-nil error that is *not*
760 // ErrValidationUnsupported (transport errors still surface).
761 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
762 w.WriteHeader(http.StatusOK)
763 }))
764 srv.Close()
765 c := &ProviderConfig{
766 ID: string(catwalk.InferenceProviderOpenAI),
767 Type: catwalk.TypeOpenAI,
768 BaseURL: srv.URL,
769 APIKey: "sk-test",
770 }
771 err := c.TestConnection(IdentityResolver())
772 require.Error(t, err)
773 // The error message should mention the provider so users see a useful
774 // hint, even though we can't classify the status code.
775 require.True(t, strings.Contains(err.Error(), "openai") || errors.Is(err, ErrValidationUnsupported))
776}