From ed870ecb91c53a98a9114d930cf62fa9c9170993 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Tue, 26 May 2026 19:39:32 -0400 Subject: [PATCH] fix(client): prevent event subscription panic on cancellation Co-Authored-By: Charm Crush --- internal/client/proto.go | 59 ++++++++++++++----- internal/client/proto_test.go | 108 ++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 14 deletions(-) create mode 100644 internal/client/proto_test.go diff --git a/internal/client/proto.go b/internal/client/proto.go index 5a57262679df0a7edc3f9269dc763fc124e778cd..f5f2f4273aba69526fdb66c6a1e0230c554be456 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -131,6 +131,7 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er go func() { defer rsp.Body.Close() + defer close(events) scr := bufio.NewReader(rsp.Body) for { @@ -139,8 +140,15 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er break } if err != nil { + if ctx.Err() != nil { + return + } slog.Error("Reading from events stream", "error", err) - time.Sleep(time.Second * 2) + select { + case <-time.After(time.Second * 2): + case <-ctx.Done(): + return + } continue } line = bytes.TrimSpace(line) @@ -166,43 +174,63 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er case pubsub.PayloadTypeLSPEvent: var e pubsub.Event[proto.LSPEvent] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypeMCPEvent: var e pubsub.Event[proto.MCPEvent] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypePermissionRequest: var e pubsub.Event[proto.PermissionRequest] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypePermissionNotification: var e pubsub.Event[proto.PermissionNotification] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypeMessage: var e pubsub.Event[proto.Message] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypeSession: var e pubsub.Event[proto.Session] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypeFile: var e pubsub.Event[proto.File] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypeAgentEvent: var e pubsub.Event[proto.AgentEvent] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypeConfigChanged: var e pubsub.Event[proto.ConfigChanged] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } case pubsub.PayloadTypeSkillsEvent: var e pubsub.Event[proto.SkillsEvent] _ = json.Unmarshal(p.Payload, &e) - sendEvent(ctx, events, e) + if !sendEvent(ctx, events, e) { + return + } default: slog.Warn("Unknown event type", "type", p.Type) continue @@ -213,12 +241,15 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er return events, nil } -func sendEvent(ctx context.Context, evc chan any, ev any) { +func sendEvent(ctx context.Context, evc chan any, ev any) bool { + if ctx.Err() != nil { + return false + } select { case evc <- ev: + return true case <-ctx.Done(): - close(evc) - return + return false } } diff --git a/internal/client/proto_test.go b/internal/client/proto_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b5739ccc91c16b2bb0fc3c3f6dc2281687bd8e65 --- /dev/null +++ b/internal/client/proto_test.go @@ -0,0 +1,108 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/require" +) + +func TestSendEventAfterContextCancelIsIdempotent(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + events := make(chan any, 1) + require.False(t, sendEvent(ctx, events, "one")) + require.False(t, sendEvent(ctx, events, "two")) + + select { + case ev := <-events: + require.Failf(t, "unexpected event", "event: %v", ev) + default: + } +} + +func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) { + t.Parallel() + + payload := marshalSSEPayload(t) + firstEventSent := make(chan struct{}) + writeSecondEvent := make(chan struct{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + require.True(t, ok) + + _, err := fmt.Fprintf(w, "data: %s\n\n", payload) + require.NoError(t, err) + flusher.Flush() + close(firstEventSent) + + select { + case <-writeSecondEvent: + case <-time.After(5 * time.Second): + return + } + _, _ = fmt.Fprintf(w, "data: %s\n\n", payload) + flusher.Flush() + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := captureClient(t, srv) + events, err := c.SubscribeEvents(ctx, "ws1") + require.NoError(t, err) + + select { + case <-firstEventSent: + case <-time.After(5 * time.Second): + require.Fail(t, "timed out waiting for server event") + } + + select { + case <-events: + case <-time.After(5 * time.Second): + require.Fail(t, "timed out waiting for first event") + } + + cancel() + close(writeSecondEvent) + + select { + case _, ok := <-events: + require.False(t, ok) + case <-time.After(5 * time.Second): + require.Fail(t, "timed out waiting for event channel close") + } +} + +func marshalSSEPayload(t *testing.T) []byte { + t.Helper() + + eventPayload, err := json.Marshal(pubsub.Event[proto.AgentEvent]{ + Type: pubsub.CreatedEvent, + Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeResponse, + }, + }) + require.NoError(t, err) + + payload, err := json.Marshal(pubsub.Payload{ + Type: pubsub.PayloadTypeAgentEvent, + Payload: eventPayload, + }) + require.NoError(t, err) + return payload +}