diff --git a/internal/agent/notify/notify.go b/internal/agent/notify/notify.go index 1a217a6d00650fe1134b24d9d779821015513063..22e9f17769b5585302a195049bb3abca919f9a91 100644 --- a/internal/agent/notify/notify.go +++ b/internal/agent/notify/notify.go @@ -23,6 +23,12 @@ type Notification struct { SessionTitle string Type Type ProviderID string + // RunID, when non-empty, is the caller-supplied correlator + // (proto.AgentMessage.RunID) for the run that produced this + // notification. It lets observers attribute a TypeAgentError to a + // specific request rather than to any in-flight run on the + // session. Empty when no caller set one. + RunID string // Message carries the error text for TypeAgentError. Other // notification types ignore it. Message string diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 2dd0479d3236d55e3919bdef1f16bb593fe5684e..4af3b8f0d2f88ad5daff41f40664b303c948b263 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -85,6 +85,7 @@ func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent. ws.AgentNotifications().Publish(pubsub.CreatedEvent, notify.Notification{ SessionID: msg.SessionID, + RunID: msg.RunID, Type: notify.TypeAgentError, Message: err.Error(), }) diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 87fa32606674847741a9d028a26375fb98935fc4..2feeba78e6f4862e453fcb790e428b3e08ab0505 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -409,11 +409,27 @@ func (s *runStream) handle(ev any, stopSpinner func()) (done bool, err error) { return true, nil case pubsub.Event[proto.AgentEvent]: - if e.Payload.Error != nil { - stop() - return true, fmt.Errorf("agent error: %w", e.Payload.Error) + if e.Payload.Error == nil { + return false, nil } - return false, nil + // Attribute the error to our run before treating it as + // fatal. Async errors from an unrelated workspace run share + // this channel, so a foreign failure must not abort us: + // - if the event carries a RunID, it is the authoritative + // correlator: it must match our run exactly, otherwise it + // belongs to a different request and we ignore it. + // - if the event carries no RunID (older server), fall back + // to SessionID: it must be present and match our session, + // otherwise we ignore it. + if e.Payload.RunID != "" { + if e.Payload.RunID != s.runID { + return false, nil + } + } else if e.Payload.SessionID == "" || e.Payload.SessionID != s.sessionID { + return false, nil + } + stop() + return true, fmt.Errorf("agent error: %w", e.Payload.Error) } return false, nil } diff --git a/internal/cmd/run_stream_test.go b/internal/cmd/run_stream_test.go index ac168fa77045aa6aa5761b6f9c657f066c952734..028eb03baa0dc7a55a0037e67f033b708ff9634e 100644 --- a/internal/cmd/run_stream_test.go +++ b/internal/cmd/run_stream_test.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "errors" "testing" "time" @@ -307,6 +308,93 @@ func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) require.Equal(t, "streamed prefix final", buf.String()) } +// TestRunStream_AgentErrorRunIDFiltersForeign verifies that an async +// agent error carrying a non-empty RunID is fatal only when it matches +// our run. A foreign RunID is ignored regardless of the event's +// SessionID, because RunID is the authoritative correlator and async +// errors share the agent event channel: without strict RunID matching +// an unrelated workspace failure would abort our run. +func TestRunStream_AgentErrorRunIDFiltersForeign(t *testing.T) { + t.Parallel() + + // Foreign RunID with a matching session is still foreign. + s := &runStream{sessionID: "S", runID: "run-mine", out: &bytes.Buffer{}, read: map[string]int{}} + done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Foreign RunID with a different session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "other", + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Foreign RunID with a missing session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Matching RunID is fatal. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + RunID: "run-mine", + Error: errors.New("my boom"), + }}, nil) + require.Error(t, err, "matching RunID error must be fatal") + require.True(t, done) +} + +// TestRunStream_AgentErrorNoRunIDFiltersBySession verifies the +// compatibility fallback: when the event carries no RunID, attribution +// falls back to SessionID. An error for another session or with an +// empty session is ignored, while an error for our own session is fatal +// so a real failure is never dropped. +func TestRunStream_AgentErrorNoRunIDFiltersBySession(t *testing.T) { + t.Parallel() + + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + + // Empty RunID for another session is ignored. + done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "error for another session must not abort our run") + require.False(t, done) + + // Empty RunID with an empty session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "error with no session must not abort our run") + require.False(t, done) + + // Empty RunID with a matching session is fatal. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + Error: errors.New("my boom"), + }}, nil) + require.Error(t, err, "error for our own session must be fatal") + require.True(t, done) +} + // TestRunStream_NoRunIDFallsBackToSessionID preserves the older // behaviour for callers (and tests) that don't supply a RunID: // SessionID-only matching still terminates the stream on the diff --git a/internal/proto/agent.go b/internal/proto/agent.go index e5266e52614a5bc43065ff62cf18b16f8ee7401f..2c85923e547b6755357479218f9ff4815e491527 100644 --- a/internal/proto/agent.go +++ b/internal/proto/agent.go @@ -31,6 +31,13 @@ type AgentEvent struct { Message Message `json:"message"` Error error `json:"error,omitempty"` + // RunID echoes the caller-supplied AgentMessage.RunID for the run + // that produced this event. It lets observers (notably + // `crush run`) attribute an error event to a specific request + // instead of to any in-flight run on the session. Empty when no + // caller set one. + RunID string `json:"run_id,omitempty"` + // When summarizing. SessionID string `json:"session_id,omitempty"` SessionTitle string `json:"session_title,omitempty"` diff --git a/internal/server/events.go b/internal/server/events.go index fd085c5a415c0ef0fc402673ad23fff8435f1db6..526f9e195009cd70c453958778fb98887aae4a37 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -89,6 +89,7 @@ func wrapEvent(ev any) *pubsub.Payload { payload := proto.AgentEvent{ SessionID: e.Payload.SessionID, SessionTitle: e.Payload.SessionTitle, + RunID: e.Payload.RunID, Type: proto.AgentEventType(e.Payload.Type), } if e.Payload.Type == notify.TypeAgentError { diff --git a/internal/server/events_test.go b/internal/server/events_test.go index 432bc42f910b4acec675baea46754b81defab9f6..e4238a05eb3abf50e13329acfaabd2cb77dd464c 100644 --- a/internal/server/events_test.go +++ b/internal/server/events_test.go @@ -123,6 +123,38 @@ func TestRunCompleteToProto_RoundTrip(t *testing.T) { require.False(t, decoded.Payload.Cancelled) } +// TestAgentErrorToProto_PreservesRunID verifies that an async agent +// error notification carries its originating RunID (and SessionID) +// through the SSE envelope. Without these correlators, `crush run` +// cannot tell whether an error event belongs to its own run and +// would abort on any unrelated workspace failure. +func TestAgentErrorToProto_PreservesRunID(t *testing.T) { + t.Parallel() + + src := pubsub.Event[notify.Notification]{ + Type: pubsub.CreatedEvent, + Payload: notify.Notification{ + SessionID: "S", + RunID: "run-99", + Type: notify.TypeAgentError, + Message: "boom", + }, + } + + env := wrapEvent(src) + require.NotNil(t, env) + require.Equal(t, pubsub.PayloadTypeAgentEvent, env.Type) + + var decoded pubsub.Event[proto.AgentEvent] + require.NoError(t, json.Unmarshal(env.Payload, &decoded)) + require.Equal(t, proto.AgentEventTypeError, decoded.Payload.Type) + require.Equal(t, "S", decoded.Payload.SessionID) + require.Equal(t, "run-99", decoded.Payload.RunID, + "RunID must survive so observers can attribute the error to its run") + require.NotNil(t, decoded.Payload.Error) + require.Equal(t, "boom", decoded.Payload.Error.Error()) +} + // TestRunCompleteToProto_Error verifies that error- and cancel-shaped // RunComplete events round-trip cleanly so clients can distinguish // "agent failed" (returns non-zero from `crush run`) from "agent diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index a6a43731675698083671cae95f983d7a3a724a5d..609a9145bd3a5374c6fbaf96b3a7549187b146d5 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -706,6 +706,7 @@ func (w *ClientWorkspace) translateEvent(ev any) tea.Msg { n := notify.Notification{ SessionID: e.Payload.SessionID, SessionTitle: e.Payload.SessionTitle, + RunID: e.Payload.RunID, Type: notify.Type(e.Payload.Type), } if e.Payload.Error != nil {