@@ -28,12 +28,21 @@ lazy_static! {
pub struct Client {
peer: Arc<Peer>,
state: RwLock<ClientState>,
+ auth_callback: Option<
+ Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
+ >,
+ connect_callback: Option<
+ Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
+ >,
}
#[derive(Copy, Clone, Debug)]
pub enum Status {
Disconnected,
- Connecting,
+ Authenticating,
+ Connecting {
+ user_id: u64,
+ },
ConnectionError,
Connected {
connection_id: ConnectionId,
@@ -94,9 +103,24 @@ impl Client {
Arc::new(Self {
peer: Peer::new(),
state: Default::default(),
+ auth_callback: None,
+ connect_callback: None,
})
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn set_login_and_connect_callbacks<Login, Connect>(
+ &mut self,
+ login: Login,
+ connect: Connect,
+ ) where
+ Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
+ Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
+ {
+ self.auth_callback = Some(Box::new(login));
+ self.connect_callback = Some(Box::new(connect));
+ }
+
pub fn status(&self) -> watch::Receiver<Status> {
self.state.read().status.1.clone()
}
@@ -192,11 +216,13 @@ impl Client {
) -> anyhow::Result<()> {
if matches!(
*self.status().borrow(),
- Status::Connecting { .. } | Status::Connected { .. }
+ Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. }
) {
return Ok(());
}
+ self.set_status(Status::Authenticating, cx);
+
let (user_id, access_token) = match self.authenticate(&cx).await {
Ok(result) => result,
Err(err) => {
@@ -205,7 +231,7 @@ impl Client {
}
};
- self.set_status(Status::Connecting, cx);
+ self.set_status(Status::Connecting { user_id }, cx);
let conn = match self.connect(user_id, &access_token, cx).await {
Ok(conn) => conn,
@@ -285,11 +311,32 @@ impl Client {
Ok(())
}
+ fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
+ if let Some(callback) = self.auth_callback.as_ref() {
+ callback(cx)
+ } else {
+ self.authenticate_with_browser(cx)
+ }
+ }
+
fn connect(
self: &Arc<Self>,
user_id: u64,
access_token: &str,
cx: &AsyncAppContext,
+ ) -> Task<Result<Conn>> {
+ if let Some(callback) = self.connect_callback.as_ref() {
+ callback(user_id, access_token, cx)
+ } else {
+ self.connect_with_websocket(user_id, access_token, cx)
+ }
+ }
+
+ fn connect_with_websocket(
+ self: &Arc<Self>,
+ user_id: u64,
+ access_token: &str,
+ cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
let request =
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
@@ -314,7 +361,10 @@ impl Client {
})
}
- pub fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
+ pub fn authenticate_with_browser(
+ self: &Arc<Self>,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<(u64, String)>> {
let platform = cx.platform();
let executor = cx.background();
executor.clone().spawn(async move {
@@ -488,8 +538,8 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_heartbeat(cx: TestAppContext) {
let user_id = 5;
- let client = Client::new();
- let server = FakeServer::for_client(user_id, &client, &cx).await;
+ let mut client = Client::new();
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
cx.foreground().advance_clock(Duration::from_secs(10));
let ping = server.receive::<proto::Ping>().await.unwrap();
@@ -203,16 +203,42 @@ pub struct FakeServer {
}
impl FakeServer {
- pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Arc<Self> {
+ pub async fn for_client(
+ client_user_id: u64,
+ client: &mut Arc<Client>,
+ cx: &TestAppContext,
+ ) -> Arc<Self> {
let result = Arc::new(Self {
peer: Peer::new(),
incoming: Default::default(),
connection_id: Default::default(),
});
+ Arc::get_mut(client)
+ .unwrap()
+ .set_login_and_connect_callbacks(
+ move |cx| {
+ cx.spawn(|_| async move {
+ let access_token = "the-token".to_string();
+ Ok((client_user_id, access_token))
+ })
+ },
+ {
+ let server = result.clone();
+ move |user_id, access_token, cx| {
+ assert_eq!(user_id, client_user_id);
+ assert_eq!(access_token, "the-token");
+ cx.spawn({
+ let server = server.clone();
+ move |cx| async move { Ok(server.connect(&cx).await) }
+ })
+ }
+ },
+ );
+
let conn = result.connect(&cx.to_async()).await;
client
- .set_connection(user_id, conn, &cx.to_async())
+ .set_connection(client_user_id, conn, &cx.to_async())
.await
.unwrap();
result