1package client
2
3import (
4 "context"
5 "encoding/json"
6 "io"
7 "net/http"
8 "net/http/httptest"
9 "net/url"
10 "testing"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/oauth"
14 "github.com/charmbracelet/crush/internal/proto"
15 "github.com/stretchr/testify/require"
16)
17
18// captureClient returns a Client that talks to the given test server,
19// plus a channel receiving the parsed request body for each call.
20func captureClient(t *testing.T, srv *httptest.Server) *Client {
21 t.Helper()
22 u, err := url.Parse(srv.URL)
23 require.NoError(t, err)
24 c, err := NewClient(t.TempDir(), "tcp", u.Host)
25 require.NoError(t, err)
26 return c
27}
28
29func TestSetProviderAPIKeyStringSendsKind(t *testing.T) {
30 t.Parallel()
31
32 var got proto.ConfigProviderKeyRequest
33 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
34 body, err := io.ReadAll(r.Body)
35 require.NoError(t, err)
36 require.NoError(t, json.Unmarshal(body, &got))
37 w.WriteHeader(http.StatusOK)
38 }))
39 defer srv.Close()
40
41 c := captureClient(t, srv)
42 require.NoError(t, c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "openai", "sk-xyz"))
43
44 require.Equal(t, proto.APIKeyKindString, got.Kind)
45 require.Equal(t, "openai", got.ProviderID)
46 require.Equal(t, config.ScopeGlobal, got.Scope)
47 decoded, err := got.DecodeAPIKey()
48 require.NoError(t, err)
49 require.Equal(t, "sk-xyz", decoded)
50}
51
52func TestSetProviderAPIKeyOAuthSendsKind(t *testing.T) {
53 t.Parallel()
54
55 var got proto.ConfigProviderKeyRequest
56 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57 body, err := io.ReadAll(r.Body)
58 require.NoError(t, err)
59 require.NoError(t, json.Unmarshal(body, &got))
60 w.WriteHeader(http.StatusOK)
61 }))
62 defer srv.Close()
63
64 tok := &oauth.Token{AccessToken: "a", RefreshToken: "r", ExpiresIn: 60, ExpiresAt: 1234567890}
65 c := captureClient(t, srv)
66 require.NoError(t, c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "hyper", tok))
67
68 require.Equal(t, proto.APIKeyKindOAuth, got.Kind)
69 decoded, err := got.DecodeAPIKey()
70 require.NoError(t, err)
71 require.Equal(t, tok, decoded.(*oauth.Token))
72}
73
74func TestSetProviderAPIKeyUnsupportedTypeFailsLocally(t *testing.T) {
75 t.Parallel()
76
77 called := false
78 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
79 called = true
80 w.WriteHeader(http.StatusOK)
81 }))
82 defer srv.Close()
83
84 c := captureClient(t, srv)
85 err := c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "x", 42)
86 require.Error(t, err)
87 require.Contains(t, err.Error(), "unsupported api key type")
88 require.False(t, called, "server should not have been reached")
89}
90
91func TestSetProviderAPIKeyNilOAuthFailsLocally(t *testing.T) {
92 t.Parallel()
93
94 c := captureClient(t, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
95 w.WriteHeader(http.StatusOK)
96 })))
97
98 var tok *oauth.Token
99 err := c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "x", tok)
100 require.Error(t, err)
101}