@@ -1023,7 +1023,7 @@ mod tests {
editor::{Editor, Insert},
fs::{FakeFs, Fs as _},
language::LanguageRegistry,
- rpc::{self, Client},
+ rpc::{self, Client, Credentials},
settings,
test::FakeHttpClient,
user::UserStore,
@@ -1922,39 +1922,40 @@ mod tests {
let forbid_connections = self.forbid_connections.clone();
Arc::get_mut(&mut client)
.unwrap()
- .set_login_and_connect_callbacks(
- move |cx| {
- cx.spawn(|_| async move {
- let access_token = "the-token".to_string();
- Ok((client_user_id.0 as u64, access_token))
+ .override_authenticate(move |cx| {
+ cx.spawn(|_| async move {
+ let access_token = "the-token".to_string();
+ Ok(Credentials {
+ user_id: client_user_id.0 as u64,
+ access_token,
})
- },
- move |user_id, access_token, cx| {
- assert_eq!(user_id, client_user_id.0 as u64);
- assert_eq!(access_token, "the-token");
-
- let server = server.clone();
- let connection_killers = connection_killers.clone();
- let forbid_connections = forbid_connections.clone();
- let client_name = client_name.clone();
- cx.spawn(move |cx| async move {
- if forbid_connections.load(SeqCst) {
- Err(anyhow!("server is forbidding connections"))
- } else {
- let (client_conn, server_conn, kill_conn) = Conn::in_memory();
- connection_killers.lock().insert(client_user_id, kill_conn);
- cx.background()
- .spawn(server.handle_connection(
- server_conn,
- client_name,
- client_user_id,
- ))
- .detach();
- Ok(client_conn)
- }
- })
- },
- );
+ })
+ })
+ .override_establish_connection(move |credentials, cx| {
+ assert_eq!(credentials.user_id, client_user_id.0 as u64);
+ assert_eq!(credentials.access_token, "the-token");
+
+ let server = server.clone();
+ let connection_killers = connection_killers.clone();
+ let forbid_connections = forbid_connections.clone();
+ let client_name = client_name.clone();
+ cx.spawn(move |cx| async move {
+ if forbid_connections.load(SeqCst) {
+ Err(anyhow!("server is forbidding connections"))
+ } else {
+ let (client_conn, server_conn, kill_conn) = Conn::in_memory();
+ connection_killers.lock().insert(client_user_id, kill_conn);
+ cx.background()
+ .spawn(server.handle_connection(
+ server_conn,
+ client_name,
+ client_user_id,
+ ))
+ .detach();
+ Ok(client_conn)
+ }
+ })
+ });
client
.authenticate_and_connect(&cx.to_async())
@@ -29,11 +29,10 @@ lazy_static! {
pub struct Client {
peer: Arc<Peer>,
state: RwLock<ClientState>,
- auth_callback: Option<
- Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
- >,
- connect_callback: Option<
- Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
+ authenticate:
+ Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
+ establish_connection: Option<
+ Box<dyn 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>>,
>,
}
@@ -41,25 +40,17 @@ pub struct Client {
pub enum Status {
SignedOut,
Authenticating,
- Connecting {
- user_id: u64,
- },
+ Connecting,
ConnectionError,
- Connected {
- connection_id: ConnectionId,
- user_id: u64,
- },
+ Connected { connection_id: ConnectionId },
ConnectionLost,
Reauthenticating,
- Reconnecting {
- user_id: u64,
- },
- ReconnectionError {
- next_reconnection: Instant,
- },
+ Reconnecting,
+ ReconnectionError { next_reconnection: Instant },
}
struct ClientState {
+ credentials: Option<Credentials>,
status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
model_handlers: HashMap<
@@ -70,9 +61,16 @@ struct ClientState {
heartbeat_interval: Duration,
}
+#[derive(Clone)]
+pub struct Credentials {
+ pub user_id: u64,
+ pub access_token: String,
+}
+
impl Default for ClientState {
fn default() -> Self {
Self {
+ credentials: None,
status: watch::channel_with(Status::SignedOut),
entity_id_extractors: Default::default(),
model_handlers: Default::default(),
@@ -107,22 +105,35 @@ impl Client {
Arc::new(Self {
peer: Peer::new(),
state: Default::default(),
- auth_callback: None,
- connect_callback: None,
+ authenticate: None,
+ establish_connection: None,
})
}
#[cfg(any(test, feature = "test-support"))]
- pub fn set_login_and_connect_callbacks<Login, Connect>(
- &mut self,
- login: Login,
- connect: Connect,
- ) where
- Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
- Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
+ pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
+ where
+ F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
{
- self.auth_callback = Some(Box::new(login));
- self.connect_callback = Some(Box::new(connect));
+ self.authenticate = Some(Box::new(authenticate));
+ self
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
+ where
+ F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>,
+ {
+ self.establish_connection = Some(Box::new(connect));
+ self
+ }
+
+ pub fn user_id(&self) -> Option<u64> {
+ self.state
+ .read()
+ .credentials
+ .as_ref()
+ .map(|credentials| credentials.user_id)
}
pub fn status(&self) -> watch::Receiver<Status> {
@@ -249,23 +260,31 @@ impl Client {
self.set_status(Status::Reauthenticating, cx)
}
- let (user_id, access_token) = match self.authenticate(&cx).await {
- Ok(result) => result,
- Err(err) => {
- self.set_status(Status::ConnectionError, cx);
- return Err(err);
- }
+ let credentials = self.state.read().credentials.clone();
+ let credentials = if let Some(credentials) = credentials {
+ credentials
+ } else {
+ let credentials = match self.authenticate(&cx).await {
+ Ok(credentials) => credentials,
+ Err(err) => {
+ self.set_status(Status::ConnectionError, cx);
+ return Err(err);
+ }
+ };
+ self.state.write().credentials = Some(credentials.clone());
+ credentials
};
if was_disconnected {
- self.set_status(Status::Connecting { user_id }, cx);
+ self.set_status(Status::Connecting, cx);
} else {
- self.set_status(Status::Reconnecting { user_id }, cx);
+ self.set_status(Status::Reconnecting, cx);
}
- match self.connect(user_id, &access_token, cx).await {
+
+ match self.establish_connection(&credentials, cx).await {
Ok(conn) => {
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
- self.set_connection(user_id, conn, cx).await;
+ self.set_connection(conn, cx).await;
Ok(())
}
Err(err) => {
@@ -275,7 +294,7 @@ impl Client {
}
}
- async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
+ async fn set_connection(self: &Arc<Self>, conn: Conn, cx: &AsyncAppContext) {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
cx.foreground()
.spawn({
@@ -310,13 +329,7 @@ impl Client {
})
.detach();
- self.set_status(
- Status::Connected {
- connection_id,
- user_id,
- },
- cx,
- );
+ self.set_status(Status::Connected { connection_id }, cx);
let handle_io = cx.background().spawn(handle_io);
let this = self.clone();
@@ -334,35 +347,35 @@ impl Client {
.detach();
}
- fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
- if let Some(callback) = self.auth_callback.as_ref() {
+ fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
+ if let Some(callback) = self.authenticate.as_ref() {
callback(cx)
} else {
self.authenticate_with_browser(cx)
}
}
- fn connect(
+ fn establish_connection(
self: &Arc<Self>,
- user_id: u64,
- access_token: &str,
+ credentials: &Credentials,
cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
- if let Some(callback) = self.connect_callback.as_ref() {
- callback(user_id, access_token, cx)
+ if let Some(callback) = self.establish_connection.as_ref() {
+ callback(credentials, cx)
} else {
- self.connect_with_websocket(user_id, access_token, cx)
+ self.establish_websocket_connection(credentials, cx)
}
}
- fn connect_with_websocket(
+ fn establish_websocket_connection(
self: &Arc<Self>,
- user_id: u64,
- access_token: &str,
+ credentials: &Credentials,
cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
- let request =
- Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
+ let request = Request::builder().header(
+ "Authorization",
+ format!("{} {}", credentials.user_id, credentials.access_token),
+ );
cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
@@ -387,7 +400,7 @@ impl Client {
pub fn authenticate_with_browser(
self: &Arc<Self>,
cx: &AsyncAppContext,
- ) -> Task<Result<(u64, String)>> {
+ ) -> Task<Result<Credentials>> {
let platform = cx.platform();
let executor = cx.background();
executor.clone().spawn(async move {
@@ -397,7 +410,10 @@ impl Client {
.flatten()
{
log::info!("already signed in. user_id: {}", user_id);
- return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
+ return Ok(Credentials {
+ user_id: user_id.parse()?,
+ access_token: String::from_utf8(access_token).unwrap(),
+ });
}
// Generate a pair of asymmetric encryption keys. The public key will be used by the
@@ -463,7 +479,11 @@ impl Client {
platform
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
.log_err();
- Ok((user_id.parse()?, access_token))
+
+ Ok(Credentials {
+ user_id: user_id.parse()?,
+ access_token,
+ })
})
}
@@ -4,7 +4,7 @@ use crate::{
fs::RealFs,
http::{HttpClient, Request, Response, ServerResponse},
language::LanguageRegistry,
- rpc::{self, Client},
+ rpc::{self, Client, Credentials},
settings::{self, ThemeRegistry},
time::ReplicaId,
user::UserStore,
@@ -226,25 +226,26 @@ impl FakeServer {
Arc::get_mut(client)
.unwrap()
- .set_login_and_connect_callbacks(
- move |cx| {
- cx.spawn(|_| async move {
- let access_token = "the-token".to_string();
- Ok((client_user_id, access_token))
+ .override_authenticate(move |cx| {
+ cx.spawn(|_| async move {
+ let access_token = "the-token".to_string();
+ Ok(Credentials {
+ user_id: client_user_id,
+ access_token,
})
- },
- {
- let server = result.clone();
- move |user_id, access_token, cx| {
- assert_eq!(user_id, client_user_id);
- assert_eq!(access_token, "the-token");
- cx.spawn({
- let server = server.clone();
- move |cx| async move { server.connect(&cx).await }
- })
- }
- },
- );
+ })
+ })
+ .override_establish_connection({
+ let server = result.clone();
+ move |credentials, cx| {
+ assert_eq!(credentials.user_id, client_user_id);
+ assert_eq!(credentials.access_token, "the-token");
+ cx.spawn({
+ let server = server.clone();
+ move |cx| async move { server.connect(&cx).await }
+ })
+ }
+ });
client
.authenticate_and_connect(&cx.to_async())