Detailed changes
@@ -5,10 +5,7 @@ use super::{
};
use anyhow::anyhow;
use async_std::{sync::RwLock, task};
-use async_tungstenite::{
- tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
- WebSocketStream,
-};
+use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use futures::{future::BoxFuture, FutureExt};
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
use sha1::{Digest as _, Sha1};
@@ -30,7 +27,7 @@ use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
- ConnectionId, Peer, TypedEnvelope,
+ Conn, ConnectionId, Peer, TypedEnvelope,
};
type ReplicaId = u16;
@@ -133,19 +130,12 @@ impl Server {
self
}
- pub fn handle_connection<Conn>(
+ pub fn handle_connection(
self: &Arc<Self>,
connection: Conn,
addr: String,
user_id: UserId,
- ) -> impl Future<Output = ()>
- where
- Conn: 'static
- + futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Send
- + Unpin,
- {
+ ) -> impl Future<Output = ()> {
let this = self.clone();
async move {
let (connection_id, handle_io, mut incoming_rx) =
@@ -974,8 +964,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
- let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
- server.handle_connection(stream, addr, user_id).await;
+ server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
}
});
@@ -1019,7 +1008,7 @@ mod tests {
fs::{FakeFs, Fs as _},
language::LanguageRegistry,
rpc::Client,
- settings, test,
+ settings,
user::UserStore,
worktree::Worktree,
};
@@ -1706,7 +1695,7 @@ mod tests {
) -> (UserId, Arc<Client>) {
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
let client = Client::new();
- let (client_conn, server_conn) = test::Channel::bidirectional();
+ let (client_conn, server_conn) = Conn::in_memory();
cx.background()
.spawn(
self.server
@@ -445,12 +445,13 @@ mod tests {
use super::*;
use crate::test::FakeServer;
use gpui::TestAppContext;
+ use std::time::Duration;
#[gpui::test]
async fn test_channel_messages(mut cx: TestAppContext) {
let user_id = 5;
let client = Client::new();
- let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+ let server = FakeServer::for_client(user_id, &client, &cx).await;
let user_store = Arc::new(UserStore::new(client.clone()));
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
@@ -1,8 +1,6 @@
use crate::util::ResultExt;
use anyhow::{anyhow, Context, Result};
-use async_tungstenite::tungstenite::{
- http::Request, Error as WebSocketError, Message as WebSocketMessage,
-};
+use async_tungstenite::tungstenite::http::Request;
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static;
use parking_lot::RwLock;
@@ -19,7 +17,7 @@ use surf::Url;
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
- Peer, Receipt,
+ Conn, Peer, Receipt,
};
lazy_static! {
@@ -106,6 +104,7 @@ impl Client {
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
let mut state = self.state.write();
*state.status.0.borrow_mut() = status;
+
match status {
Status::Connected { .. } => {
let heartbeat_interval = state.heartbeat_interval;
@@ -193,75 +192,46 @@ impl Client {
) -> anyhow::Result<()> {
if matches!(
*self.status().borrow(),
- Status::Connecting | Status::Connected { .. }
+ Status::Connecting { .. } | Status::Connected { .. }
) {
return Ok(());
}
- let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
- let user_id = user_id.parse::<u64>()?;
+ let (user_id, access_token) = match self.authenticate(&cx).await {
+ Ok(result) => result,
+ Err(err) => {
+ self.set_status(Status::ConnectionError, cx);
+ return Err(err);
+ }
+ };
self.set_status(Status::Connecting, cx);
- match self.connect(user_id, &access_token, cx).await {
- Ok(()) => {
- log::info!("connected to rpc address {}", *ZED_SERVER_URL);
- Ok(())
- }
+
+ let conn = match self.connect(user_id, &access_token, cx).await {
+ Ok(conn) => conn,
Err(err) => {
self.set_status(Status::ConnectionError, cx);
- Err(err)
+ return Err(err);
}
- }
- }
+ };
- async fn connect(
- self: &Arc<Self>,
- user_id: u64,
- access_token: &str,
- cx: &AsyncAppContext,
- ) -> Result<()> {
- let request =
- Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
- if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
- let stream = smol::net::TcpStream::connect(host).await?;
- let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
- let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
- .await
- .context("websocket handshake")?;
- self.set_connection(user_id, stream, cx).await?;
- Ok(())
- } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
- let stream = smol::net::TcpStream::connect(host).await?;
- let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
- let (stream, _) = async_tungstenite::client_async(request, stream)
- .await
- .context("websocket handshake")?;
- self.set_connection(user_id, stream, cx).await?;
- Ok(())
- } else {
- return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL));
- }
+ self.set_connection(user_id, conn, cx).await?;
+ log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+ Ok(())
}
- pub async fn set_connection<Conn>(
+ pub async fn set_connection(
self: &Arc<Self>,
user_id: u64,
conn: Conn,
cx: &AsyncAppContext,
- ) -> Result<()>
- where
- Conn: 'static
- + futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Unpin
- + Send,
- {
+ ) -> Result<()> {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
- {
- let mut cx = cx.clone();
- let this = self.clone();
- cx.foreground()
- .spawn(async move {
+ cx.foreground()
+ .spawn({
+ let mut cx = cx.clone();
+ let this = self.clone();
+ async move {
while let Some(message) = incoming.recv().await {
let mut state = this.state.write();
if let Some(extract_entity_id) =
@@ -286,9 +256,9 @@ impl Client {
log::info!("unhandled message {}", message.payload_type_name());
}
}
- })
- .detach();
- }
+ }
+ })
+ .detach();
self.set_status(
Status::Connected {
@@ -315,11 +285,38 @@ impl Client {
Ok(())
}
- pub fn login(
- platform: Arc<dyn gpui::Platform>,
- executor: &Arc<gpui::executor::Background>,
- ) -> Task<Result<(String, String)>> {
- let executor = executor.clone();
+ fn connect(
+ self: &Arc<Self>,
+ user_id: u64,
+ access_token: &str,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<Conn>> {
+ let request =
+ Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
+ cx.background().spawn(async move {
+ if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
+ let stream = smol::net::TcpStream::connect(host).await?;
+ let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
+ let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
+ .await
+ .context("websocket handshake")?;
+ Ok(Conn::new(stream))
+ } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
+ let stream = smol::net::TcpStream::connect(host).await?;
+ let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
+ let (stream, _) = async_tungstenite::client_async(request, stream)
+ .await
+ .context("websocket handshake")?;
+ Ok(Conn::new(stream))
+ } else {
+ Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
+ }
+ })
+ }
+
+ pub fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
+ let platform = cx.platform();
+ let executor = cx.background();
executor.clone().spawn(async move {
if let Some((user_id, access_token)) = platform
.read_credentials(&ZED_SERVER_URL)
@@ -327,7 +324,7 @@ impl Client {
.flatten()
{
log::info!("already signed in. user_id: {}", user_id);
- return Ok((user_id, String::from_utf8(access_token).unwrap()));
+ return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
}
// Generate a pair of asymmetric encryption keys. The public key will be used by the
@@ -393,7 +390,7 @@ impl Client {
platform
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
.log_err();
- Ok((user_id.to_string(), access_token))
+ Ok((user_id.parse()?, access_token))
})
}
@@ -492,7 +489,7 @@ mod tests {
async fn test_heartbeat(cx: TestAppContext) {
let user_id = 5;
let client = Client::new();
- let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+ let server = FakeServer::for_client(user_id, &client, &cx).await;
cx.foreground().advance_clock(Duration::from_secs(10));
let ping = server.receive::<proto::Ping>().await.unwrap();
@@ -10,7 +10,7 @@ use crate::{
AppState,
};
use anyhow::{anyhow, Result};
-use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext};
+use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
use parking_lot::Mutex;
use postage::{mpsc, prelude::Stream as _};
use smol::channel;
@@ -20,10 +20,7 @@ use std::{
sync::Arc,
};
use tempdir::TempDir;
-use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
-
-#[cfg(feature = "test-support")]
-pub use zrpc::test::Channel;
+use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(test)]
#[ctor::ctor]
@@ -201,40 +198,64 @@ impl<T: Entity> Observer<T> {
pub struct FakeServer {
peer: Arc<Peer>,
- incoming: mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>,
- connection_id: ConnectionId,
+ incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
+ connection_id: Mutex<Option<ConnectionId>>,
}
impl FakeServer {
- pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
- let (client_conn, server_conn) = zrpc::test::Channel::bidirectional();
- let peer = Peer::new();
- let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
- cx.background().spawn(io).detach();
+ pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Arc<Self> {
+ let result = Arc::new(Self {
+ peer: Peer::new(),
+ incoming: Default::default(),
+ connection_id: Default::default(),
+ });
+ let conn = result.connect(&cx.to_async()).await;
client
- .set_connection(user_id, client_conn, &cx.to_async())
+ .set_connection(user_id, conn, &cx.to_async())
.await
.unwrap();
+ result
+ }
- Self {
- peer,
- incoming,
- connection_id,
- }
+ pub async fn disconnect(&self) {
+ self.peer.disconnect(self.connection_id()).await;
+ self.connection_id.lock().take();
+ self.incoming.lock().take();
+ }
+
+ async fn connect(&self, cx: &AsyncAppContext) -> Conn {
+ let (client_conn, server_conn) = Conn::in_memory();
+ let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
+ cx.background().spawn(io).detach();
+ *self.incoming.lock() = Some(incoming);
+ *self.connection_id.lock() = Some(connection_id);
+ client_conn
}
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
- self.peer.send(self.connection_id, message).await.unwrap();
+ self.peer.send(self.connection_id(), message).await.unwrap();
}
- pub async fn receive<M: proto::EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
+ pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
let message = self
.incoming
+ .lock()
+ .as_mut()
+ .expect("not connected")
.recv()
.await
.ok_or_else(|| anyhow!("other half hung up"))?;
- Ok(*message.into_any().downcast::<TypedEnvelope<M>>().unwrap())
+ let type_name = message.payload_type_name();
+ Ok(*message
+ .into_any()
+ .downcast::<TypedEnvelope<M>>()
+ .unwrap_or_else(|_| {
+ panic!(
+ "fake server received unexpected message type: {:?}",
+ type_name
+ );
+ }))
}
pub async fn respond<T: proto::RequestMessage>(
@@ -244,4 +265,8 @@ impl FakeServer {
) {
self.peer.respond(receipt, response).await.unwrap()
}
+
+ fn connection_id(&self) -> ConnectionId {
+ self.connection_id.lock().expect("not connected")
+ }
}
@@ -0,0 +1,54 @@
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
+use futures::{SinkExt as _, StreamExt as _};
+
+pub struct Conn {
+ pub(crate) tx:
+ Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+ pub(crate) rx: Box<
+ dyn 'static
+ + Send
+ + Unpin
+ + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+ >,
+}
+
+impl Conn {
+ pub fn new<S>(stream: S) -> Self
+ where
+ S: 'static
+ + Send
+ + Unpin
+ + futures::Sink<WebSocketMessage, Error = WebSocketError>
+ + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+ {
+ let (tx, rx) = stream.split();
+ Self {
+ tx: Box::new(tx),
+ rx: Box::new(rx),
+ }
+ }
+
+ pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
+ self.tx.send(message).await
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn in_memory() -> (Self, Self) {
+ use futures::SinkExt as _;
+ use futures::StreamExt as _;
+ use std::io::{Error, ErrorKind};
+
+ let (a_tx, a_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+ let (b_tx, b_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+ (
+ Self {
+ tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
+ rx: Box::new(b_rx.map(Ok)),
+ },
+ Self {
+ tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
+ rx: Box::new(a_rx.map(Ok)),
+ },
+ )
+ }
+}
@@ -1,7 +1,6 @@
pub mod auth;
+mod conn;
mod peer;
pub mod proto;
-#[cfg(any(test, feature = "test-support"))]
-pub mod test;
-
+pub use conn::Conn;
pub use peer::*;
@@ -1,8 +1,8 @@
-use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
+use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
+use super::Conn;
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{FutureExt, StreamExt};
+use futures::FutureExt as _;
use postage::{
mpsc,
prelude::{Sink as _, Stream as _},
@@ -98,21 +98,14 @@ impl Peer {
})
}
- pub async fn add_connection<Conn>(
+ pub async fn add_connection(
self: &Arc<Self>,
conn: Conn,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
- )
- where
- Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Send
- + Unpin,
- {
- let (tx, rx) = conn.split();
+ ) {
let connection_id = ConnectionId(
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
@@ -124,8 +117,8 @@ impl Peer {
next_message_id: Default::default(),
response_channels: Default::default(),
};
- let mut writer = MessageStream::new(tx);
- let mut reader = MessageStream::new(rx);
+ let mut writer = MessageStream::new(conn.tx);
+ let mut reader = MessageStream::new(conn.rx);
let this = self.clone();
let response_channels = connection.response_channels.clone();
@@ -347,7 +340,9 @@ impl Peer {
#[cfg(test)]
mod tests {
use super::*;
- use crate::{test, TypedEnvelope};
+ use crate::TypedEnvelope;
+ use async_tungstenite::tungstenite::Message as WebSocketMessage;
+ use futures::StreamExt as _;
#[test]
fn test_request_response() {
@@ -357,12 +352,12 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
- let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
+ let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory();
let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
- let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
+ let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory();
let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@@ -497,7 +492,7 @@ mod tests {
#[test]
fn test_disconnect() {
smol::block_on(async move {
- let (client_conn, mut server_conn) = test::Channel::bidirectional();
+ let (client_conn, mut server_conn) = Conn::in_memory();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
@@ -521,18 +516,17 @@ mod tests {
io_ended_rx.recv().await;
messages_ended_rx.recv().await;
- assert!(
- futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
- .await
- .is_err()
- );
+ assert!(server_conn
+ .send(WebSocketMessage::Binary(vec![]))
+ .await
+ .is_err());
});
}
#[test]
fn test_io_error() {
smol::block_on(async move {
- let (client_conn, server_conn) = test::Channel::bidirectional();
+ let (client_conn, server_conn) = Conn::in_memory();
drop(server_conn);
let client = Peer::new();
@@ -247,30 +247,3 @@ impl From<SystemTime> for Timestamp {
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::test;
-
- #[test]
- fn test_round_trip_message() {
- smol::block_on(async {
- let stream = test::Channel::new();
- let message1 = Ping { id: 5 }.into_envelope(3, None, None);
- let message2 = OpenBuffer {
- worktree_id: 0,
- path: "some/path".to_string(),
- }
- .into_envelope(5, None, None);
-
- let mut message_stream = MessageStream::new(stream);
- message_stream.write_message(&message1).await.unwrap();
- message_stream.write_message(&message2).await.unwrap();
- let decoded_message1 = message_stream.read_message().await.unwrap();
- let decoded_message2 = message_stream.read_message().await.unwrap();
- assert_eq!(decoded_message1, message1);
- assert_eq!(decoded_message2, message2);
- });
- }
-}
@@ -1,64 +0,0 @@
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use std::{
- io,
- pin::Pin,
- task::{Context, Poll},
-};
-
-pub struct Channel {
- tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
- rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
-}
-
-impl Channel {
- pub fn new() -> Self {
- let (tx, rx) = futures::channel::mpsc::unbounded();
- Self { tx, rx }
- }
-
- pub fn bidirectional() -> (Self, Self) {
- let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
- let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
- let a = Self { tx: a_tx, rx: b_rx };
- let b = Self { tx: b_tx, rx: a_rx };
- (a, b)
- }
-}
-
-impl futures::Sink<WebSocketMessage> for Channel {
- type Error = WebSocketError;
-
- fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.tx)
- .poll_ready(cx)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-
- fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
- Pin::new(&mut self.tx)
- .start_send(item)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-
- fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.tx)
- .poll_flush(cx)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-
- fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.tx)
- .poll_close(cx)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-}
-
-impl futures::Stream for Channel {
- type Item = Result<WebSocketMessage, WebSocketError>;
-
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- Pin::new(&mut self.rx)
- .poll_next(cx)
- .map(|i| i.map(|i| Ok(i)))
- }
-}