api_key_input_test.go

  1package dialog
  2
  3import (
  4	"errors"
  5	"fmt"
  6	"io"
  7	"net/http"
  8	"net/http/httptest"
  9	"testing"
 10
 11	"charm.land/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16// TestAPIKeyStateForVerifyErr pins the mapping between TestConnection errors
 17// and the dialog state the UI should transition into. In particular, an
 18// [config.ErrValidationUnsupported] error must yield the unverified state
 19// (so the UI shows "saved (not verified)" instead of "invalid").
 20func TestAPIKeyStateForVerifyErr(t *testing.T) {
 21	t.Parallel()
 22
 23	tests := map[string]struct {
 24		err  error
 25		want APIKeyInputState
 26	}{
 27		"nilIsVerified": {
 28			err:  nil,
 29			want: APIKeyInputStateVerified,
 30		},
 31		"unsupportedIsUnverified": {
 32			err:  config.ErrValidationUnsupported,
 33			want: APIKeyInputStateUnverified,
 34		},
 35		"wrappedUnsupportedIsUnverified": {
 36			err:  fmt.Errorf("probing provider: %w", config.ErrValidationUnsupported),
 37			want: APIKeyInputStateUnverified,
 38		},
 39		"plainErrorIsError": {
 40			err:  errors.New("bad key"),
 41			want: APIKeyInputStateError,
 42		},
 43	}
 44	for name, tc := range tests {
 45		t.Run(name, func(t *testing.T) {
 46			t.Parallel()
 47			require.Equal(t, tc.want, apiKeyStateForVerifyErr(tc.err))
 48		})
 49	}
 50}
 51
 52// TestProviderConfigForVerifyPropagatesDefaultHeaders locks in the UI
 53// contract that any catwalk-declared DefaultHeaders are copied into the
 54// ProviderConfig.ExtraHeaders used for the validation probe. Without this,
 55// providers that require routing/tenant headers (e.g. DefaultHeaders
 56// supplied by the catwalk provider definition) would probe with a stripped
 57// header set and potentially be misclassified as "not verified".
 58func TestProviderConfigForVerifyPropagatesDefaultHeaders(t *testing.T) {
 59	t.Parallel()
 60
 61	provider := catwalk.Provider{
 62		ID:          "test-provider",
 63		Name:        "Test Provider",
 64		Type:        catwalk.TypeOpenAICompat,
 65		APIEndpoint: "https://example.invalid",
 66		DefaultHeaders: map[string]string{
 67			"X-Tenant":   "acme",
 68			"X-Route":    "primary",
 69			"X-Shared":   "from-default",
 70			"User-Agent": "crush-test",
 71		},
 72	}
 73	cfg := providerConfigForVerify(provider, "sk-test")
 74
 75	require.Equal(t, string(provider.ID), cfg.ID)
 76	require.Equal(t, provider.Name, cfg.Name)
 77	require.Equal(t, provider.Type, cfg.Type)
 78	require.Equal(t, provider.APIEndpoint, cfg.BaseURL)
 79	require.Equal(t, "sk-test", cfg.APIKey)
 80	require.Equal(t, provider.DefaultHeaders, cfg.ExtraHeaders,
 81		"DefaultHeaders must be propagated to ExtraHeaders")
 82
 83	// Mutating the returned config must not leak back into the provider
 84	// definition (the dialog reuses the provider value across retries).
 85	cfg.ExtraHeaders["X-Tenant"] = "mutated"
 86	require.Equal(t, "acme", provider.DefaultHeaders["X-Tenant"],
 87		"providerConfigForVerify must copy DefaultHeaders, not alias them")
 88}
 89
 90func TestProviderConfigForVerifyWithNoDefaultHeaders(t *testing.T) {
 91	t.Parallel()
 92
 93	provider := catwalk.Provider{
 94		ID:          "test-provider",
 95		Type:        catwalk.TypeOpenAICompat,
 96		APIEndpoint: "https://example.invalid",
 97	}
 98	cfg := providerConfigForVerify(provider, "sk-test")
 99	require.Nil(t, cfg.ExtraHeaders,
100		"no DefaultHeaders should yield no ExtraHeaders allocation")
101}
102
103// TestProviderConfigForVerifyHeadersHitTheWire is an end-to-end UI-level
104// check: after providerConfigForVerify builds the probe config, calling
105// TestConnection against a local server must deliver the DefaultHeaders on
106// the outbound request. This guards against silent regressions where the
107// header map is dropped between the dialog and the HTTP layer.
108func TestProviderConfigForVerifyHeadersHitTheWire(t *testing.T) {
109	t.Parallel()
110
111	var captured http.Header
112	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
113		captured = r.Header.Clone()
114		_, _ = io.Copy(io.Discard, r.Body)
115		// 400 on the malformed-body chat-completions probe means "auth
116		// passed, schema rejected" for the Synthetic-style override.
117		w.WriteHeader(http.StatusBadRequest)
118	}))
119	t.Cleanup(srv.Close)
120
121	provider := catwalk.Provider{
122		ID:          catwalk.InferenceProviderSynthetic,
123		Name:        "Synthetic",
124		Type:        catwalk.TypeOpenAICompat,
125		APIEndpoint: srv.URL,
126		DefaultHeaders: map[string]string{
127			"X-Tenant": "acme",
128			"X-Route":  "primary",
129		},
130	}
131	cfg := providerConfigForVerify(provider, "sk-test")
132	require.NoError(t, cfg.TestConnection(config.IdentityResolver()))
133
134	require.NotNil(t, captured, "probe must have reached the test server")
135	require.Equal(t, "acme", captured.Get("X-Tenant"))
136	require.Equal(t, "primary", captured.Get("X-Route"))
137	// Probe-defined headers should still be present alongside the
138	// DefaultHeaders.
139	require.Equal(t, "Bearer sk-test", captured.Get("Authorization"))
140}
141
142// TestAPIKeyInputStatesAreDistinct guards against someone accidentally making
143// APIKeyInputStateUnverified equal to one of the other states (which would
144// silently collapse the "saved (not verified)" path onto "validated" or
145// "invalid").
146func TestAPIKeyInputStatesAreDistinct(t *testing.T) {
147	t.Parallel()
148
149	states := []APIKeyInputState{
150		APIKeyInputStateInitial,
151		APIKeyInputStateVerifying,
152		APIKeyInputStateVerified,
153		APIKeyInputStateUnverified,
154		APIKeyInputStateError,
155	}
156	seen := map[APIKeyInputState]struct{}{}
157	for _, s := range states {
158		_, dup := seen[s]
159		require.False(t, dup, "state %d declared twice", s)
160		seen[s] = struct{}{}
161	}
162}