diff --git a/server/src/rpc.rs b/server/src/rpc.rs index fec6182fcc1ff02cf696fed5cfcba32f41564af4..52ebd513960e1660c2b381c6f0170f50fec649ea 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -894,8 +894,11 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade"); let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket"); let upgrade_requested = connection_upgrade && upgrade_to_websocket; + let client_protocol_version: Option = request + .header("X-ZRPC-VERSION") + .and_then(|v| v.as_str().parse().ok()); - if !upgrade_requested { + if !upgrade_requested || client_protocol_version != Some(zrpc::VERSION) { return Ok(Response::new(StatusCode::UpgradeRequired)); } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 5dc2b49b76d9ac8ca958f6dc3903a8da3603dbcd..5b18ff310bcdaafcc84ecfd45fe83302bd9f54e2 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -55,6 +55,8 @@ pub struct Client { #[derive(Error, Debug)] pub enum EstablishConnectionError { + #[error("upgrade required")] + UpgradeRequired, #[error("unauthorized")] Unauthorized, #[error("{0}")] @@ -68,8 +70,10 @@ pub enum EstablishConnectionError { impl From for EstablishConnectionError { fn from(error: WebsocketError) -> Self { if let WebsocketError::Http(response) = &error { - if response.status() == StatusCode::UNAUTHORIZED { - return EstablishConnectionError::Unauthorized; + match response.status() { + StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized, + StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired, + _ => {} } } EstablishConnectionError::Other(error.into()) @@ -85,6 +89,7 @@ impl EstablishConnectionError { #[derive(Copy, Clone, Debug)] pub enum Status { SignedOut, + UpgradeRequired, Authenticating, Connecting, ConnectionError, @@ -227,7 +232,7 @@ impl Client { } })); } - Status::SignedOut => { + Status::SignedOut | Status::UpgradeRequired => { state._maintain_connection.take(); } _ => {} @@ -346,6 +351,7 @@ impl Client { | Status::Reconnecting { .. } | Status::Authenticating | Status::Reauthenticating => return Ok(()), + Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?, }; if was_disconnected { @@ -388,22 +394,25 @@ impl Client { self.set_connection(conn, cx).await; Ok(()) } - Err(err) => { - if matches!(err, EstablishConnectionError::Unauthorized) { - self.state.write().credentials.take(); - if used_keychain { - cx.platform().delete_credentials(&ZED_SERVER_URL).log_err(); - self.set_status(Status::SignedOut, cx); - self.authenticate_and_connect(cx).await - } else { - self.set_status(Status::ConnectionError, cx); - Err(err)? - } + Err(EstablishConnectionError::Unauthorized) => { + self.state.write().credentials.take(); + if used_keychain { + cx.platform().delete_credentials(&ZED_SERVER_URL).log_err(); + self.set_status(Status::SignedOut, cx); + self.authenticate_and_connect(cx).await } else { self.set_status(Status::ConnectionError, cx); - Err(err)? + Err(EstablishConnectionError::Unauthorized)? } } + Err(EstablishConnectionError::UpgradeRequired) => { + self.set_status(Status::UpgradeRequired, cx); + Err(EstablishConnectionError::UpgradeRequired)? + } + Err(error) => { + self.set_status(Status::ConnectionError, cx); + Err(error)? + } } } @@ -489,10 +498,12 @@ impl Client { credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { - let request = Request::builder().header( - "Authorization", - format!("{} {}", credentials.user_id, credentials.access_token), - ); + let request = Request::builder() + .header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ) + .header("X-ZRPC-VERSION", zrpc::VERSION); cx.background().spawn(async move { if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { let stream = smol::net::TcpStream::connect(host).await?; diff --git a/zrpc/src/lib.rs b/zrpc/src/lib.rs index a7bb44774b8e700443f753e3fb47c1176ef80142..607cd252f0a1dce3934bdc025d5cfcad46e62d47 100644 --- a/zrpc/src/lib.rs +++ b/zrpc/src/lib.rs @@ -4,3 +4,5 @@ mod peer; pub mod proto; pub use conn::Connection; pub use peer::*; + +pub const VERSION: u32 = 0;