config_test.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"io"
  7	"net/http"
  8	"net/http/httptest"
  9	"net/url"
 10	"testing"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/oauth"
 14	"github.com/charmbracelet/crush/internal/proto"
 15	"github.com/stretchr/testify/require"
 16)
 17
 18// captureClient returns a Client that talks to the given test server,
 19// plus a channel receiving the parsed request body for each call.
 20func captureClient(t *testing.T, srv *httptest.Server) *Client {
 21	t.Helper()
 22	u, err := url.Parse(srv.URL)
 23	require.NoError(t, err)
 24	c, err := NewClient(t.TempDir(), "tcp", u.Host)
 25	require.NoError(t, err)
 26	return c
 27}
 28
 29func TestSetProviderAPIKeyStringSendsKind(t *testing.T) {
 30	t.Parallel()
 31
 32	var got proto.ConfigProviderKeyRequest
 33	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 34		body, err := io.ReadAll(r.Body)
 35		require.NoError(t, err)
 36		require.NoError(t, json.Unmarshal(body, &got))
 37		w.WriteHeader(http.StatusOK)
 38	}))
 39	defer srv.Close()
 40
 41	c := captureClient(t, srv)
 42	require.NoError(t, c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "openai", "sk-xyz"))
 43
 44	require.Equal(t, proto.APIKeyKindString, got.Kind)
 45	require.Equal(t, "openai", got.ProviderID)
 46	require.Equal(t, config.ScopeGlobal, got.Scope)
 47	decoded, err := got.DecodeAPIKey()
 48	require.NoError(t, err)
 49	require.Equal(t, "sk-xyz", decoded)
 50}
 51
 52func TestSetProviderAPIKeyOAuthSendsKind(t *testing.T) {
 53	t.Parallel()
 54
 55	var got proto.ConfigProviderKeyRequest
 56	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 57		body, err := io.ReadAll(r.Body)
 58		require.NoError(t, err)
 59		require.NoError(t, json.Unmarshal(body, &got))
 60		w.WriteHeader(http.StatusOK)
 61	}))
 62	defer srv.Close()
 63
 64	tok := &oauth.Token{AccessToken: "a", RefreshToken: "r", ExpiresIn: 60, ExpiresAt: 1234567890}
 65	c := captureClient(t, srv)
 66	require.NoError(t, c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "hyper", tok))
 67
 68	require.Equal(t, proto.APIKeyKindOAuth, got.Kind)
 69	decoded, err := got.DecodeAPIKey()
 70	require.NoError(t, err)
 71	require.Equal(t, tok, decoded.(*oauth.Token))
 72}
 73
 74func TestSetProviderAPIKeyUnsupportedTypeFailsLocally(t *testing.T) {
 75	t.Parallel()
 76
 77	called := false
 78	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
 79		called = true
 80		w.WriteHeader(http.StatusOK)
 81	}))
 82	defer srv.Close()
 83
 84	c := captureClient(t, srv)
 85	err := c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "x", 42)
 86	require.Error(t, err)
 87	require.Contains(t, err.Error(), "unsupported api key type")
 88	require.False(t, called, "server should not have been reached")
 89}
 90
 91func TestSetProviderAPIKeyNilOAuthFailsLocally(t *testing.T) {
 92	t.Parallel()
 93
 94	c := captureClient(t, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
 95		w.WriteHeader(http.StatusOK)
 96	})))
 97
 98	var tok *oauth.Token
 99	err := c.SetProviderAPIKey(context.Background(), "ws1", config.ScopeGlobal, "x", tok)
100	require.Error(t, err)
101}