@@ -894,8 +894,11 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
let upgrade_requested = connection_upgrade && upgrade_to_websocket;
+ let client_protocol_version: Option<u32> = request
+ .header("X-ZRPC-VERSION")
+ .and_then(|v| v.as_str().parse().ok());
- if !upgrade_requested {
+ if !upgrade_requested || client_protocol_version != Some(zrpc::VERSION) {
return Ok(Response::new(StatusCode::UpgradeRequired));
}
@@ -55,6 +55,8 @@ pub struct Client {
#[derive(Error, Debug)]
pub enum EstablishConnectionError {
+ #[error("upgrade required")]
+ UpgradeRequired,
#[error("unauthorized")]
Unauthorized,
#[error("{0}")]
@@ -68,8 +70,10 @@ pub enum EstablishConnectionError {
impl From<WebsocketError> for EstablishConnectionError {
fn from(error: WebsocketError) -> Self {
if let WebsocketError::Http(response) = &error {
- if response.status() == StatusCode::UNAUTHORIZED {
- return EstablishConnectionError::Unauthorized;
+ match response.status() {
+ StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
+ StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
+ _ => {}
}
}
EstablishConnectionError::Other(error.into())
@@ -85,6 +89,7 @@ impl EstablishConnectionError {
#[derive(Copy, Clone, Debug)]
pub enum Status {
SignedOut,
+ UpgradeRequired,
Authenticating,
Connecting,
ConnectionError,
@@ -227,7 +232,7 @@ impl Client {
}
}));
}
- Status::SignedOut => {
+ Status::SignedOut | Status::UpgradeRequired => {
state._maintain_connection.take();
}
_ => {}
@@ -346,6 +351,7 @@ impl Client {
| Status::Reconnecting { .. }
| Status::Authenticating
| Status::Reauthenticating => return Ok(()),
+ Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
};
if was_disconnected {
@@ -388,22 +394,25 @@ impl Client {
self.set_connection(conn, cx).await;
Ok(())
}
- Err(err) => {
- if matches!(err, EstablishConnectionError::Unauthorized) {
- self.state.write().credentials.take();
- if used_keychain {
- cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
- self.set_status(Status::SignedOut, cx);
- self.authenticate_and_connect(cx).await
- } else {
- self.set_status(Status::ConnectionError, cx);
- Err(err)?
- }
+ Err(EstablishConnectionError::Unauthorized) => {
+ self.state.write().credentials.take();
+ if used_keychain {
+ cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+ self.set_status(Status::SignedOut, cx);
+ self.authenticate_and_connect(cx).await
} else {
self.set_status(Status::ConnectionError, cx);
- Err(err)?
+ Err(EstablishConnectionError::Unauthorized)?
}
}
+ Err(EstablishConnectionError::UpgradeRequired) => {
+ self.set_status(Status::UpgradeRequired, cx);
+ Err(EstablishConnectionError::UpgradeRequired)?
+ }
+ Err(error) => {
+ self.set_status(Status::ConnectionError, cx);
+ Err(error)?
+ }
}
}
@@ -489,10 +498,12 @@ impl Client {
credentials: &Credentials,
cx: &AsyncAppContext,
) -> Task<Result<Connection, EstablishConnectionError>> {
- let request = Request::builder().header(
- "Authorization",
- format!("{} {}", credentials.user_id, credentials.access_token),
- );
+ let request = Request::builder()
+ .header(
+ "Authorization",
+ format!("{} {}", credentials.user_id, credentials.access_token),
+ )
+ .header("X-ZRPC-VERSION", zrpc::VERSION);
cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;