1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod user;
7
8use anyhow::{anyhow, Context, Result};
9use async_recursion::async_recursion;
10use async_tungstenite::tungstenite::{
11 error::Error as WebsocketError,
12 http::{Request, StatusCode},
13};
14use futures::{future::LocalBoxFuture, FutureExt, StreamExt};
15use gpui::{
16 action, AnyModelHandle, AnyWeakModelHandle, AsyncAppContext, Entity, ModelContext, ModelHandle,
17 MutableAppContext, Task,
18};
19use http::HttpClient;
20use lazy_static::lazy_static;
21use parking_lot::RwLock;
22use postage::watch;
23use rand::prelude::*;
24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
25use std::{
26 any::TypeId,
27 collections::HashMap,
28 convert::TryFrom,
29 fmt::Write as _,
30 future::Future,
31 sync::{
32 atomic::{AtomicUsize, Ordering},
33 Arc, Weak,
34 },
35 time::{Duration, Instant},
36};
37use surf::{http::Method, Url};
38use thiserror::Error;
39use util::{ResultExt, TryFutureExt};
40
41pub use channel::*;
42pub use rpc::*;
43pub use user::*;
44
45lazy_static! {
46 static ref ZED_SERVER_URL: String =
47 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
48 static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
49 .ok()
50 .and_then(|s| if s.is_empty() { None } else { Some(s) });
51}
52
53action!(Authenticate);
54
55pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
56 cx.add_global_action(move |_: &Authenticate, cx| {
57 let rpc = rpc.clone();
58 cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
59 .detach();
60 });
61}
62
63pub struct Client {
64 id: usize,
65 peer: Arc<Peer>,
66 http: Arc<dyn HttpClient>,
67 state: RwLock<ClientState>,
68 authenticate:
69 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
70 establish_connection: Option<
71 Box<
72 dyn 'static
73 + Send
74 + Sync
75 + Fn(
76 &Credentials,
77 &AsyncAppContext,
78 ) -> Task<Result<Connection, EstablishConnectionError>>,
79 >,
80 >,
81}
82
83#[derive(Error, Debug)]
84pub enum EstablishConnectionError {
85 #[error("upgrade required")]
86 UpgradeRequired,
87 #[error("unauthorized")]
88 Unauthorized,
89 #[error("{0}")]
90 Other(#[from] anyhow::Error),
91 #[error("{0}")]
92 Io(#[from] std::io::Error),
93 #[error("{0}")]
94 Http(#[from] async_tungstenite::tungstenite::http::Error),
95}
96
97impl From<WebsocketError> for EstablishConnectionError {
98 fn from(error: WebsocketError) -> Self {
99 if let WebsocketError::Http(response) = &error {
100 match response.status() {
101 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
102 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
103 _ => {}
104 }
105 }
106 EstablishConnectionError::Other(error.into())
107 }
108}
109
110impl EstablishConnectionError {
111 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
112 Self::Other(error.into())
113 }
114}
115
116#[derive(Copy, Clone, Debug)]
117pub enum Status {
118 SignedOut,
119 UpgradeRequired,
120 Authenticating,
121 Connecting,
122 ConnectionError,
123 Connected { connection_id: ConnectionId },
124 ConnectionLost,
125 Reauthenticating,
126 Reconnecting,
127 ReconnectionError { next_reconnection: Instant },
128}
129
130struct ClientState {
131 credentials: Option<Credentials>,
132 status: (watch::Sender<Status>, watch::Receiver<Status>),
133 entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
134 _maintain_connection: Option<Task<()>>,
135 heartbeat_interval: Duration,
136
137 pending_messages: HashMap<(TypeId, u64), Vec<Box<dyn AnyTypedEnvelope>>>,
138 models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
139 models_by_message_type: HashMap<TypeId, AnyModelHandle>,
140 model_types_by_message_type: HashMap<TypeId, TypeId>,
141 message_handlers: HashMap<
142 TypeId,
143 Arc<
144 dyn Send
145 + Sync
146 + Fn(
147 AnyModelHandle,
148 Box<dyn AnyTypedEnvelope>,
149 AsyncAppContext,
150 ) -> LocalBoxFuture<'static, Result<()>>,
151 >,
152 >,
153}
154
155#[derive(Clone, Debug)]
156pub struct Credentials {
157 pub user_id: u64,
158 pub access_token: String,
159}
160
161impl Default for ClientState {
162 fn default() -> Self {
163 Self {
164 credentials: None,
165 status: watch::channel_with(Status::SignedOut),
166 entity_id_extractors: Default::default(),
167 _maintain_connection: None,
168 heartbeat_interval: Duration::from_secs(5),
169 models_by_message_type: Default::default(),
170 models_by_entity_type_and_remote_id: Default::default(),
171 model_types_by_message_type: Default::default(),
172 pending_messages: Default::default(),
173 message_handlers: Default::default(),
174 }
175 }
176}
177
178pub enum Subscription {
179 Entity {
180 client: Weak<Client>,
181 id: (TypeId, u64),
182 },
183 Message {
184 client: Weak<Client>,
185 id: TypeId,
186 },
187}
188
189impl Drop for Subscription {
190 fn drop(&mut self) {
191 match self {
192 Subscription::Entity { client, id } => {
193 if let Some(client) = client.upgrade() {
194 let mut state = client.state.write();
195 let _ = state.models_by_entity_type_and_remote_id.remove(id);
196 }
197 }
198 Subscription::Message { client, id } => {
199 if let Some(client) = client.upgrade() {
200 let mut state = client.state.write();
201 let _ = state.model_types_by_message_type.remove(id);
202 let _ = state.message_handlers.remove(id);
203 }
204 }
205 }
206 }
207}
208
209impl Client {
210 pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
211 lazy_static! {
212 static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
213 }
214
215 Arc::new(Self {
216 id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
217 peer: Peer::new(),
218 http,
219 state: Default::default(),
220 authenticate: None,
221 establish_connection: None,
222 })
223 }
224
225 #[cfg(any(test, feature = "test-support"))]
226 pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
227 where
228 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
229 {
230 self.authenticate = Some(Box::new(authenticate));
231 self
232 }
233
234 #[cfg(any(test, feature = "test-support"))]
235 pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
236 where
237 F: 'static
238 + Send
239 + Sync
240 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
241 {
242 self.establish_connection = Some(Box::new(connect));
243 self
244 }
245
246 pub fn user_id(&self) -> Option<u64> {
247 self.state
248 .read()
249 .credentials
250 .as_ref()
251 .map(|credentials| credentials.user_id)
252 }
253
254 pub fn status(&self) -> watch::Receiver<Status> {
255 self.state.read().status.1.clone()
256 }
257
258 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
259 let mut state = self.state.write();
260 *state.status.0.borrow_mut() = status;
261
262 match status {
263 Status::Connected { .. } => {
264 let heartbeat_interval = state.heartbeat_interval;
265 let this = self.clone();
266 let foreground = cx.foreground();
267 state._maintain_connection = Some(cx.foreground().spawn(async move {
268 loop {
269 foreground.timer(heartbeat_interval).await;
270 let _ = this.request(proto::Ping {}).await;
271 }
272 }));
273 }
274 Status::ConnectionLost => {
275 let this = self.clone();
276 let foreground = cx.foreground();
277 let heartbeat_interval = state.heartbeat_interval;
278 state._maintain_connection = Some(cx.spawn(|cx| async move {
279 let mut rng = StdRng::from_entropy();
280 let mut delay = Duration::from_millis(100);
281 while let Err(error) = this.authenticate_and_connect(&cx).await {
282 log::error!("failed to connect {}", error);
283 this.set_status(
284 Status::ReconnectionError {
285 next_reconnection: Instant::now() + delay,
286 },
287 &cx,
288 );
289 foreground.timer(delay).await;
290 delay = delay
291 .mul_f32(rng.gen_range(1.0..=2.0))
292 .min(heartbeat_interval);
293 }
294 }));
295 }
296 Status::SignedOut | Status::UpgradeRequired => {
297 state._maintain_connection.take();
298 }
299 _ => {}
300 }
301 }
302
303 pub fn add_model_for_remote_entity<T: Entity>(
304 self: &Arc<Self>,
305 remote_id: u64,
306 cx: &mut ModelContext<T>,
307 ) -> Subscription {
308 let handle = AnyModelHandle::from(cx.handle());
309 let mut state = self.state.write();
310 let id = (TypeId::of::<T>(), remote_id);
311 state
312 .models_by_entity_type_and_remote_id
313 .insert(id, handle.downgrade());
314 let pending_messages = state.pending_messages.remove(&id);
315 drop(state);
316
317 let client_id = self.id;
318 for message in pending_messages.into_iter().flatten() {
319 let type_id = message.payload_type_id();
320 let type_name = message.payload_type_name();
321 let state = self.state.read();
322 if let Some(handler) = state.message_handlers.get(&type_id).cloned() {
323 let future = (handler)(handle.clone(), message, cx.to_async());
324 drop(state);
325 log::debug!(
326 "deferred rpc message received. client_id:{}, name:{}",
327 client_id,
328 type_name
329 );
330 cx.foreground()
331 .spawn(async move {
332 match future.await {
333 Ok(()) => {
334 log::debug!(
335 "deferred rpc message handled. client_id:{}, name:{}",
336 client_id,
337 type_name
338 );
339 }
340 Err(error) => {
341 log::error!(
342 "error handling deferred message. client_id:{}, name:{}, {}",
343 client_id,
344 type_name,
345 error
346 );
347 }
348 }
349 })
350 .detach();
351 }
352 }
353
354 Subscription::Entity {
355 client: Arc::downgrade(self),
356 id,
357 }
358 }
359
360 pub fn add_message_handler<M, E, H, F>(
361 self: &Arc<Self>,
362 model: ModelHandle<E>,
363 handler: H,
364 ) -> Subscription
365 where
366 M: EnvelopedMessage,
367 E: Entity,
368 H: 'static
369 + Send
370 + Sync
371 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
372 F: 'static + Future<Output = Result<()>>,
373 {
374 let message_type_id = TypeId::of::<M>();
375
376 let client = self.clone();
377 let mut state = self.state.write();
378 state
379 .models_by_message_type
380 .insert(message_type_id, model.into());
381
382 let prev_handler = state.message_handlers.insert(
383 message_type_id,
384 Arc::new(move |handle, envelope, cx| {
385 let model = handle.downcast::<E>().unwrap();
386 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
387 handler(model, *envelope, client.clone(), cx).boxed_local()
388 }),
389 );
390 if prev_handler.is_some() {
391 panic!("registered handler for the same message twice");
392 }
393
394 Subscription::Message {
395 client: Arc::downgrade(self),
396 id: message_type_id,
397 }
398 }
399
400 pub fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
401 where
402 M: EntityMessage,
403 E: Entity,
404 H: 'static
405 + Send
406 + Sync
407 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
408 F: 'static + Future<Output = Result<()>>,
409 {
410 let model_type_id = TypeId::of::<E>();
411 let message_type_id = TypeId::of::<M>();
412
413 let client = self.clone();
414 let mut state = self.state.write();
415 state
416 .model_types_by_message_type
417 .insert(message_type_id, model_type_id);
418 state
419 .entity_id_extractors
420 .entry(message_type_id)
421 .or_insert_with(|| {
422 Box::new(|envelope| {
423 let envelope = envelope
424 .as_any()
425 .downcast_ref::<TypedEnvelope<M>>()
426 .unwrap();
427 envelope.payload.remote_entity_id()
428 })
429 });
430
431 let prev_handler = state.message_handlers.insert(
432 message_type_id,
433 Arc::new(move |handle, envelope, cx| {
434 let model = handle.downcast::<E>().unwrap();
435 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
436 handler(model, *envelope, client.clone(), cx).boxed_local()
437 }),
438 );
439 if prev_handler.is_some() {
440 panic!("registered handler for the same message twice");
441 }
442 }
443
444 pub fn add_entity_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
445 where
446 M: EntityMessage + RequestMessage,
447 E: Entity,
448 H: 'static
449 + Send
450 + Sync
451 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
452 F: 'static + Future<Output = Result<M::Response>>,
453 {
454 self.add_entity_message_handler(move |model, envelope, client, cx| {
455 let receipt = envelope.receipt();
456 let response = handler(model, envelope, client.clone(), cx);
457 async move {
458 match response.await {
459 Ok(response) => {
460 client.respond(receipt, response)?;
461 Ok(())
462 }
463 Err(error) => {
464 client.respond_with_error(
465 receipt,
466 proto::Error {
467 message: error.to_string(),
468 },
469 )?;
470 Err(error)
471 }
472 }
473 }
474 })
475 }
476
477 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
478 read_credentials_from_keychain(cx).is_some()
479 }
480
481 #[async_recursion(?Send)]
482 pub async fn authenticate_and_connect(
483 self: &Arc<Self>,
484 cx: &AsyncAppContext,
485 ) -> anyhow::Result<()> {
486 let was_disconnected = match *self.status().borrow() {
487 Status::SignedOut => true,
488 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
489 false
490 }
491 Status::Connected { .. }
492 | Status::Connecting { .. }
493 | Status::Reconnecting { .. }
494 | Status::Authenticating
495 | Status::Reauthenticating => return Ok(()),
496 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
497 };
498
499 if was_disconnected {
500 self.set_status(Status::Authenticating, cx);
501 } else {
502 self.set_status(Status::Reauthenticating, cx)
503 }
504
505 let mut used_keychain = false;
506 let credentials = self.state.read().credentials.clone();
507 let credentials = if let Some(credentials) = credentials {
508 credentials
509 } else if let Some(credentials) = read_credentials_from_keychain(cx) {
510 used_keychain = true;
511 credentials
512 } else {
513 let credentials = match self.authenticate(&cx).await {
514 Ok(credentials) => credentials,
515 Err(err) => {
516 self.set_status(Status::ConnectionError, cx);
517 return Err(err);
518 }
519 };
520 credentials
521 };
522
523 if was_disconnected {
524 self.set_status(Status::Connecting, cx);
525 } else {
526 self.set_status(Status::Reconnecting, cx);
527 }
528
529 match self.establish_connection(&credentials, cx).await {
530 Ok(conn) => {
531 self.state.write().credentials = Some(credentials.clone());
532 if !used_keychain && IMPERSONATE_LOGIN.is_none() {
533 write_credentials_to_keychain(&credentials, cx).log_err();
534 }
535 self.set_connection(conn, cx).await;
536 Ok(())
537 }
538 Err(EstablishConnectionError::Unauthorized) => {
539 self.state.write().credentials.take();
540 if used_keychain {
541 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
542 self.set_status(Status::SignedOut, cx);
543 self.authenticate_and_connect(cx).await
544 } else {
545 self.set_status(Status::ConnectionError, cx);
546 Err(EstablishConnectionError::Unauthorized)?
547 }
548 }
549 Err(EstablishConnectionError::UpgradeRequired) => {
550 self.set_status(Status::UpgradeRequired, cx);
551 Err(EstablishConnectionError::UpgradeRequired)?
552 }
553 Err(error) => {
554 self.set_status(Status::ConnectionError, cx);
555 Err(error)?
556 }
557 }
558 }
559
560 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
561 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
562 cx.foreground()
563 .spawn({
564 let cx = cx.clone();
565 let this = self.clone();
566 async move {
567 while let Some(message) = incoming.next().await {
568 let mut state = this.state.write();
569 let payload_type_id = message.payload_type_id();
570 let type_name = message.payload_type_name();
571 let model_type_id = state
572 .model_types_by_message_type
573 .get(&payload_type_id)
574 .copied();
575 let entity_id = state
576 .entity_id_extractors
577 .get(&message.payload_type_id())
578 .map(|extract_entity_id| (extract_entity_id)(message.as_ref()));
579
580 let model = state
581 .models_by_message_type
582 .get(&payload_type_id)
583 .cloned()
584 .or_else(|| {
585 let model_type_id = model_type_id?;
586 let entity_id = entity_id?;
587 let model = state
588 .models_by_entity_type_and_remote_id
589 .get(&(model_type_id, entity_id))?;
590 if let Some(model) = model.upgrade(&cx) {
591 Some(model)
592 } else {
593 state
594 .models_by_entity_type_and_remote_id
595 .remove(&(model_type_id, entity_id));
596 None
597 }
598 });
599
600 let model = if let Some(model) = model {
601 model
602 } else {
603 log::info!("unhandled message {}", type_name);
604 if let Some((model_type_id, entity_id)) = model_type_id.zip(entity_id) {
605 state
606 .pending_messages
607 .entry((model_type_id, entity_id))
608 .or_default()
609 .push(message);
610 }
611
612 continue;
613 };
614
615 if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
616 {
617 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
618 let future = handler(model, message, cx.clone());
619
620 let client_id = this.id;
621 log::debug!(
622 "rpc message received. client_id:{}, name:{}",
623 client_id,
624 type_name
625 );
626 cx.foreground()
627 .spawn(async move {
628 match future.await {
629 Ok(()) => {
630 log::debug!(
631 "rpc message handled. client_id:{}, name:{}",
632 client_id,
633 type_name
634 );
635 }
636 Err(error) => {
637 log::error!(
638 "error handling message. client_id:{}, name:{}, {}",
639 client_id,
640 type_name,
641 error
642 );
643 }
644 }
645 })
646 .detach();
647 } else {
648 log::info!("unhandled message {}", type_name);
649 }
650 }
651 }
652 })
653 .detach();
654
655 self.set_status(Status::Connected { connection_id }, cx);
656
657 let handle_io = cx.background().spawn(handle_io);
658 let this = self.clone();
659 let cx = cx.clone();
660 cx.foreground()
661 .spawn(async move {
662 match handle_io.await {
663 Ok(()) => this.set_status(Status::SignedOut, &cx),
664 Err(err) => {
665 log::error!("connection error: {:?}", err);
666 this.set_status(Status::ConnectionLost, &cx);
667 }
668 }
669 })
670 .detach();
671 }
672
673 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
674 if let Some(callback) = self.authenticate.as_ref() {
675 callback(cx)
676 } else {
677 self.authenticate_with_browser(cx)
678 }
679 }
680
681 fn establish_connection(
682 self: &Arc<Self>,
683 credentials: &Credentials,
684 cx: &AsyncAppContext,
685 ) -> Task<Result<Connection, EstablishConnectionError>> {
686 if let Some(callback) = self.establish_connection.as_ref() {
687 callback(credentials, cx)
688 } else {
689 self.establish_websocket_connection(credentials, cx)
690 }
691 }
692
693 fn establish_websocket_connection(
694 self: &Arc<Self>,
695 credentials: &Credentials,
696 cx: &AsyncAppContext,
697 ) -> Task<Result<Connection, EstablishConnectionError>> {
698 let request = Request::builder()
699 .header(
700 "Authorization",
701 format!("{} {}", credentials.user_id, credentials.access_token),
702 )
703 .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
704
705 let http = self.http.clone();
706 cx.background().spawn(async move {
707 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
708 let rpc_request = surf::Request::new(
709 Method::Get,
710 surf::Url::parse(&rpc_url).context("invalid ZED_SERVER_URL")?,
711 );
712 let rpc_response = http.send(rpc_request).await?;
713
714 if rpc_response.status().is_redirection() {
715 rpc_url = rpc_response
716 .header("Location")
717 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
718 .as_str()
719 .to_string();
720 }
721 // Until we switch the zed.dev domain to point to the new Next.js app, there
722 // will be no redirect required, and the app will connect directly to
723 // wss://zed.dev/rpc.
724 else if rpc_response.status() != surf::StatusCode::UpgradeRequired {
725 Err(anyhow!(
726 "unexpected /rpc response status {}",
727 rpc_response.status()
728 ))?
729 }
730
731 let mut rpc_url = surf::Url::parse(&rpc_url).context("invalid rpc url")?;
732 let rpc_host = rpc_url
733 .host_str()
734 .zip(rpc_url.port_or_known_default())
735 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
736 let stream = smol::net::TcpStream::connect(rpc_host).await?;
737
738 log::info!("connected to rpc endpoint {}", rpc_url);
739
740 match rpc_url.scheme() {
741 "https" => {
742 rpc_url.set_scheme("wss").unwrap();
743 let request = request.uri(rpc_url.as_str()).body(())?;
744 let (stream, _) =
745 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
746 Ok(Connection::new(stream))
747 }
748 "http" => {
749 rpc_url.set_scheme("ws").unwrap();
750 let request = request.uri(rpc_url.as_str()).body(())?;
751 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
752 Ok(Connection::new(stream))
753 }
754 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
755 }
756 })
757 }
758
759 pub fn authenticate_with_browser(
760 self: &Arc<Self>,
761 cx: &AsyncAppContext,
762 ) -> Task<Result<Credentials>> {
763 let platform = cx.platform();
764 let executor = cx.background();
765 executor.clone().spawn(async move {
766 // Generate a pair of asymmetric encryption keys. The public key will be used by the
767 // zed server to encrypt the user's access token, so that it can'be intercepted by
768 // any other app running on the user's device.
769 let (public_key, private_key) =
770 rpc::auth::keypair().expect("failed to generate keypair for auth");
771 let public_key_string =
772 String::try_from(public_key).expect("failed to serialize public key for auth");
773
774 // Start an HTTP server to receive the redirect from Zed's sign-in page.
775 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
776 let port = server.server_addr().port();
777
778 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
779 // that the user is signing in from a Zed app running on the same device.
780 let mut url = format!(
781 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
782 *ZED_SERVER_URL, port, public_key_string
783 );
784
785 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
786 log::info!("impersonating user @{}", impersonate_login);
787 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
788 }
789
790 platform.open_url(&url);
791
792 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
793 // access token from the query params.
794 //
795 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
796 // custom URL scheme instead of this local HTTP server.
797 let (user_id, access_token) = executor
798 .spawn(async move {
799 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
800 let path = req.url();
801 let mut user_id = None;
802 let mut access_token = None;
803 let url = Url::parse(&format!("http://example.com{}", path))
804 .context("failed to parse login notification url")?;
805 for (key, value) in url.query_pairs() {
806 if key == "access_token" {
807 access_token = Some(value.to_string());
808 } else if key == "user_id" {
809 user_id = Some(value.to_string());
810 }
811 }
812
813 let post_auth_url =
814 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
815 req.respond(
816 tiny_http::Response::empty(302).with_header(
817 tiny_http::Header::from_bytes(
818 &b"Location"[..],
819 post_auth_url.as_bytes(),
820 )
821 .unwrap(),
822 ),
823 )
824 .context("failed to respond to login http request")?;
825 Ok((
826 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
827 access_token
828 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
829 ))
830 } else {
831 Err(anyhow!("didn't receive login redirect"))
832 }
833 })
834 .await?;
835
836 let access_token = private_key
837 .decrypt_string(&access_token)
838 .context("failed to decrypt access token")?;
839 platform.activate(true);
840
841 Ok(Credentials {
842 user_id: user_id.parse()?,
843 access_token,
844 })
845 })
846 }
847
848 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
849 let conn_id = self.connection_id()?;
850 self.peer.disconnect(conn_id);
851 self.set_status(Status::SignedOut, cx);
852 Ok(())
853 }
854
855 fn connection_id(&self) -> Result<ConnectionId> {
856 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
857 Ok(connection_id)
858 } else {
859 Err(anyhow!("not connected"))
860 }
861 }
862
863 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
864 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
865 self.peer.send(self.connection_id()?, message)
866 }
867
868 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
869 log::debug!(
870 "rpc request start. client_id: {}. name:{}",
871 self.id,
872 T::NAME
873 );
874 let response = self.peer.request(self.connection_id()?, request).await;
875 log::debug!(
876 "rpc request finish. client_id: {}. name:{}",
877 self.id,
878 T::NAME
879 );
880 response
881 }
882
883 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
884 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
885 self.peer.respond(receipt, response)
886 }
887
888 fn respond_with_error<T: RequestMessage>(
889 &self,
890 receipt: Receipt<T>,
891 error: proto::Error,
892 ) -> Result<()> {
893 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
894 self.peer.respond_with_error(receipt, error)
895 }
896}
897
898fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
899 if IMPERSONATE_LOGIN.is_some() {
900 return None;
901 }
902
903 let (user_id, access_token) = cx
904 .platform()
905 .read_credentials(&ZED_SERVER_URL)
906 .log_err()
907 .flatten()?;
908 Some(Credentials {
909 user_id: user_id.parse().ok()?,
910 access_token: String::from_utf8(access_token).ok()?,
911 })
912}
913
914fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
915 cx.platform().write_credentials(
916 &ZED_SERVER_URL,
917 &credentials.user_id.to_string(),
918 credentials.access_token.as_bytes(),
919 )
920}
921
922const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
923
924pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
925 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
926}
927
928pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
929 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
930 let mut parts = path.split('/');
931 let id = parts.next()?.parse::<u64>().ok()?;
932 let access_token = parts.next()?;
933 if access_token.is_empty() {
934 return None;
935 }
936 Some((id, access_token.to_string()))
937}
938
939#[cfg(test)]
940mod tests {
941 use super::*;
942 use crate::test::{FakeHttpClient, FakeServer};
943 use gpui::TestAppContext;
944
945 #[gpui::test(iterations = 10)]
946 async fn test_heartbeat(cx: TestAppContext) {
947 cx.foreground().forbid_parking();
948
949 let user_id = 5;
950 let mut client = Client::new(FakeHttpClient::with_404_response());
951 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
952
953 cx.foreground().advance_clock(Duration::from_secs(10));
954 let ping = server.receive::<proto::Ping>().await.unwrap();
955 server.respond(ping.receipt(), proto::Ack {}).await;
956
957 cx.foreground().advance_clock(Duration::from_secs(10));
958 let ping = server.receive::<proto::Ping>().await.unwrap();
959 server.respond(ping.receipt(), proto::Ack {}).await;
960
961 client.disconnect(&cx.to_async()).unwrap();
962 assert!(server.receive::<proto::Ping>().await.is_err());
963 }
964
965 #[gpui::test(iterations = 10)]
966 async fn test_reconnection(cx: TestAppContext) {
967 cx.foreground().forbid_parking();
968
969 let user_id = 5;
970 let mut client = Client::new(FakeHttpClient::with_404_response());
971 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
972 let mut status = client.status();
973 assert!(matches!(
974 status.next().await,
975 Some(Status::Connected { .. })
976 ));
977 assert_eq!(server.auth_count(), 1);
978
979 server.forbid_connections();
980 server.disconnect();
981 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
982
983 server.allow_connections();
984 cx.foreground().advance_clock(Duration::from_secs(10));
985 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
986 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
987
988 server.forbid_connections();
989 server.disconnect();
990 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
991
992 // Clear cached credentials after authentication fails
993 server.roll_access_token();
994 server.allow_connections();
995 cx.foreground().advance_clock(Duration::from_secs(10));
996 assert_eq!(server.auth_count(), 1);
997 cx.foreground().advance_clock(Duration::from_secs(10));
998 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
999 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1000 }
1001
1002 #[test]
1003 fn test_encode_and_decode_worktree_url() {
1004 let url = encode_worktree_url(5, "deadbeef");
1005 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1006 assert_eq!(
1007 decode_worktree_url(&format!("\n {}\t", url)),
1008 Some((5, "deadbeef".to_string()))
1009 );
1010 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1011 }
1012
1013 #[gpui::test]
1014 async fn test_subscribing_to_entity(mut cx: TestAppContext) {
1015 cx.foreground().forbid_parking();
1016
1017 let user_id = 5;
1018 let mut client = Client::new(FakeHttpClient::with_404_response());
1019 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1020
1021 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1022 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1023 client.add_entity_message_handler(
1024 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::UnshareProject>, _, cx| {
1025 match model.read_with(&cx, |model, _| model.id) {
1026 1 => done_tx1.try_send(()).unwrap(),
1027 2 => done_tx2.try_send(()).unwrap(),
1028 _ => unreachable!(),
1029 }
1030 async { Ok(()) }
1031 },
1032 );
1033 let model1 = cx.add_model(|_| Model {
1034 id: 1,
1035 subscription: None,
1036 });
1037 let model2 = cx.add_model(|_| Model {
1038 id: 2,
1039 subscription: None,
1040 });
1041 let model3 = cx.add_model(|_| Model {
1042 id: 3,
1043 subscription: None,
1044 });
1045
1046 let _subscription1 =
1047 model1.update(&mut cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1048 let _subscription2 =
1049 model2.update(&mut cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1050 // Ensure dropping a subscription for the same entity type still allows receiving of
1051 // messages for other entity IDs of the same type.
1052 let subscription3 =
1053 model3.update(&mut cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1054 drop(subscription3);
1055
1056 server.send(proto::UnshareProject { project_id: 1 });
1057 server.send(proto::UnshareProject { project_id: 2 });
1058 done_rx1.next().await.unwrap();
1059 done_rx2.next().await.unwrap();
1060 }
1061
1062 #[gpui::test]
1063 async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
1064 cx.foreground().forbid_parking();
1065
1066 let user_id = 5;
1067 let mut client = Client::new(FakeHttpClient::with_404_response());
1068 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1069
1070 let model = cx.add_model(|_| Model::default());
1071 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1072 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1073 let subscription1 = client.add_message_handler(
1074 model.clone(),
1075 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1076 done_tx1.try_send(()).unwrap();
1077 async { Ok(()) }
1078 },
1079 );
1080 drop(subscription1);
1081 let _subscription2 =
1082 client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1083 done_tx2.try_send(()).unwrap();
1084 async { Ok(()) }
1085 });
1086 server.send(proto::Ping {});
1087 done_rx2.next().await.unwrap();
1088 }
1089
1090 #[gpui::test]
1091 async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
1092 cx.foreground().forbid_parking();
1093
1094 let user_id = 5;
1095 let mut client = Client::new(FakeHttpClient::with_404_response());
1096 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1097
1098 let model = cx.add_model(|_| Model::default());
1099 let (done_tx, mut done_rx) = smol::channel::unbounded();
1100 let subscription = client.add_message_handler(
1101 model.clone(),
1102 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1103 model.update(&mut cx, |model, _| model.subscription.take());
1104 done_tx.try_send(()).unwrap();
1105 async { Ok(()) }
1106 },
1107 );
1108 model.update(&mut cx, |model, _| {
1109 model.subscription = Some(subscription);
1110 });
1111 server.send(proto::Ping {});
1112 done_rx.next().await.unwrap();
1113 }
1114
1115 #[derive(Default)]
1116 struct Model {
1117 id: usize,
1118 subscription: Option<Subscription>,
1119 }
1120
1121 impl Entity for Model {
1122 type Event = ();
1123 }
1124}