proto_test.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"net/http"
  8	"net/http/httptest"
  9	"testing"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/proto"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17func TestSendEventAfterContextCancelIsIdempotent(t *testing.T) {
 18	t.Parallel()
 19
 20	ctx, cancel := context.WithCancel(context.Background())
 21	cancel()
 22
 23	events := make(chan any, 1)
 24	require.False(t, sendEvent(ctx, events, "one"))
 25	require.False(t, sendEvent(ctx, events, "two"))
 26
 27	select {
 28	case ev := <-events:
 29		require.Failf(t, "unexpected event", "event: %v", ev)
 30	default:
 31	}
 32}
 33
 34func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) {
 35	t.Parallel()
 36
 37	payload := marshalSSEPayload(t)
 38	firstEventSent := make(chan struct{})
 39	writeSecondEvent := make(chan struct{})
 40
 41	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
 42		w.Header().Set("Content-Type", "text/event-stream")
 43		flusher, ok := w.(http.Flusher)
 44		require.True(t, ok)
 45
 46		_, err := fmt.Fprintf(w, "data: %s\n\n", payload)
 47		require.NoError(t, err)
 48		flusher.Flush()
 49		close(firstEventSent)
 50
 51		select {
 52		case <-writeSecondEvent:
 53		case <-time.After(5 * time.Second):
 54			return
 55		}
 56		_, _ = fmt.Fprintf(w, "data: %s\n\n", payload)
 57		flusher.Flush()
 58	}))
 59	defer srv.Close()
 60
 61	ctx, cancel := context.WithCancel(context.Background())
 62	defer cancel()
 63
 64	c := captureClient(t, srv)
 65	events, err := c.SubscribeEvents(ctx, "ws1")
 66	require.NoError(t, err)
 67
 68	select {
 69	case <-firstEventSent:
 70	case <-time.After(5 * time.Second):
 71		require.Fail(t, "timed out waiting for server event")
 72	}
 73
 74	select {
 75	case <-events:
 76	case <-time.After(5 * time.Second):
 77		require.Fail(t, "timed out waiting for first event")
 78	}
 79
 80	cancel()
 81	close(writeSecondEvent)
 82
 83	select {
 84	case _, ok := <-events:
 85		require.False(t, ok)
 86	case <-time.After(5 * time.Second):
 87		require.Fail(t, "timed out waiting for event channel close")
 88	}
 89}
 90
 91func TestSendMessageAcceptsStatusAccepted(t *testing.T) {
 92	t.Parallel()
 93
 94	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
 95		w.WriteHeader(http.StatusAccepted)
 96	}))
 97	defer srv.Close()
 98
 99	c := captureClient(t, srv)
100	require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello"))
101}
102
103func TestSendMessageAcceptsStatusOK(t *testing.T) {
104	t.Parallel()
105
106	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
107		w.WriteHeader(http.StatusOK)
108	}))
109	defer srv.Close()
110
111	c := captureClient(t, srv)
112	require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello"))
113}
114
115func TestSendMessageDecodesErrorBody(t *testing.T) {
116	t.Parallel()
117
118	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
119		w.WriteHeader(http.StatusBadRequest)
120		_ = json.NewEncoder(w).Encode(proto.Error{Message: "session id is required"})
121	}))
122	defer srv.Close()
123
124	c := captureClient(t, srv)
125	err := c.SendMessage(context.Background(), "ws1", "", "", "hello")
126	require.Error(t, err)
127	require.Contains(t, err.Error(), "status code 400")
128	require.Contains(t, err.Error(), "session id is required")
129}
130
131func TestSendMessageFallsBackOnMalformedErrorBody(t *testing.T) {
132	t.Parallel()
133
134	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
135		w.WriteHeader(http.StatusInternalServerError)
136		_, _ = w.Write([]byte("not json"))
137	}))
138	defer srv.Close()
139
140	c := captureClient(t, srv)
141	err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")
142	require.Error(t, err)
143	require.Contains(t, err.Error(), "status code 500")
144	require.NotContains(t, err.Error(), "not json")
145}
146
147func TestSendMessageFallsBackOnEmptyErrorBody(t *testing.T) {
148	t.Parallel()
149
150	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
151		w.WriteHeader(http.StatusInternalServerError)
152	}))
153	defer srv.Close()
154
155	c := captureClient(t, srv)
156	err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")
157	require.Error(t, err)
158	require.Contains(t, err.Error(), "status code 500")
159}
160
161func marshalSSEPayload(t *testing.T) []byte {
162	t.Helper()
163
164	eventPayload, err := json.Marshal(pubsub.Event[proto.AgentEvent]{
165		Type: pubsub.CreatedEvent,
166		Payload: proto.AgentEvent{
167			Type: proto.AgentEventTypeResponse,
168		},
169	})
170	require.NoError(t, err)
171
172	payload, err := json.Marshal(pubsub.Payload{
173		Type:    pubsub.PayloadTypeAgentEvent,
174		Payload: eventPayload,
175	})
176	require.NoError(t, err)
177	return payload
178}