@@ -423,12 +423,28 @@ func (c *Client) SendMessage(ctx context.Context, id string, sessionID, runID, p
return fmt.Errorf("failed to send message to agent: %w", err)
}
defer rsp.Body.Close()
- if rsp.StatusCode != http.StatusOK {
+ if rsp.StatusCode != http.StatusOK && rsp.StatusCode != http.StatusAccepted {
+ if msg := decodeErrorMessage(rsp.Body); msg != "" {
+ return fmt.Errorf("failed to send message to agent: status code %d: %s", rsp.StatusCode, msg)
+ }
return fmt.Errorf("failed to send message to agent: status code %d", rsp.StatusCode)
}
return nil
}
+// decodeErrorMessage attempts to decode the response body as a
+// proto.Error and returns its message. It returns an empty string
+// when the body is empty or cannot be decoded into a proto.Error
+// with a non-empty message, letting callers fall back to a
+// status-only error.
+func decodeErrorMessage(body io.Reader) string {
+ var e proto.Error
+ if err := json.NewDecoder(body).Decode(&e); err != nil {
+ return ""
+ }
+ return e.Message
+}
+
// GetAgentSessionInfo retrieves the agent session info for a workspace.
func (c *Client) GetAgentSessionInfo(ctx context.Context, id string, sessionID string) (*proto.AgentSession, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s", id, sessionID), nil, nil)
@@ -88,6 +88,76 @@ func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) {
}
}
+func TestSendMessageAcceptsStatusAccepted(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusAccepted)
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello"))
+}
+
+func TestSendMessageAcceptsStatusOK(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello"))
+}
+
+func TestSendMessageDecodesErrorBody(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusBadRequest)
+ _ = json.NewEncoder(w).Encode(proto.Error{Message: "session id is required"})
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ err := c.SendMessage(context.Background(), "ws1", "", "", "hello")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status code 400")
+ require.Contains(t, err.Error(), "session id is required")
+}
+
+func TestSendMessageFallsBackOnMalformedErrorBody(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("not json"))
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status code 500")
+ require.NotContains(t, err.Error(), "not json")
+}
+
+func TestSendMessageFallsBackOnEmptyErrorBody(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status code 500")
+}
+
func marshalSSEPayload(t *testing.T) []byte {
t.Helper()