fix(client): prevent event subscription panic on cancellation

Christian Rocha and Charm Crush created

Co-Authored-By: Charm Crush <crush@charm.land>

Change summary

internal/client/proto.go      |  59 +++++++++++++++----
internal/client/proto_test.go | 108 +++++++++++++++++++++++++++++++++++++
2 files changed, 153 insertions(+), 14 deletions(-)

Detailed changes

internal/client/proto.go 🔗

@@ -131,6 +131,7 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
 
 	go func() {
 		defer rsp.Body.Close()
+		defer close(events)
 
 		scr := bufio.NewReader(rsp.Body)
 		for {
@@ -139,8 +140,15 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
 				break
 			}
 			if err != nil {
+				if ctx.Err() != nil {
+					return
+				}
 				slog.Error("Reading from events stream", "error", err)
-				time.Sleep(time.Second * 2)
+				select {
+				case <-time.After(time.Second * 2):
+				case <-ctx.Done():
+					return
+				}
 				continue
 			}
 			line = bytes.TrimSpace(line)
@@ -166,43 +174,63 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
 			case pubsub.PayloadTypeLSPEvent:
 				var e pubsub.Event[proto.LSPEvent]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypeMCPEvent:
 				var e pubsub.Event[proto.MCPEvent]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypePermissionRequest:
 				var e pubsub.Event[proto.PermissionRequest]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypePermissionNotification:
 				var e pubsub.Event[proto.PermissionNotification]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypeMessage:
 				var e pubsub.Event[proto.Message]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypeSession:
 				var e pubsub.Event[proto.Session]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypeFile:
 				var e pubsub.Event[proto.File]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypeAgentEvent:
 				var e pubsub.Event[proto.AgentEvent]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypeConfigChanged:
 				var e pubsub.Event[proto.ConfigChanged]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			case pubsub.PayloadTypeSkillsEvent:
 				var e pubsub.Event[proto.SkillsEvent]
 				_ = json.Unmarshal(p.Payload, &e)
-				sendEvent(ctx, events, e)
+				if !sendEvent(ctx, events, e) {
+					return
+				}
 			default:
 				slog.Warn("Unknown event type", "type", p.Type)
 				continue
@@ -213,12 +241,15 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er
 	return events, nil
 }
 
-func sendEvent(ctx context.Context, evc chan any, ev any) {
+func sendEvent(ctx context.Context, evc chan any, ev any) bool {
+	if ctx.Err() != nil {
+		return false
+	}
 	select {
 	case evc <- ev:
+		return true
 	case <-ctx.Done():
-		close(evc)
-		return
+		return false
 	}
 }
 

internal/client/proto_test.go 🔗

@@ -0,0 +1,108 @@
+package client
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"github.com/charmbracelet/crush/internal/proto"
+	"github.com/charmbracelet/crush/internal/pubsub"
+	"github.com/stretchr/testify/require"
+)
+
+func TestSendEventAfterContextCancelIsIdempotent(t *testing.T) {
+	t.Parallel()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+
+	events := make(chan any, 1)
+	require.False(t, sendEvent(ctx, events, "one"))
+	require.False(t, sendEvent(ctx, events, "two"))
+
+	select {
+	case ev := <-events:
+		require.Failf(t, "unexpected event", "event: %v", ev)
+	default:
+	}
+}
+
+func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) {
+	t.Parallel()
+
+	payload := marshalSSEPayload(t)
+	firstEventSent := make(chan struct{})
+	writeSecondEvent := make(chan struct{})
+
+	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+		w.Header().Set("Content-Type", "text/event-stream")
+		flusher, ok := w.(http.Flusher)
+		require.True(t, ok)
+
+		_, err := fmt.Fprintf(w, "data: %s\n\n", payload)
+		require.NoError(t, err)
+		flusher.Flush()
+		close(firstEventSent)
+
+		select {
+		case <-writeSecondEvent:
+		case <-time.After(5 * time.Second):
+			return
+		}
+		_, _ = fmt.Fprintf(w, "data: %s\n\n", payload)
+		flusher.Flush()
+	}))
+	defer srv.Close()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	c := captureClient(t, srv)
+	events, err := c.SubscribeEvents(ctx, "ws1")
+	require.NoError(t, err)
+
+	select {
+	case <-firstEventSent:
+	case <-time.After(5 * time.Second):
+		require.Fail(t, "timed out waiting for server event")
+	}
+
+	select {
+	case <-events:
+	case <-time.After(5 * time.Second):
+		require.Fail(t, "timed out waiting for first event")
+	}
+
+	cancel()
+	close(writeSecondEvent)
+
+	select {
+	case _, ok := <-events:
+		require.False(t, ok)
+	case <-time.After(5 * time.Second):
+		require.Fail(t, "timed out waiting for event channel close")
+	}
+}
+
+func marshalSSEPayload(t *testing.T) []byte {
+	t.Helper()
+
+	eventPayload, err := json.Marshal(pubsub.Event[proto.AgentEvent]{
+		Type: pubsub.CreatedEvent,
+		Payload: proto.AgentEvent{
+			Type: proto.AgentEventTypeResponse,
+		},
+	})
+	require.NoError(t, err)
+
+	payload, err := json.Marshal(pubsub.Payload{
+		Type:    pubsub.PayloadTypeAgentEvent,
+		Payload: eventPayload,
+	})
+	require.NoError(t, err)
+	return payload
+}