1package client
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "net/http/httptest"
9 "testing"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/proto"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/stretchr/testify/require"
15)
16
17func TestSendEventAfterContextCancelIsIdempotent(t *testing.T) {
18 t.Parallel()
19
20 ctx, cancel := context.WithCancel(context.Background())
21 cancel()
22
23 events := make(chan any, 1)
24 require.False(t, sendEvent(ctx, events, "one"))
25 require.False(t, sendEvent(ctx, events, "two"))
26
27 select {
28 case ev := <-events:
29 require.Failf(t, "unexpected event", "event: %v", ev)
30 default:
31 }
32}
33
34func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) {
35 t.Parallel()
36
37 payload := marshalSSEPayload(t)
38 firstEventSent := make(chan struct{})
39 writeSecondEvent := make(chan struct{})
40
41 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
42 w.Header().Set("Content-Type", "text/event-stream")
43 flusher, ok := w.(http.Flusher)
44 require.True(t, ok)
45
46 _, err := fmt.Fprintf(w, "data: %s\n\n", payload)
47 require.NoError(t, err)
48 flusher.Flush()
49 close(firstEventSent)
50
51 select {
52 case <-writeSecondEvent:
53 case <-time.After(5 * time.Second):
54 return
55 }
56 _, _ = fmt.Fprintf(w, "data: %s\n\n", payload)
57 flusher.Flush()
58 }))
59 defer srv.Close()
60
61 ctx, cancel := context.WithCancel(context.Background())
62 defer cancel()
63
64 c := captureClient(t, srv)
65 events, err := c.SubscribeEvents(ctx, "ws1")
66 require.NoError(t, err)
67
68 select {
69 case <-firstEventSent:
70 case <-time.After(5 * time.Second):
71 require.Fail(t, "timed out waiting for server event")
72 }
73
74 select {
75 case <-events:
76 case <-time.After(5 * time.Second):
77 require.Fail(t, "timed out waiting for first event")
78 }
79
80 cancel()
81 close(writeSecondEvent)
82
83 select {
84 case _, ok := <-events:
85 require.False(t, ok)
86 case <-time.After(5 * time.Second):
87 require.Fail(t, "timed out waiting for event channel close")
88 }
89}
90
91func marshalSSEPayload(t *testing.T) []byte {
92 t.Helper()
93
94 eventPayload, err := json.Marshal(pubsub.Event[proto.AgentEvent]{
95 Type: pubsub.CreatedEvent,
96 Payload: proto.AgentEvent{
97 Type: proto.AgentEventTypeResponse,
98 },
99 })
100 require.NoError(t, err)
101
102 payload, err := json.Marshal(pubsub.Payload{
103 Type: pubsub.PayloadTypeAgentEvent,
104 Payload: eventPayload,
105 })
106 require.NoError(t, err)
107 return payload
108}