1package proto_test
2
3import (
4 "encoding/json"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/oauth"
9 "github.com/charmbracelet/crush/internal/proto"
10 "github.com/stretchr/testify/require"
11)
12
13func TestConfigProviderKeyRequestStringRoundTrip(t *testing.T) {
14 t.Parallel()
15
16 apiKey, err := json.Marshal("sk-test-123")
17 require.NoError(t, err)
18
19 src := proto.ConfigProviderKeyRequest{
20 Scope: config.ScopeGlobal,
21 ProviderID: "openai",
22 Kind: proto.APIKeyKindString,
23 APIKey: apiKey,
24 }
25 b, err := json.Marshal(src)
26 require.NoError(t, err)
27
28 var got proto.ConfigProviderKeyRequest
29 require.NoError(t, json.Unmarshal(b, &got))
30 require.Equal(t, proto.APIKeyKindString, got.Kind)
31
32 decoded, err := got.DecodeAPIKey()
33 require.NoError(t, err)
34 s, ok := decoded.(string)
35 require.True(t, ok, "expected string, got %T", decoded)
36 require.Equal(t, "sk-test-123", s)
37}
38
39func TestConfigProviderKeyRequestOAuthRoundTrip(t *testing.T) {
40 t.Parallel()
41
42 tok := &oauth.Token{
43 AccessToken: "access",
44 RefreshToken: "refresh",
45 ExpiresIn: 60,
46 ExpiresAt: 1234567890,
47 }
48 apiKey, err := json.Marshal(tok)
49 require.NoError(t, err)
50
51 src := proto.ConfigProviderKeyRequest{
52 Scope: config.ScopeGlobal,
53 ProviderID: "hyper",
54 Kind: proto.APIKeyKindOAuth,
55 APIKey: apiKey,
56 }
57 b, err := json.Marshal(src)
58 require.NoError(t, err)
59
60 var got proto.ConfigProviderKeyRequest
61 require.NoError(t, json.Unmarshal(b, &got))
62 require.Equal(t, proto.APIKeyKindOAuth, got.Kind)
63
64 decoded, err := got.DecodeAPIKey()
65 require.NoError(t, err)
66 gotTok, ok := decoded.(*oauth.Token)
67 require.True(t, ok, "expected *oauth.Token, got %T", decoded)
68 require.Equal(t, tok, gotTok)
69}
70
71func TestConfigProviderKeyRequestUnknownKind(t *testing.T) {
72 t.Parallel()
73
74 req := proto.ConfigProviderKeyRequest{
75 Kind: proto.APIKeyKind("bogus"),
76 APIKey: json.RawMessage(`"x"`),
77 }
78 _, err := req.DecodeAPIKey()
79 require.Error(t, err)
80 require.Contains(t, err.Error(), "bogus")
81}
82
83func TestConfigProviderKeyRequestMalformedPayload(t *testing.T) {
84 t.Parallel()
85
86 cases := []struct {
87 name string
88 kind proto.APIKeyKind
89 raw string
90 }{
91 {"string kind with object payload", proto.APIKeyKindString, `{"foo":"bar"}`},
92 {"oauth kind with string payload", proto.APIKeyKindOAuth, `"not-a-token"`},
93 {"oauth kind with invalid json", proto.APIKeyKindOAuth, `{`},
94 }
95 for _, tc := range cases {
96 t.Run(tc.name, func(t *testing.T) {
97 t.Parallel()
98 req := proto.ConfigProviderKeyRequest{
99 Kind: tc.kind,
100 APIKey: json.RawMessage(tc.raw),
101 }
102 _, err := req.DecodeAPIKey()
103 require.Error(t, err)
104 })
105 }
106}