1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod user;
7
8use anyhow::{anyhow, Context, Result};
9use async_recursion::async_recursion;
10use async_tungstenite::tungstenite::{
11 error::Error as WebsocketError,
12 http::{Request, StatusCode},
13};
14use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
15use gpui::{
16 actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
17 Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
18};
19use http::HttpClient;
20use lazy_static::lazy_static;
21use parking_lot::RwLock;
22use postage::watch;
23use rand::prelude::*;
24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
25use std::{
26 any::TypeId,
27 collections::HashMap,
28 convert::TryFrom,
29 fmt::Write as _,
30 future::Future,
31 sync::{
32 atomic::{AtomicUsize, Ordering},
33 Arc, Weak,
34 },
35 time::{Duration, Instant},
36};
37use thiserror::Error;
38use url::Url;
39use util::{ResultExt, TryFutureExt};
40
41pub use channel::*;
42pub use rpc::*;
43pub use user::*;
44
45lazy_static! {
46 pub static ref ZED_SERVER_URL: String =
47 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
48 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
49 .ok()
50 .and_then(|s| if s.is_empty() { None } else { Some(s) });
51}
52
53pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894";
54
55actions!(client, [Authenticate]);
56
57pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
58 cx.add_global_action(move |_: &Authenticate, cx| {
59 let rpc = rpc.clone();
60 cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
61 .detach();
62 });
63}
64
65pub struct Client {
66 id: usize,
67 peer: Arc<Peer>,
68 http: Arc<dyn HttpClient>,
69 state: RwLock<ClientState>,
70 authenticate:
71 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
72 establish_connection: Option<
73 Box<
74 dyn 'static
75 + Send
76 + Sync
77 + Fn(
78 &Credentials,
79 &AsyncAppContext,
80 ) -> Task<Result<Connection, EstablishConnectionError>>,
81 >,
82 >,
83}
84
85#[derive(Error, Debug)]
86pub enum EstablishConnectionError {
87 #[error("upgrade required")]
88 UpgradeRequired,
89 #[error("unauthorized")]
90 Unauthorized,
91 #[error("{0}")]
92 Other(#[from] anyhow::Error),
93 #[error("{0}")]
94 Http(#[from] http::Error),
95 #[error("{0}")]
96 Io(#[from] std::io::Error),
97 #[error("{0}")]
98 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
99}
100
101impl From<WebsocketError> for EstablishConnectionError {
102 fn from(error: WebsocketError) -> Self {
103 if let WebsocketError::Http(response) = &error {
104 match response.status() {
105 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
106 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
107 _ => {}
108 }
109 }
110 EstablishConnectionError::Other(error.into())
111 }
112}
113
114impl EstablishConnectionError {
115 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
116 Self::Other(error.into())
117 }
118}
119
120#[derive(Copy, Clone, Debug)]
121pub enum Status {
122 SignedOut,
123 UpgradeRequired,
124 Authenticating,
125 Connecting,
126 ConnectionError,
127 Connected { connection_id: ConnectionId },
128 ConnectionLost,
129 Reauthenticating,
130 Reconnecting,
131 ReconnectionError { next_reconnection: Instant },
132}
133
134impl Status {
135 pub fn is_connected(&self) -> bool {
136 matches!(self, Self::Connected { .. })
137 }
138}
139
140struct ClientState {
141 credentials: Option<Credentials>,
142 status: (watch::Sender<Status>, watch::Receiver<Status>),
143 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
144 _reconnect_task: Option<Task<()>>,
145 reconnect_interval: Duration,
146 entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
147 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
148 entity_types_by_message_type: HashMap<TypeId, TypeId>,
149 message_handlers: HashMap<
150 TypeId,
151 Arc<
152 dyn Send
153 + Sync
154 + Fn(
155 AnyEntityHandle,
156 Box<dyn AnyTypedEnvelope>,
157 &Arc<Client>,
158 AsyncAppContext,
159 ) -> LocalBoxFuture<'static, Result<()>>,
160 >,
161 >,
162}
163
164enum AnyWeakEntityHandle {
165 Model(AnyWeakModelHandle),
166 View(AnyWeakViewHandle),
167}
168
169enum AnyEntityHandle {
170 Model(AnyModelHandle),
171 View(AnyViewHandle),
172}
173
174#[derive(Clone, Debug)]
175pub struct Credentials {
176 pub user_id: u64,
177 pub access_token: String,
178}
179
180impl Default for ClientState {
181 fn default() -> Self {
182 Self {
183 credentials: None,
184 status: watch::channel_with(Status::SignedOut),
185 entity_id_extractors: Default::default(),
186 _reconnect_task: None,
187 reconnect_interval: Duration::from_secs(5),
188 models_by_message_type: Default::default(),
189 entities_by_type_and_remote_id: Default::default(),
190 entity_types_by_message_type: Default::default(),
191 message_handlers: Default::default(),
192 }
193 }
194}
195
196pub enum Subscription {
197 Entity {
198 client: Weak<Client>,
199 id: (TypeId, u64),
200 },
201 Message {
202 client: Weak<Client>,
203 id: TypeId,
204 },
205}
206
207impl Drop for Subscription {
208 fn drop(&mut self) {
209 match self {
210 Subscription::Entity { client, id } => {
211 if let Some(client) = client.upgrade() {
212 let mut state = client.state.write();
213 let _ = state.entities_by_type_and_remote_id.remove(id);
214 }
215 }
216 Subscription::Message { client, id } => {
217 if let Some(client) = client.upgrade() {
218 let mut state = client.state.write();
219 let _ = state.entity_types_by_message_type.remove(id);
220 let _ = state.message_handlers.remove(id);
221 }
222 }
223 }
224 }
225}
226
227impl Client {
228 pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
229 lazy_static! {
230 static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
231 }
232
233 Arc::new(Self {
234 id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
235 peer: Peer::new(),
236 http,
237 state: Default::default(),
238 authenticate: None,
239 establish_connection: None,
240 })
241 }
242
243 pub fn id(&self) -> usize {
244 self.id
245 }
246
247 pub fn http_client(&self) -> Arc<dyn HttpClient> {
248 self.http.clone()
249 }
250
251 #[cfg(any(test, feature = "test-support"))]
252 pub fn tear_down(&self) {
253 let mut state = self.state.write();
254 state._reconnect_task.take();
255 state.message_handlers.clear();
256 state.models_by_message_type.clear();
257 state.entities_by_type_and_remote_id.clear();
258 state.entity_id_extractors.clear();
259 self.peer.reset();
260 }
261
262 #[cfg(any(test, feature = "test-support"))]
263 pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
264 where
265 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
266 {
267 self.authenticate = Some(Box::new(authenticate));
268 self
269 }
270
271 #[cfg(any(test, feature = "test-support"))]
272 pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
273 where
274 F: 'static
275 + Send
276 + Sync
277 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
278 {
279 self.establish_connection = Some(Box::new(connect));
280 self
281 }
282
283 pub fn user_id(&self) -> Option<u64> {
284 self.state
285 .read()
286 .credentials
287 .as_ref()
288 .map(|credentials| credentials.user_id)
289 }
290
291 pub fn status(&self) -> watch::Receiver<Status> {
292 self.state.read().status.1.clone()
293 }
294
295 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
296 let mut state = self.state.write();
297 *state.status.0.borrow_mut() = status;
298
299 match status {
300 Status::Connected { .. } => {
301 state._reconnect_task = None;
302 }
303 Status::ConnectionLost => {
304 let this = self.clone();
305 let reconnect_interval = state.reconnect_interval;
306 state._reconnect_task = Some(cx.spawn(|cx| async move {
307 let mut rng = StdRng::from_entropy();
308 let mut delay = Duration::from_millis(100);
309 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
310 log::error!("failed to connect {}", error);
311 this.set_status(
312 Status::ReconnectionError {
313 next_reconnection: Instant::now() + delay,
314 },
315 &cx,
316 );
317 cx.background().timer(delay).await;
318 delay = delay
319 .mul_f32(rng.gen_range(1.0..=2.0))
320 .min(reconnect_interval);
321 }
322 }));
323 }
324 Status::SignedOut | Status::UpgradeRequired => {
325 state._reconnect_task.take();
326 }
327 _ => {}
328 }
329 }
330
331 pub fn add_view_for_remote_entity<T: View>(
332 self: &Arc<Self>,
333 remote_id: u64,
334 cx: &mut ViewContext<T>,
335 ) -> Subscription {
336 let id = (TypeId::of::<T>(), remote_id);
337 self.state
338 .write()
339 .entities_by_type_and_remote_id
340 .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
341 Subscription::Entity {
342 client: Arc::downgrade(self),
343 id,
344 }
345 }
346
347 pub fn add_model_for_remote_entity<T: Entity>(
348 self: &Arc<Self>,
349 remote_id: u64,
350 cx: &mut ModelContext<T>,
351 ) -> Subscription {
352 let id = (TypeId::of::<T>(), remote_id);
353 self.state
354 .write()
355 .entities_by_type_and_remote_id
356 .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
357 Subscription::Entity {
358 client: Arc::downgrade(self),
359 id,
360 }
361 }
362
363 pub fn add_message_handler<M, E, H, F>(
364 self: &Arc<Self>,
365 model: ModelHandle<E>,
366 handler: H,
367 ) -> Subscription
368 where
369 M: EnvelopedMessage,
370 E: Entity,
371 H: 'static
372 + Send
373 + Sync
374 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
375 F: 'static + Future<Output = Result<()>>,
376 {
377 let message_type_id = TypeId::of::<M>();
378
379 let mut state = self.state.write();
380 state
381 .models_by_message_type
382 .insert(message_type_id, model.downgrade().into());
383
384 let prev_handler = state.message_handlers.insert(
385 message_type_id,
386 Arc::new(move |handle, envelope, client, cx| {
387 let handle = if let AnyEntityHandle::Model(handle) = handle {
388 handle
389 } else {
390 unreachable!();
391 };
392 let model = handle.downcast::<E>().unwrap();
393 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
394 handler(model, *envelope, client.clone(), cx).boxed_local()
395 }),
396 );
397 if prev_handler.is_some() {
398 panic!("registered handler for the same message twice");
399 }
400
401 Subscription::Message {
402 client: Arc::downgrade(self),
403 id: message_type_id,
404 }
405 }
406
407 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
408 where
409 M: EntityMessage,
410 E: View,
411 H: 'static
412 + Send
413 + Sync
414 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
415 F: 'static + Future<Output = Result<()>>,
416 {
417 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
418 if let AnyEntityHandle::View(handle) = handle {
419 handler(handle.downcast::<E>().unwrap(), message, client, cx)
420 } else {
421 unreachable!();
422 }
423 })
424 }
425
426 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
427 where
428 M: EntityMessage,
429 E: Entity,
430 H: 'static
431 + Send
432 + Sync
433 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
434 F: 'static + Future<Output = Result<()>>,
435 {
436 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
437 if let AnyEntityHandle::Model(handle) = handle {
438 handler(handle.downcast::<E>().unwrap(), message, client, cx)
439 } else {
440 unreachable!();
441 }
442 })
443 }
444
445 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
446 where
447 M: EntityMessage,
448 E: Entity,
449 H: 'static
450 + Send
451 + Sync
452 + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
453 F: 'static + Future<Output = Result<()>>,
454 {
455 let model_type_id = TypeId::of::<E>();
456 let message_type_id = TypeId::of::<M>();
457
458 let mut state = self.state.write();
459 state
460 .entity_types_by_message_type
461 .insert(message_type_id, model_type_id);
462 state
463 .entity_id_extractors
464 .entry(message_type_id)
465 .or_insert_with(|| {
466 |envelope| {
467 envelope
468 .as_any()
469 .downcast_ref::<TypedEnvelope<M>>()
470 .unwrap()
471 .payload
472 .remote_entity_id()
473 }
474 });
475 let prev_handler = state.message_handlers.insert(
476 message_type_id,
477 Arc::new(move |handle, envelope, client, cx| {
478 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
479 handler(handle, *envelope, client.clone(), cx).boxed_local()
480 }),
481 );
482 if prev_handler.is_some() {
483 panic!("registered handler for the same message twice");
484 }
485 }
486
487 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
488 where
489 M: EntityMessage + RequestMessage,
490 E: Entity,
491 H: 'static
492 + Send
493 + Sync
494 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
495 F: 'static + Future<Output = Result<M::Response>>,
496 {
497 self.add_model_message_handler(move |entity, envelope, client, cx| {
498 Self::respond_to_request::<M, _>(
499 envelope.receipt(),
500 handler(entity, envelope, client.clone(), cx),
501 client,
502 )
503 })
504 }
505
506 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
507 where
508 M: EntityMessage + RequestMessage,
509 E: View,
510 H: 'static
511 + Send
512 + Sync
513 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
514 F: 'static + Future<Output = Result<M::Response>>,
515 {
516 self.add_view_message_handler(move |entity, envelope, client, cx| {
517 Self::respond_to_request::<M, _>(
518 envelope.receipt(),
519 handler(entity, envelope, client.clone(), cx),
520 client,
521 )
522 })
523 }
524
525 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
526 receipt: Receipt<T>,
527 response: F,
528 client: Arc<Self>,
529 ) -> Result<()> {
530 match response.await {
531 Ok(response) => {
532 client.respond(receipt, response)?;
533 Ok(())
534 }
535 Err(error) => {
536 client.respond_with_error(
537 receipt,
538 proto::Error {
539 message: error.to_string(),
540 },
541 )?;
542 Err(error)
543 }
544 }
545 }
546
547 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
548 read_credentials_from_keychain(cx).is_some()
549 }
550
551 #[async_recursion(?Send)]
552 pub async fn authenticate_and_connect(
553 self: &Arc<Self>,
554 try_keychain: bool,
555 cx: &AsyncAppContext,
556 ) -> anyhow::Result<()> {
557 let was_disconnected = match *self.status().borrow() {
558 Status::SignedOut => true,
559 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
560 false
561 }
562 Status::Connected { .. }
563 | Status::Connecting { .. }
564 | Status::Reconnecting { .. }
565 | Status::Authenticating
566 | Status::Reauthenticating => return Ok(()),
567 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
568 };
569
570 if was_disconnected {
571 self.set_status(Status::Authenticating, cx);
572 } else {
573 self.set_status(Status::Reauthenticating, cx)
574 }
575
576 let mut read_from_keychain = false;
577 let mut credentials = self.state.read().credentials.clone();
578 if credentials.is_none() && try_keychain {
579 credentials = read_credentials_from_keychain(cx);
580 read_from_keychain = credentials.is_some();
581 }
582 if credentials.is_none() {
583 credentials = Some(match self.authenticate(&cx).await {
584 Ok(credentials) => credentials,
585 Err(err) => {
586 self.set_status(Status::ConnectionError, cx);
587 return Err(err);
588 }
589 });
590 }
591 let credentials = credentials.unwrap();
592
593 if was_disconnected {
594 self.set_status(Status::Connecting, cx);
595 } else {
596 self.set_status(Status::Reconnecting, cx);
597 }
598
599 match self.establish_connection(&credentials, cx).await {
600 Ok(conn) => {
601 self.state.write().credentials = Some(credentials.clone());
602 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
603 write_credentials_to_keychain(&credentials, cx).log_err();
604 }
605 self.set_connection(conn, cx).await;
606 Ok(())
607 }
608 Err(EstablishConnectionError::Unauthorized) => {
609 self.state.write().credentials.take();
610 if read_from_keychain {
611 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
612 self.set_status(Status::SignedOut, cx);
613 self.authenticate_and_connect(false, cx).await
614 } else {
615 self.set_status(Status::ConnectionError, cx);
616 Err(EstablishConnectionError::Unauthorized)?
617 }
618 }
619 Err(EstablishConnectionError::UpgradeRequired) => {
620 self.set_status(Status::UpgradeRequired, cx);
621 Err(EstablishConnectionError::UpgradeRequired)?
622 }
623 Err(error) => {
624 self.set_status(Status::ConnectionError, cx);
625 Err(error)?
626 }
627 }
628 }
629
630 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
631 let executor = cx.background();
632 let (connection_id, handle_io, mut incoming) = self
633 .peer
634 .add_connection(conn, move |duration| executor.timer(duration))
635 .await;
636 cx.foreground()
637 .spawn({
638 let cx = cx.clone();
639 let this = self.clone();
640 async move {
641 let mut message_id = 0_usize;
642 while let Some(message) = incoming.next().await {
643 let mut state = this.state.write();
644 message_id += 1;
645 let type_name = message.payload_type_name();
646 let payload_type_id = message.payload_type_id();
647 let sender_id = message.original_sender_id().map(|id| id.0);
648
649 let model = state
650 .models_by_message_type
651 .get(&payload_type_id)
652 .and_then(|model| model.upgrade(&cx))
653 .map(AnyEntityHandle::Model)
654 .or_else(|| {
655 let entity_type_id =
656 *state.entity_types_by_message_type.get(&payload_type_id)?;
657 let entity_id = state
658 .entity_id_extractors
659 .get(&message.payload_type_id())
660 .map(|extract_entity_id| {
661 (extract_entity_id)(message.as_ref())
662 })?;
663
664 let entity = state
665 .entities_by_type_and_remote_id
666 .get(&(entity_type_id, entity_id))?;
667 if let Some(entity) = entity.upgrade(&cx) {
668 Some(entity)
669 } else {
670 state
671 .entities_by_type_and_remote_id
672 .remove(&(entity_type_id, entity_id));
673 None
674 }
675 });
676
677 let model = if let Some(model) = model {
678 model
679 } else {
680 log::info!("unhandled message {}", type_name);
681 continue;
682 };
683
684 if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
685 {
686 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
687 let future = handler(model, message, &this, cx.clone());
688
689 let client_id = this.id;
690 log::debug!(
691 "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
692 client_id,
693 message_id,
694 sender_id,
695 type_name
696 );
697 cx.foreground()
698 .spawn(async move {
699 match future.await {
700 Ok(()) => {
701 log::debug!(
702 "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
703 client_id,
704 message_id,
705 sender_id,
706 type_name
707 );
708 }
709 Err(error) => {
710 log::error!(
711 "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
712 client_id,
713 message_id,
714 sender_id,
715 type_name,
716 error
717 );
718 }
719 }
720 })
721 .detach();
722 } else {
723 log::info!("unhandled message {}", type_name);
724 }
725
726 // Don't starve the main thread when receiving lots of messages at once.
727 smol::future::yield_now().await;
728 }
729 }
730 })
731 .detach();
732
733 self.set_status(Status::Connected { connection_id }, cx);
734
735 let handle_io = cx.background().spawn(handle_io);
736 let this = self.clone();
737 let cx = cx.clone();
738 cx.foreground()
739 .spawn(async move {
740 match handle_io.await {
741 Ok(()) => this.set_status(Status::SignedOut, &cx),
742 Err(err) => {
743 log::error!("connection error: {:?}", err);
744 this.set_status(Status::ConnectionLost, &cx);
745 }
746 }
747 })
748 .detach();
749 }
750
751 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
752 if let Some(callback) = self.authenticate.as_ref() {
753 callback(cx)
754 } else {
755 self.authenticate_with_browser(cx)
756 }
757 }
758
759 fn establish_connection(
760 self: &Arc<Self>,
761 credentials: &Credentials,
762 cx: &AsyncAppContext,
763 ) -> Task<Result<Connection, EstablishConnectionError>> {
764 if let Some(callback) = self.establish_connection.as_ref() {
765 callback(credentials, cx)
766 } else {
767 self.establish_websocket_connection(credentials, cx)
768 }
769 }
770
771 fn establish_websocket_connection(
772 self: &Arc<Self>,
773 credentials: &Credentials,
774 cx: &AsyncAppContext,
775 ) -> Task<Result<Connection, EstablishConnectionError>> {
776 let request = Request::builder()
777 .header(
778 "Authorization",
779 format!("{} {}", credentials.user_id, credentials.access_token),
780 )
781 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
782
783 let http = self.http.clone();
784 cx.background().spawn(async move {
785 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
786 let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
787 if rpc_response.status().is_redirection() {
788 rpc_url = rpc_response
789 .headers()
790 .get("Location")
791 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
792 .to_str()
793 .map_err(|error| EstablishConnectionError::other(error))?
794 .to_string();
795 }
796 // Until we switch the zed.dev domain to point to the new Next.js app, there
797 // will be no redirect required, and the app will connect directly to
798 // wss://zed.dev/rpc.
799 else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
800 Err(anyhow!(
801 "unexpected /rpc response status {}",
802 rpc_response.status()
803 ))?
804 }
805
806 let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
807 let rpc_host = rpc_url
808 .host_str()
809 .zip(rpc_url.port_or_known_default())
810 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
811 let stream = smol::net::TcpStream::connect(rpc_host).await?;
812
813 log::info!("connected to rpc endpoint {}", rpc_url);
814
815 match rpc_url.scheme() {
816 "https" => {
817 rpc_url.set_scheme("wss").unwrap();
818 let request = request.uri(rpc_url.as_str()).body(())?;
819 let (stream, _) =
820 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
821 Ok(Connection::new(
822 stream
823 .map_err(|error| anyhow!(error))
824 .sink_map_err(|error| anyhow!(error)),
825 ))
826 }
827 "http" => {
828 rpc_url.set_scheme("ws").unwrap();
829 let request = request.uri(rpc_url.as_str()).body(())?;
830 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
831 Ok(Connection::new(
832 stream
833 .map_err(|error| anyhow!(error))
834 .sink_map_err(|error| anyhow!(error)),
835 ))
836 }
837 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
838 }
839 })
840 }
841
842 pub fn authenticate_with_browser(
843 self: &Arc<Self>,
844 cx: &AsyncAppContext,
845 ) -> Task<Result<Credentials>> {
846 let platform = cx.platform();
847 let executor = cx.background();
848 executor.clone().spawn(async move {
849 // Generate a pair of asymmetric encryption keys. The public key will be used by the
850 // zed server to encrypt the user's access token, so that it can'be intercepted by
851 // any other app running on the user's device.
852 let (public_key, private_key) =
853 rpc::auth::keypair().expect("failed to generate keypair for auth");
854 let public_key_string =
855 String::try_from(public_key).expect("failed to serialize public key for auth");
856
857 // Start an HTTP server to receive the redirect from Zed's sign-in page.
858 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
859 let port = server.server_addr().port();
860
861 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
862 // that the user is signing in from a Zed app running on the same device.
863 let mut url = format!(
864 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
865 *ZED_SERVER_URL, port, public_key_string
866 );
867
868 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
869 log::info!("impersonating user @{}", impersonate_login);
870 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
871 }
872
873 platform.open_url(&url);
874
875 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
876 // access token from the query params.
877 //
878 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
879 // custom URL scheme instead of this local HTTP server.
880 let (user_id, access_token) = executor
881 .spawn(async move {
882 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
883 let path = req.url();
884 let mut user_id = None;
885 let mut access_token = None;
886 let url = Url::parse(&format!("http://example.com{}", path))
887 .context("failed to parse login notification url")?;
888 for (key, value) in url.query_pairs() {
889 if key == "access_token" {
890 access_token = Some(value.to_string());
891 } else if key == "user_id" {
892 user_id = Some(value.to_string());
893 }
894 }
895
896 let post_auth_url =
897 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
898 req.respond(
899 tiny_http::Response::empty(302).with_header(
900 tiny_http::Header::from_bytes(
901 &b"Location"[..],
902 post_auth_url.as_bytes(),
903 )
904 .unwrap(),
905 ),
906 )
907 .context("failed to respond to login http request")?;
908 Ok((
909 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
910 access_token
911 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
912 ))
913 } else {
914 Err(anyhow!("didn't receive login redirect"))
915 }
916 })
917 .await?;
918
919 let access_token = private_key
920 .decrypt_string(&access_token)
921 .context("failed to decrypt access token")?;
922 platform.activate(true);
923
924 Ok(Credentials {
925 user_id: user_id.parse()?,
926 access_token,
927 })
928 })
929 }
930
931 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
932 let conn_id = self.connection_id()?;
933 self.peer.disconnect(conn_id);
934 self.set_status(Status::SignedOut, cx);
935 Ok(())
936 }
937
938 fn connection_id(&self) -> Result<ConnectionId> {
939 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
940 Ok(connection_id)
941 } else {
942 Err(anyhow!("not connected"))
943 }
944 }
945
946 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
947 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
948 self.peer.send(self.connection_id()?, message)
949 }
950
951 pub fn request<T: RequestMessage>(
952 &self,
953 request: T,
954 ) -> impl Future<Output = Result<T::Response>> {
955 let client_id = self.id;
956 log::debug!(
957 "rpc request start. client_id:{}. name:{}",
958 client_id,
959 T::NAME
960 );
961 let response = self
962 .connection_id()
963 .map(|conn_id| self.peer.request(conn_id, request));
964 async move {
965 let response = response?.await;
966 log::debug!(
967 "rpc request finish. client_id:{}. name:{}",
968 client_id,
969 T::NAME
970 );
971 response
972 }
973 }
974
975 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
976 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
977 self.peer.respond(receipt, response)
978 }
979
980 fn respond_with_error<T: RequestMessage>(
981 &self,
982 receipt: Receipt<T>,
983 error: proto::Error,
984 ) -> Result<()> {
985 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
986 self.peer.respond_with_error(receipt, error)
987 }
988}
989
990impl AnyWeakEntityHandle {
991 fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
992 match self {
993 AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
994 AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
995 }
996 }
997}
998
999fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1000 if IMPERSONATE_LOGIN.is_some() {
1001 return None;
1002 }
1003
1004 let (user_id, access_token) = cx
1005 .platform()
1006 .read_credentials(&ZED_SERVER_URL)
1007 .log_err()
1008 .flatten()?;
1009 Some(Credentials {
1010 user_id: user_id.parse().ok()?,
1011 access_token: String::from_utf8(access_token).ok()?,
1012 })
1013}
1014
1015fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1016 cx.platform().write_credentials(
1017 &ZED_SERVER_URL,
1018 &credentials.user_id.to_string(),
1019 credentials.access_token.as_bytes(),
1020 )
1021}
1022
1023const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1024
1025pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1026 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1027}
1028
1029pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1030 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1031 let mut parts = path.split('/');
1032 let id = parts.next()?.parse::<u64>().ok()?;
1033 let access_token = parts.next()?;
1034 if access_token.is_empty() {
1035 return None;
1036 }
1037 Some((id, access_token.to_string()))
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042 use super::*;
1043 use crate::test::{FakeHttpClient, FakeServer};
1044 use gpui::TestAppContext;
1045
1046 #[gpui::test(iterations = 10)]
1047 async fn test_reconnection(cx: &mut TestAppContext) {
1048 cx.foreground().forbid_parking();
1049
1050 let user_id = 5;
1051 let mut client = Client::new(FakeHttpClient::with_404_response());
1052 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1053 let mut status = client.status();
1054 assert!(matches!(
1055 status.next().await,
1056 Some(Status::Connected { .. })
1057 ));
1058 assert_eq!(server.auth_count(), 1);
1059
1060 server.forbid_connections();
1061 server.disconnect();
1062 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1063
1064 server.allow_connections();
1065 cx.foreground().advance_clock(Duration::from_secs(10));
1066 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1067 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1068
1069 server.forbid_connections();
1070 server.disconnect();
1071 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1072
1073 // Clear cached credentials after authentication fails
1074 server.roll_access_token();
1075 server.allow_connections();
1076 cx.foreground().advance_clock(Duration::from_secs(10));
1077 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1078 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1079 }
1080
1081 #[test]
1082 fn test_encode_and_decode_worktree_url() {
1083 let url = encode_worktree_url(5, "deadbeef");
1084 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1085 assert_eq!(
1086 decode_worktree_url(&format!("\n {}\t", url)),
1087 Some((5, "deadbeef".to_string()))
1088 );
1089 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1090 }
1091
1092 #[gpui::test]
1093 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1094 cx.foreground().forbid_parking();
1095
1096 let user_id = 5;
1097 let mut client = Client::new(FakeHttpClient::with_404_response());
1098 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1099
1100 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1101 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1102 client.add_model_message_handler(
1103 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::UnshareProject>, _, cx| {
1104 match model.read_with(&cx, |model, _| model.id) {
1105 1 => done_tx1.try_send(()).unwrap(),
1106 2 => done_tx2.try_send(()).unwrap(),
1107 _ => unreachable!(),
1108 }
1109 async { Ok(()) }
1110 },
1111 );
1112 let model1 = cx.add_model(|_| Model {
1113 id: 1,
1114 subscription: None,
1115 });
1116 let model2 = cx.add_model(|_| Model {
1117 id: 2,
1118 subscription: None,
1119 });
1120 let model3 = cx.add_model(|_| Model {
1121 id: 3,
1122 subscription: None,
1123 });
1124
1125 let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1126 let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1127 // Ensure dropping a subscription for the same entity type still allows receiving of
1128 // messages for other entity IDs of the same type.
1129 let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1130 drop(subscription3);
1131
1132 server.send(proto::UnshareProject { project_id: 1 });
1133 server.send(proto::UnshareProject { project_id: 2 });
1134 done_rx1.next().await.unwrap();
1135 done_rx2.next().await.unwrap();
1136 }
1137
1138 #[gpui::test]
1139 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1140 cx.foreground().forbid_parking();
1141
1142 let user_id = 5;
1143 let mut client = Client::new(FakeHttpClient::with_404_response());
1144 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1145
1146 let model = cx.add_model(|_| Model::default());
1147 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1148 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1149 let subscription1 = client.add_message_handler(
1150 model.clone(),
1151 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1152 done_tx1.try_send(()).unwrap();
1153 async { Ok(()) }
1154 },
1155 );
1156 drop(subscription1);
1157 let _subscription2 =
1158 client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1159 done_tx2.try_send(()).unwrap();
1160 async { Ok(()) }
1161 });
1162 server.send(proto::Ping {});
1163 done_rx2.next().await.unwrap();
1164 }
1165
1166 #[gpui::test]
1167 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1168 cx.foreground().forbid_parking();
1169
1170 let user_id = 5;
1171 let mut client = Client::new(FakeHttpClient::with_404_response());
1172 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1173
1174 let model = cx.add_model(|_| Model::default());
1175 let (done_tx, mut done_rx) = smol::channel::unbounded();
1176 let subscription = client.add_message_handler(
1177 model.clone(),
1178 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1179 model.update(&mut cx, |model, _| model.subscription.take());
1180 done_tx.try_send(()).unwrap();
1181 async { Ok(()) }
1182 },
1183 );
1184 model.update(cx, |model, _| {
1185 model.subscription = Some(subscription);
1186 });
1187 server.send(proto::Ping {});
1188 done_rx.next().await.unwrap();
1189 }
1190
1191 #[derive(Default)]
1192 struct Model {
1193 id: usize,
1194 subscription: Option<Subscription>,
1195 }
1196
1197 impl Entity for Model {
1198 type Event = ();
1199 }
1200}