Version the zrpc protocol using a `X-ZRPC-VERSION` header

Antonio Scandurra created

Change summary

server/src/rpc.rs |  5 ++++-
zed/src/rpc.rs    | 49 ++++++++++++++++++++++++++++++-------------------
zrpc/src/lib.rs   |  2 ++
3 files changed, 36 insertions(+), 20 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -894,8 +894,11 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
             let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
             let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
             let upgrade_requested = connection_upgrade && upgrade_to_websocket;
+            let client_protocol_version: Option<u32> = request
+                .header("X-ZRPC-VERSION")
+                .and_then(|v| v.as_str().parse().ok());
 
-            if !upgrade_requested {
+            if !upgrade_requested || client_protocol_version != Some(zrpc::VERSION) {
                 return Ok(Response::new(StatusCode::UpgradeRequired));
             }
 

zed/src/rpc.rs 🔗

@@ -55,6 +55,8 @@ pub struct Client {
 
 #[derive(Error, Debug)]
 pub enum EstablishConnectionError {
+    #[error("upgrade required")]
+    UpgradeRequired,
     #[error("unauthorized")]
     Unauthorized,
     #[error("{0}")]
@@ -68,8 +70,10 @@ pub enum EstablishConnectionError {
 impl From<WebsocketError> for EstablishConnectionError {
     fn from(error: WebsocketError) -> Self {
         if let WebsocketError::Http(response) = &error {
-            if response.status() == StatusCode::UNAUTHORIZED {
-                return EstablishConnectionError::Unauthorized;
+            match response.status() {
+                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
+                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
+                _ => {}
             }
         }
         EstablishConnectionError::Other(error.into())
@@ -85,6 +89,7 @@ impl EstablishConnectionError {
 #[derive(Copy, Clone, Debug)]
 pub enum Status {
     SignedOut,
+    UpgradeRequired,
     Authenticating,
     Connecting,
     ConnectionError,
@@ -227,7 +232,7 @@ impl Client {
                     }
                 }));
             }
-            Status::SignedOut => {
+            Status::SignedOut | Status::UpgradeRequired => {
                 state._maintain_connection.take();
             }
             _ => {}
@@ -346,6 +351,7 @@ impl Client {
             | Status::Reconnecting { .. }
             | Status::Authenticating
             | Status::Reauthenticating => return Ok(()),
+            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
         };
 
         if was_disconnected {
@@ -388,22 +394,25 @@ impl Client {
                 self.set_connection(conn, cx).await;
                 Ok(())
             }
-            Err(err) => {
-                if matches!(err, EstablishConnectionError::Unauthorized) {
-                    self.state.write().credentials.take();
-                    if used_keychain {
-                        cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
-                        self.set_status(Status::SignedOut, cx);
-                        self.authenticate_and_connect(cx).await
-                    } else {
-                        self.set_status(Status::ConnectionError, cx);
-                        Err(err)?
-                    }
+            Err(EstablishConnectionError::Unauthorized) => {
+                self.state.write().credentials.take();
+                if used_keychain {
+                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+                    self.set_status(Status::SignedOut, cx);
+                    self.authenticate_and_connect(cx).await
                 } else {
                     self.set_status(Status::ConnectionError, cx);
-                    Err(err)?
+                    Err(EstablishConnectionError::Unauthorized)?
                 }
             }
+            Err(EstablishConnectionError::UpgradeRequired) => {
+                self.set_status(Status::UpgradeRequired, cx);
+                Err(EstablishConnectionError::UpgradeRequired)?
+            }
+            Err(error) => {
+                self.set_status(Status::ConnectionError, cx);
+                Err(error)?
+            }
         }
     }
 
@@ -489,10 +498,12 @@ impl Client {
         credentials: &Credentials,
         cx: &AsyncAppContext,
     ) -> Task<Result<Connection, EstablishConnectionError>> {
-        let request = Request::builder().header(
-            "Authorization",
-            format!("{} {}", credentials.user_id, credentials.access_token),
-        );
+        let request = Request::builder()
+            .header(
+                "Authorization",
+                format!("{} {}", credentials.user_id, credentials.access_token),
+            )
+            .header("X-ZRPC-VERSION", zrpc::VERSION);
         cx.background().spawn(async move {
             if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
                 let stream = smol::net::TcpStream::connect(host).await?;

zrpc/src/lib.rs 🔗

@@ -4,3 +4,5 @@ mod peer;
 pub mod proto;
 pub use conn::Connection;
 pub use peer::*;
+
+pub const VERSION: u32 = 0;