requests_test.go

  1package proto_test
  2
  3import (
  4	"encoding/json"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/oauth"
  9	"github.com/charmbracelet/crush/internal/proto"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13func TestConfigProviderKeyRequestStringRoundTrip(t *testing.T) {
 14	t.Parallel()
 15
 16	apiKey, err := json.Marshal("sk-test-123")
 17	require.NoError(t, err)
 18
 19	src := proto.ConfigProviderKeyRequest{
 20		Scope:      config.ScopeGlobal,
 21		ProviderID: "openai",
 22		Kind:       proto.APIKeyKindString,
 23		APIKey:     apiKey,
 24	}
 25	b, err := json.Marshal(src)
 26	require.NoError(t, err)
 27
 28	var got proto.ConfigProviderKeyRequest
 29	require.NoError(t, json.Unmarshal(b, &got))
 30	require.Equal(t, proto.APIKeyKindString, got.Kind)
 31
 32	decoded, err := got.DecodeAPIKey()
 33	require.NoError(t, err)
 34	s, ok := decoded.(string)
 35	require.True(t, ok, "expected string, got %T", decoded)
 36	require.Equal(t, "sk-test-123", s)
 37}
 38
 39func TestConfigProviderKeyRequestOAuthRoundTrip(t *testing.T) {
 40	t.Parallel()
 41
 42	tok := &oauth.Token{
 43		AccessToken:  "access",
 44		RefreshToken: "refresh",
 45		ExpiresIn:    60,
 46		ExpiresAt:    1234567890,
 47	}
 48	apiKey, err := json.Marshal(tok)
 49	require.NoError(t, err)
 50
 51	src := proto.ConfigProviderKeyRequest{
 52		Scope:      config.ScopeGlobal,
 53		ProviderID: "hyper",
 54		Kind:       proto.APIKeyKindOAuth,
 55		APIKey:     apiKey,
 56	}
 57	b, err := json.Marshal(src)
 58	require.NoError(t, err)
 59
 60	var got proto.ConfigProviderKeyRequest
 61	require.NoError(t, json.Unmarshal(b, &got))
 62	require.Equal(t, proto.APIKeyKindOAuth, got.Kind)
 63
 64	decoded, err := got.DecodeAPIKey()
 65	require.NoError(t, err)
 66	gotTok, ok := decoded.(*oauth.Token)
 67	require.True(t, ok, "expected *oauth.Token, got %T", decoded)
 68	require.Equal(t, tok, gotTok)
 69}
 70
 71func TestConfigProviderKeyRequestUnknownKind(t *testing.T) {
 72	t.Parallel()
 73
 74	req := proto.ConfigProviderKeyRequest{
 75		Kind:   proto.APIKeyKind("bogus"),
 76		APIKey: json.RawMessage(`"x"`),
 77	}
 78	_, err := req.DecodeAPIKey()
 79	require.Error(t, err)
 80	require.Contains(t, err.Error(), "bogus")
 81}
 82
 83func TestConfigProviderKeyRequestMalformedPayload(t *testing.T) {
 84	t.Parallel()
 85
 86	cases := []struct {
 87		name string
 88		kind proto.APIKeyKind
 89		raw  string
 90	}{
 91		{"string kind with object payload", proto.APIKeyKindString, `{"foo":"bar"}`},
 92		{"oauth kind with string payload", proto.APIKeyKindOAuth, `"not-a-token"`},
 93		{"oauth kind with invalid json", proto.APIKeyKindOAuth, `{`},
 94	}
 95	for _, tc := range cases {
 96		t.Run(tc.name, func(t *testing.T) {
 97			t.Parallel()
 98			req := proto.ConfigProviderKeyRequest{
 99				Kind:   tc.kind,
100				APIKey: json.RawMessage(tc.raw),
101			}
102			_, err := req.DecodeAPIKey()
103			require.Error(t, err)
104		})
105	}
106}