diff --git a/internal/client/proto.go b/internal/client/proto.go index 62a43b5884e01ae8fcd3242c68e95d1f76251c42..d07e46dc84bf09dccffbd609784f92c7ae9a9c67 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -423,12 +423,28 @@ func (c *Client) SendMessage(ctx context.Context, id string, sessionID, runID, p return fmt.Errorf("failed to send message to agent: %w", err) } defer rsp.Body.Close() - if rsp.StatusCode != http.StatusOK { + if rsp.StatusCode != http.StatusOK && rsp.StatusCode != http.StatusAccepted { + if msg := decodeErrorMessage(rsp.Body); msg != "" { + return fmt.Errorf("failed to send message to agent: status code %d: %s", rsp.StatusCode, msg) + } return fmt.Errorf("failed to send message to agent: status code %d", rsp.StatusCode) } return nil } +// decodeErrorMessage attempts to decode the response body as a +// proto.Error and returns its message. It returns an empty string +// when the body is empty or cannot be decoded into a proto.Error +// with a non-empty message, letting callers fall back to a +// status-only error. +func decodeErrorMessage(body io.Reader) string { + var e proto.Error + if err := json.NewDecoder(body).Decode(&e); err != nil { + return "" + } + return e.Message +} + // GetAgentSessionInfo retrieves the agent session info for a workspace. func (c *Client) GetAgentSessionInfo(ctx context.Context, id string, sessionID string) (*proto.AgentSession, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s", id, sessionID), nil, nil) diff --git a/internal/client/proto_test.go b/internal/client/proto_test.go index b5739ccc91c16b2bb0fc3c3f6dc2281687bd8e65..c7abd3e03d4ae6f575079c7c938369d6cb7cc30b 100644 --- a/internal/client/proto_test.go +++ b/internal/client/proto_test.go @@ -88,6 +88,76 @@ func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) { } } +func TestSendMessageAcceptsStatusAccepted(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + c := captureClient(t, srv) + require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")) +} + +func TestSendMessageAcceptsStatusOK(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + c := captureClient(t, srv) + require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")) +} + +func TestSendMessageDecodesErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(proto.Error{Message: "session id is required"}) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 400") + require.Contains(t, err.Error(), "session id is required") +} + +func TestSendMessageFallsBackOnMalformedErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("not json")) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 500") + require.NotContains(t, err.Error(), "not json") +} + +func TestSendMessageFallsBackOnEmptyErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 500") +} + func marshalSSEPayload(t *testing.T) []byte { t.Helper()