@@ -131,6 +131,7 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
go func() {
defer rsp.Body.Close()
+ defer close(events)
scr := bufio.NewReader(rsp.Body)
for {
@@ -139,8 +140,15 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
break
}
if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
slog.Error("Reading from events stream", "error", err)
- time.Sleep(time.Second * 2)
+ select {
+ case <-time.After(time.Second * 2):
+ case <-ctx.Done():
+ return
+ }
continue
}
line = bytes.TrimSpace(line)
@@ -166,43 +174,63 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
case pubsub.PayloadTypeLSPEvent:
var e pubsub.Event[proto.LSPEvent]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypeMCPEvent:
var e pubsub.Event[proto.MCPEvent]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypePermissionRequest:
var e pubsub.Event[proto.PermissionRequest]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypePermissionNotification:
var e pubsub.Event[proto.PermissionNotification]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypeMessage:
var e pubsub.Event[proto.Message]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypeSession:
var e pubsub.Event[proto.Session]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypeFile:
var e pubsub.Event[proto.File]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypeAgentEvent:
var e pubsub.Event[proto.AgentEvent]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypeConfigChanged:
var e pubsub.Event[proto.ConfigChanged]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
case pubsub.PayloadTypeSkillsEvent:
var e pubsub.Event[proto.SkillsEvent]
_ = json.Unmarshal(p.Payload, &e)
- sendEvent(ctx, events, e)
+ if !sendEvent(ctx, events, e) {
+ return
+ }
default:
slog.Warn("Unknown event type", "type", p.Type)
continue
@@ -213,12 +241,15 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
return events, nil
}
-func sendEvent(ctx context.Context, evc chan any, ev any) {
+func sendEvent(ctx context.Context, evc chan any, ev any) bool {
+ if ctx.Err() != nil {
+ return false
+ }
select {
case evc <- ev:
+ return true
case <-ctx.Done():
- close(evc)
- return
+ return false
}
}
@@ -0,0 +1,108 @@
+package client
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSendEventAfterContextCancelIsIdempotent(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ events := make(chan any, 1)
+ require.False(t, sendEvent(ctx, events, "one"))
+ require.False(t, sendEvent(ctx, events, "two"))
+
+ select {
+ case ev := <-events:
+ require.Failf(t, "unexpected event", "event: %v", ev)
+ default:
+ }
+}
+
+func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) {
+ t.Parallel()
+
+ payload := marshalSSEPayload(t)
+ firstEventSent := make(chan struct{})
+ writeSecondEvent := make(chan struct{})
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ flusher, ok := w.(http.Flusher)
+ require.True(t, ok)
+
+ _, err := fmt.Fprintf(w, "data: %s\n\n", payload)
+ require.NoError(t, err)
+ flusher.Flush()
+ close(firstEventSent)
+
+ select {
+ case <-writeSecondEvent:
+ case <-time.After(5 * time.Second):
+ return
+ }
+ _, _ = fmt.Fprintf(w, "data: %s\n\n", payload)
+ flusher.Flush()
+ }))
+ defer srv.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ c := captureClient(t, srv)
+ events, err := c.SubscribeEvents(ctx, "ws1")
+ require.NoError(t, err)
+
+ select {
+ case <-firstEventSent:
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "timed out waiting for server event")
+ }
+
+ select {
+ case <-events:
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "timed out waiting for first event")
+ }
+
+ cancel()
+ close(writeSecondEvent)
+
+ select {
+ case _, ok := <-events:
+ require.False(t, ok)
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "timed out waiting for event channel close")
+ }
+}
+
+func marshalSSEPayload(t *testing.T) []byte {
+ t.Helper()
+
+ eventPayload, err := json.Marshal(pubsub.Event[proto.AgentEvent]{
+ Type: pubsub.CreatedEvent,
+ Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeResponse,
+ },
+ })
+ require.NoError(t, err)
+
+ payload, err := json.Marshal(pubsub.Payload{
+ Type: pubsub.PayloadTypeAgentEvent,
+ Payload: eventPayload,
+ })
+ require.NoError(t, err)
+ return payload
+}