1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod user;
7
8use anyhow::{anyhow, Context, Result};
9use async_recursion::async_recursion;
10use async_tungstenite::tungstenite::{
11 error::Error as WebsocketError,
12 http::{Request, StatusCode},
13};
14use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
15use gpui::{
16 actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
17 Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
18};
19use http::HttpClient;
20use lazy_static::lazy_static;
21use parking_lot::RwLock;
22use postage::watch;
23use rand::prelude::*;
24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
25use std::{
26 any::TypeId,
27 collections::HashMap,
28 convert::TryFrom,
29 fmt::Write as _,
30 future::Future,
31 sync::{
32 atomic::{AtomicUsize, Ordering},
33 Arc, Weak,
34 },
35 time::{Duration, Instant},
36};
37use thiserror::Error;
38use url::Url;
39use util::{ResultExt, TryFutureExt};
40
41pub use channel::*;
42pub use rpc::*;
43pub use user::*;
44
45lazy_static! {
46 pub static ref ZED_SERVER_URL: String =
47 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
48 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
49 .ok()
50 .and_then(|s| if s.is_empty() { None } else { Some(s) });
51}
52
53pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894";
54
55actions!(client, [Authenticate]);
56
57pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
58 cx.add_global_action(move |_: &Authenticate, cx| {
59 let rpc = rpc.clone();
60 cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
61 .detach();
62 });
63}
64
65pub struct Client {
66 id: usize,
67 peer: Arc<Peer>,
68 http: Arc<dyn HttpClient>,
69 state: RwLock<ClientState>,
70 authenticate:
71 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
72 establish_connection: Option<
73 Box<
74 dyn 'static
75 + Send
76 + Sync
77 + Fn(
78 &Credentials,
79 &AsyncAppContext,
80 ) -> Task<Result<Connection, EstablishConnectionError>>,
81 >,
82 >,
83}
84
85#[derive(Error, Debug)]
86pub enum EstablishConnectionError {
87 #[error("upgrade required")]
88 UpgradeRequired,
89 #[error("unauthorized")]
90 Unauthorized,
91 #[error("{0}")]
92 Other(#[from] anyhow::Error),
93 #[error("{0}")]
94 Http(#[from] http::Error),
95 #[error("{0}")]
96 Io(#[from] std::io::Error),
97 #[error("{0}")]
98 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
99}
100
101impl From<WebsocketError> for EstablishConnectionError {
102 fn from(error: WebsocketError) -> Self {
103 if let WebsocketError::Http(response) = &error {
104 match response.status() {
105 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
106 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
107 _ => {}
108 }
109 }
110 EstablishConnectionError::Other(error.into())
111 }
112}
113
114impl EstablishConnectionError {
115 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
116 Self::Other(error.into())
117 }
118}
119
120#[derive(Copy, Clone, Debug, Eq, PartialEq)]
121pub enum Status {
122 SignedOut,
123 UpgradeRequired,
124 Authenticating,
125 Connecting,
126 ConnectionError,
127 Connected { connection_id: ConnectionId },
128 ConnectionLost,
129 Reauthenticating,
130 Reconnecting,
131 ReconnectionError { next_reconnection: Instant },
132}
133
134impl Status {
135 pub fn is_connected(&self) -> bool {
136 matches!(self, Self::Connected { .. })
137 }
138}
139
140struct ClientState {
141 credentials: Option<Credentials>,
142 status: (watch::Sender<Status>, watch::Receiver<Status>),
143 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
144 _reconnect_task: Option<Task<()>>,
145 reconnect_interval: Duration,
146 entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
147 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
148 entity_types_by_message_type: HashMap<TypeId, TypeId>,
149 message_handlers: HashMap<
150 TypeId,
151 Arc<
152 dyn Send
153 + Sync
154 + Fn(
155 AnyEntityHandle,
156 Box<dyn AnyTypedEnvelope>,
157 &Arc<Client>,
158 AsyncAppContext,
159 ) -> LocalBoxFuture<'static, Result<()>>,
160 >,
161 >,
162}
163
164enum AnyWeakEntityHandle {
165 Model(AnyWeakModelHandle),
166 View(AnyWeakViewHandle),
167}
168
169enum AnyEntityHandle {
170 Model(AnyModelHandle),
171 View(AnyViewHandle),
172}
173
174#[derive(Clone, Debug)]
175pub struct Credentials {
176 pub user_id: u64,
177 pub access_token: String,
178}
179
180impl Default for ClientState {
181 fn default() -> Self {
182 Self {
183 credentials: None,
184 status: watch::channel_with(Status::SignedOut),
185 entity_id_extractors: Default::default(),
186 _reconnect_task: None,
187 reconnect_interval: Duration::from_secs(5),
188 models_by_message_type: Default::default(),
189 entities_by_type_and_remote_id: Default::default(),
190 entity_types_by_message_type: Default::default(),
191 message_handlers: Default::default(),
192 }
193 }
194}
195
196pub enum Subscription {
197 Entity {
198 client: Weak<Client>,
199 id: (TypeId, u64),
200 },
201 Message {
202 client: Weak<Client>,
203 id: TypeId,
204 },
205}
206
207impl Drop for Subscription {
208 fn drop(&mut self) {
209 match self {
210 Subscription::Entity { client, id } => {
211 if let Some(client) = client.upgrade() {
212 let mut state = client.state.write();
213 let _ = state.entities_by_type_and_remote_id.remove(id);
214 }
215 }
216 Subscription::Message { client, id } => {
217 if let Some(client) = client.upgrade() {
218 let mut state = client.state.write();
219 let _ = state.entity_types_by_message_type.remove(id);
220 let _ = state.message_handlers.remove(id);
221 }
222 }
223 }
224 }
225}
226
227impl Client {
228 pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
229 lazy_static! {
230 static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
231 }
232
233 Arc::new(Self {
234 id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
235 peer: Peer::new(),
236 http,
237 state: Default::default(),
238 authenticate: None,
239 establish_connection: None,
240 })
241 }
242
243 pub fn id(&self) -> usize {
244 self.id
245 }
246
247 pub fn http_client(&self) -> Arc<dyn HttpClient> {
248 self.http.clone()
249 }
250
251 #[cfg(any(test, feature = "test-support"))]
252 pub fn tear_down(&self) {
253 let mut state = self.state.write();
254 state._reconnect_task.take();
255 state.message_handlers.clear();
256 state.models_by_message_type.clear();
257 state.entities_by_type_and_remote_id.clear();
258 state.entity_id_extractors.clear();
259 self.peer.reset();
260 }
261
262 #[cfg(any(test, feature = "test-support"))]
263 pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
264 where
265 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
266 {
267 self.authenticate = Some(Box::new(authenticate));
268 self
269 }
270
271 #[cfg(any(test, feature = "test-support"))]
272 pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
273 where
274 F: 'static
275 + Send
276 + Sync
277 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
278 {
279 self.establish_connection = Some(Box::new(connect));
280 self
281 }
282
283 pub fn user_id(&self) -> Option<u64> {
284 self.state
285 .read()
286 .credentials
287 .as_ref()
288 .map(|credentials| credentials.user_id)
289 }
290
291 pub fn status(&self) -> watch::Receiver<Status> {
292 self.state.read().status.1.clone()
293 }
294
295 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
296 log::info!("set status on client {}: {:?}", self.id, status);
297 let mut state = self.state.write();
298 *state.status.0.borrow_mut() = status;
299
300 match status {
301 Status::Connected { .. } => {
302 state._reconnect_task = None;
303 }
304 Status::ConnectionLost => {
305 let this = self.clone();
306 let reconnect_interval = state.reconnect_interval;
307 state._reconnect_task = Some(cx.spawn(|cx| async move {
308 let mut rng = StdRng::from_entropy();
309 let mut delay = Duration::from_millis(100);
310 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
311 log::error!("failed to connect {}", error);
312 this.set_status(
313 Status::ReconnectionError {
314 next_reconnection: Instant::now() + delay,
315 },
316 &cx,
317 );
318 cx.background().timer(delay).await;
319 delay = delay
320 .mul_f32(rng.gen_range(1.0..=2.0))
321 .min(reconnect_interval);
322 }
323 }));
324 }
325 Status::SignedOut | Status::UpgradeRequired => {
326 state._reconnect_task.take();
327 }
328 _ => {}
329 }
330 }
331
332 pub fn add_view_for_remote_entity<T: View>(
333 self: &Arc<Self>,
334 remote_id: u64,
335 cx: &mut ViewContext<T>,
336 ) -> Subscription {
337 let id = (TypeId::of::<T>(), remote_id);
338 self.state
339 .write()
340 .entities_by_type_and_remote_id
341 .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
342 Subscription::Entity {
343 client: Arc::downgrade(self),
344 id,
345 }
346 }
347
348 pub fn add_model_for_remote_entity<T: Entity>(
349 self: &Arc<Self>,
350 remote_id: u64,
351 cx: &mut ModelContext<T>,
352 ) -> Subscription {
353 let id = (TypeId::of::<T>(), remote_id);
354 self.state
355 .write()
356 .entities_by_type_and_remote_id
357 .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
358 Subscription::Entity {
359 client: Arc::downgrade(self),
360 id,
361 }
362 }
363
364 pub fn add_message_handler<M, E, H, F>(
365 self: &Arc<Self>,
366 model: ModelHandle<E>,
367 handler: H,
368 ) -> Subscription
369 where
370 M: EnvelopedMessage,
371 E: Entity,
372 H: 'static
373 + Send
374 + Sync
375 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
376 F: 'static + Future<Output = Result<()>>,
377 {
378 let message_type_id = TypeId::of::<M>();
379
380 let mut state = self.state.write();
381 state
382 .models_by_message_type
383 .insert(message_type_id, model.downgrade().into());
384
385 let prev_handler = state.message_handlers.insert(
386 message_type_id,
387 Arc::new(move |handle, envelope, client, cx| {
388 let handle = if let AnyEntityHandle::Model(handle) = handle {
389 handle
390 } else {
391 unreachable!();
392 };
393 let model = handle.downcast::<E>().unwrap();
394 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
395 handler(model, *envelope, client.clone(), cx).boxed_local()
396 }),
397 );
398 if prev_handler.is_some() {
399 panic!("registered handler for the same message twice");
400 }
401
402 Subscription::Message {
403 client: Arc::downgrade(self),
404 id: message_type_id,
405 }
406 }
407
408 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
409 where
410 M: EntityMessage,
411 E: View,
412 H: 'static
413 + Send
414 + Sync
415 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
416 F: 'static + Future<Output = Result<()>>,
417 {
418 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
419 if let AnyEntityHandle::View(handle) = handle {
420 handler(handle.downcast::<E>().unwrap(), message, client, cx)
421 } else {
422 unreachable!();
423 }
424 })
425 }
426
427 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
428 where
429 M: EntityMessage,
430 E: Entity,
431 H: 'static
432 + Send
433 + Sync
434 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
435 F: 'static + Future<Output = Result<()>>,
436 {
437 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
438 if let AnyEntityHandle::Model(handle) = handle {
439 handler(handle.downcast::<E>().unwrap(), message, client, cx)
440 } else {
441 unreachable!();
442 }
443 })
444 }
445
446 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
447 where
448 M: EntityMessage,
449 E: Entity,
450 H: 'static
451 + Send
452 + Sync
453 + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
454 F: 'static + Future<Output = Result<()>>,
455 {
456 let model_type_id = TypeId::of::<E>();
457 let message_type_id = TypeId::of::<M>();
458
459 let mut state = self.state.write();
460 state
461 .entity_types_by_message_type
462 .insert(message_type_id, model_type_id);
463 state
464 .entity_id_extractors
465 .entry(message_type_id)
466 .or_insert_with(|| {
467 |envelope| {
468 envelope
469 .as_any()
470 .downcast_ref::<TypedEnvelope<M>>()
471 .unwrap()
472 .payload
473 .remote_entity_id()
474 }
475 });
476 let prev_handler = state.message_handlers.insert(
477 message_type_id,
478 Arc::new(move |handle, envelope, client, cx| {
479 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
480 handler(handle, *envelope, client.clone(), cx).boxed_local()
481 }),
482 );
483 if prev_handler.is_some() {
484 panic!("registered handler for the same message twice");
485 }
486 }
487
488 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
489 where
490 M: EntityMessage + RequestMessage,
491 E: Entity,
492 H: 'static
493 + Send
494 + Sync
495 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
496 F: 'static + Future<Output = Result<M::Response>>,
497 {
498 self.add_model_message_handler(move |entity, envelope, client, cx| {
499 Self::respond_to_request::<M, _>(
500 envelope.receipt(),
501 handler(entity, envelope, client.clone(), cx),
502 client,
503 )
504 })
505 }
506
507 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
508 where
509 M: EntityMessage + RequestMessage,
510 E: View,
511 H: 'static
512 + Send
513 + Sync
514 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
515 F: 'static + Future<Output = Result<M::Response>>,
516 {
517 self.add_view_message_handler(move |entity, envelope, client, cx| {
518 Self::respond_to_request::<M, _>(
519 envelope.receipt(),
520 handler(entity, envelope, client.clone(), cx),
521 client,
522 )
523 })
524 }
525
526 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
527 receipt: Receipt<T>,
528 response: F,
529 client: Arc<Self>,
530 ) -> Result<()> {
531 match response.await {
532 Ok(response) => {
533 client.respond(receipt, response)?;
534 Ok(())
535 }
536 Err(error) => {
537 client.respond_with_error(
538 receipt,
539 proto::Error {
540 message: error.to_string(),
541 },
542 )?;
543 Err(error)
544 }
545 }
546 }
547
548 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
549 read_credentials_from_keychain(cx).is_some()
550 }
551
552 #[async_recursion(?Send)]
553 pub async fn authenticate_and_connect(
554 self: &Arc<Self>,
555 try_keychain: bool,
556 cx: &AsyncAppContext,
557 ) -> anyhow::Result<()> {
558 let was_disconnected = match *self.status().borrow() {
559 Status::SignedOut => true,
560 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
561 false
562 }
563 Status::Connected { .. }
564 | Status::Connecting { .. }
565 | Status::Reconnecting { .. }
566 | Status::Authenticating
567 | Status::Reauthenticating => return Ok(()),
568 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
569 };
570
571 if was_disconnected {
572 self.set_status(Status::Authenticating, cx);
573 } else {
574 self.set_status(Status::Reauthenticating, cx)
575 }
576
577 let mut read_from_keychain = false;
578 let mut credentials = self.state.read().credentials.clone();
579 if credentials.is_none() && try_keychain {
580 credentials = read_credentials_from_keychain(cx);
581 read_from_keychain = credentials.is_some();
582 }
583 if credentials.is_none() {
584 credentials = Some(match self.authenticate(&cx).await {
585 Ok(credentials) => credentials,
586 Err(err) => {
587 self.set_status(Status::ConnectionError, cx);
588 return Err(err);
589 }
590 });
591 }
592 let credentials = credentials.unwrap();
593
594 if was_disconnected {
595 self.set_status(Status::Connecting, cx);
596 } else {
597 self.set_status(Status::Reconnecting, cx);
598 }
599
600 match self.establish_connection(&credentials, cx).await {
601 Ok(conn) => {
602 self.state.write().credentials = Some(credentials.clone());
603 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
604 write_credentials_to_keychain(&credentials, cx).log_err();
605 }
606 self.set_connection(conn, cx).await;
607 Ok(())
608 }
609 Err(EstablishConnectionError::Unauthorized) => {
610 self.state.write().credentials.take();
611 if read_from_keychain {
612 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
613 self.set_status(Status::SignedOut, cx);
614 self.authenticate_and_connect(false, cx).await
615 } else {
616 self.set_status(Status::ConnectionError, cx);
617 Err(EstablishConnectionError::Unauthorized)?
618 }
619 }
620 Err(EstablishConnectionError::UpgradeRequired) => {
621 self.set_status(Status::UpgradeRequired, cx);
622 Err(EstablishConnectionError::UpgradeRequired)?
623 }
624 Err(error) => {
625 self.set_status(Status::ConnectionError, cx);
626 Err(error)?
627 }
628 }
629 }
630
631 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
632 let executor = cx.background();
633 log::info!("add connection to peer");
634 let (connection_id, handle_io, mut incoming) = self
635 .peer
636 .add_connection(conn, move |duration| executor.timer(duration))
637 .await;
638 log::info!("set status to connected {}", connection_id);
639 self.set_status(Status::Connected { connection_id }, cx);
640 cx.foreground()
641 .spawn({
642 let cx = cx.clone();
643 let this = self.clone();
644 async move {
645 let mut message_id = 0_usize;
646 while let Some(message) = incoming.next().await {
647 let mut state = this.state.write();
648 message_id += 1;
649 let type_name = message.payload_type_name();
650 let payload_type_id = message.payload_type_id();
651 let sender_id = message.original_sender_id().map(|id| id.0);
652
653 let model = state
654 .models_by_message_type
655 .get(&payload_type_id)
656 .and_then(|model| model.upgrade(&cx))
657 .map(AnyEntityHandle::Model)
658 .or_else(|| {
659 let entity_type_id =
660 *state.entity_types_by_message_type.get(&payload_type_id)?;
661 let entity_id = state
662 .entity_id_extractors
663 .get(&message.payload_type_id())
664 .map(|extract_entity_id| {
665 (extract_entity_id)(message.as_ref())
666 })?;
667
668 let entity = state
669 .entities_by_type_and_remote_id
670 .get(&(entity_type_id, entity_id))?;
671 if let Some(entity) = entity.upgrade(&cx) {
672 Some(entity)
673 } else {
674 state
675 .entities_by_type_and_remote_id
676 .remove(&(entity_type_id, entity_id));
677 None
678 }
679 });
680
681 let model = if let Some(model) = model {
682 model
683 } else {
684 log::info!("unhandled message {}", type_name);
685 continue;
686 };
687
688 if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
689 {
690 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
691 let future = handler(model, message, &this, cx.clone());
692
693 let client_id = this.id;
694 log::debug!(
695 "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
696 client_id,
697 message_id,
698 sender_id,
699 type_name
700 );
701 cx.foreground()
702 .spawn(async move {
703 match future.await {
704 Ok(()) => {
705 log::debug!(
706 "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
707 client_id,
708 message_id,
709 sender_id,
710 type_name
711 );
712 }
713 Err(error) => {
714 log::error!(
715 "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
716 client_id,
717 message_id,
718 sender_id,
719 type_name,
720 error
721 );
722 }
723 }
724 })
725 .detach();
726 } else {
727 log::info!("unhandled message {}", type_name);
728 }
729
730 // Don't starve the main thread when receiving lots of messages at once.
731 smol::future::yield_now().await;
732 }
733 }
734 })
735 .detach();
736
737 let handle_io = cx.background().spawn(handle_io);
738 let this = self.clone();
739 let cx = cx.clone();
740 cx.foreground()
741 .spawn(async move {
742 match handle_io.await {
743 Ok(()) => {
744 if *this.status().borrow() == (Status::Connected { connection_id }) {
745 this.set_status(Status::SignedOut, &cx);
746 }
747 }
748 Err(err) => {
749 log::error!("connection error: {:?}", err);
750 this.set_status(Status::ConnectionLost, &cx);
751 }
752 }
753 })
754 .detach();
755 }
756
757 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
758 if let Some(callback) = self.authenticate.as_ref() {
759 callback(cx)
760 } else {
761 self.authenticate_with_browser(cx)
762 }
763 }
764
765 fn establish_connection(
766 self: &Arc<Self>,
767 credentials: &Credentials,
768 cx: &AsyncAppContext,
769 ) -> Task<Result<Connection, EstablishConnectionError>> {
770 if let Some(callback) = self.establish_connection.as_ref() {
771 callback(credentials, cx)
772 } else {
773 self.establish_websocket_connection(credentials, cx)
774 }
775 }
776
777 fn establish_websocket_connection(
778 self: &Arc<Self>,
779 credentials: &Credentials,
780 cx: &AsyncAppContext,
781 ) -> Task<Result<Connection, EstablishConnectionError>> {
782 let request = Request::builder()
783 .header(
784 "Authorization",
785 format!("{} {}", credentials.user_id, credentials.access_token),
786 )
787 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
788
789 let http = self.http.clone();
790 cx.background().spawn(async move {
791 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
792 let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
793 if rpc_response.status().is_redirection() {
794 rpc_url = rpc_response
795 .headers()
796 .get("Location")
797 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
798 .to_str()
799 .map_err(|error| EstablishConnectionError::other(error))?
800 .to_string();
801 }
802 // Until we switch the zed.dev domain to point to the new Next.js app, there
803 // will be no redirect required, and the app will connect directly to
804 // wss://zed.dev/rpc.
805 else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
806 Err(anyhow!(
807 "unexpected /rpc response status {}",
808 rpc_response.status()
809 ))?
810 }
811
812 let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
813 let rpc_host = rpc_url
814 .host_str()
815 .zip(rpc_url.port_or_known_default())
816 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
817 let stream = smol::net::TcpStream::connect(rpc_host).await?;
818
819 log::info!("connected to rpc endpoint {}", rpc_url);
820
821 match rpc_url.scheme() {
822 "https" => {
823 rpc_url.set_scheme("wss").unwrap();
824 let request = request.uri(rpc_url.as_str()).body(())?;
825 let (stream, _) =
826 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
827 Ok(Connection::new(
828 stream
829 .map_err(|error| anyhow!(error))
830 .sink_map_err(|error| anyhow!(error)),
831 ))
832 }
833 "http" => {
834 rpc_url.set_scheme("ws").unwrap();
835 let request = request.uri(rpc_url.as_str()).body(())?;
836 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
837 Ok(Connection::new(
838 stream
839 .map_err(|error| anyhow!(error))
840 .sink_map_err(|error| anyhow!(error)),
841 ))
842 }
843 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
844 }
845 })
846 }
847
848 pub fn authenticate_with_browser(
849 self: &Arc<Self>,
850 cx: &AsyncAppContext,
851 ) -> Task<Result<Credentials>> {
852 let platform = cx.platform();
853 let executor = cx.background();
854 executor.clone().spawn(async move {
855 // Generate a pair of asymmetric encryption keys. The public key will be used by the
856 // zed server to encrypt the user's access token, so that it can'be intercepted by
857 // any other app running on the user's device.
858 let (public_key, private_key) =
859 rpc::auth::keypair().expect("failed to generate keypair for auth");
860 let public_key_string =
861 String::try_from(public_key).expect("failed to serialize public key for auth");
862
863 // Start an HTTP server to receive the redirect from Zed's sign-in page.
864 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
865 let port = server.server_addr().port();
866
867 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
868 // that the user is signing in from a Zed app running on the same device.
869 let mut url = format!(
870 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
871 *ZED_SERVER_URL, port, public_key_string
872 );
873
874 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
875 log::info!("impersonating user @{}", impersonate_login);
876 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
877 }
878
879 platform.open_url(&url);
880
881 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
882 // access token from the query params.
883 //
884 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
885 // custom URL scheme instead of this local HTTP server.
886 let (user_id, access_token) = executor
887 .spawn(async move {
888 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
889 let path = req.url();
890 let mut user_id = None;
891 let mut access_token = None;
892 let url = Url::parse(&format!("http://example.com{}", path))
893 .context("failed to parse login notification url")?;
894 for (key, value) in url.query_pairs() {
895 if key == "access_token" {
896 access_token = Some(value.to_string());
897 } else if key == "user_id" {
898 user_id = Some(value.to_string());
899 }
900 }
901
902 let post_auth_url =
903 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
904 req.respond(
905 tiny_http::Response::empty(302).with_header(
906 tiny_http::Header::from_bytes(
907 &b"Location"[..],
908 post_auth_url.as_bytes(),
909 )
910 .unwrap(),
911 ),
912 )
913 .context("failed to respond to login http request")?;
914 Ok((
915 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
916 access_token
917 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
918 ))
919 } else {
920 Err(anyhow!("didn't receive login redirect"))
921 }
922 })
923 .await?;
924
925 let access_token = private_key
926 .decrypt_string(&access_token)
927 .context("failed to decrypt access token")?;
928 platform.activate(true);
929
930 Ok(Credentials {
931 user_id: user_id.parse()?,
932 access_token,
933 })
934 })
935 }
936
937 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
938 let conn_id = self.connection_id()?;
939 self.peer.disconnect(conn_id);
940 self.set_status(Status::SignedOut, cx);
941 Ok(())
942 }
943
944 fn connection_id(&self) -> Result<ConnectionId> {
945 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
946 Ok(connection_id)
947 } else {
948 Err(anyhow!("not connected"))
949 }
950 }
951
952 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
953 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
954 self.peer.send(self.connection_id()?, message)
955 }
956
957 pub fn request<T: RequestMessage>(
958 &self,
959 request: T,
960 ) -> impl Future<Output = Result<T::Response>> {
961 let client_id = self.id;
962 log::debug!(
963 "rpc request start. client_id:{}. name:{}",
964 client_id,
965 T::NAME
966 );
967 let response = self
968 .connection_id()
969 .map(|conn_id| self.peer.request(conn_id, request));
970 async move {
971 let response = response?.await;
972 log::debug!(
973 "rpc request finish. client_id:{}. name:{}",
974 client_id,
975 T::NAME
976 );
977 response
978 }
979 }
980
981 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
982 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
983 self.peer.respond(receipt, response)
984 }
985
986 fn respond_with_error<T: RequestMessage>(
987 &self,
988 receipt: Receipt<T>,
989 error: proto::Error,
990 ) -> Result<()> {
991 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
992 self.peer.respond_with_error(receipt, error)
993 }
994}
995
996impl AnyWeakEntityHandle {
997 fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
998 match self {
999 AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1000 AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1001 }
1002 }
1003}
1004
1005fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1006 if IMPERSONATE_LOGIN.is_some() {
1007 return None;
1008 }
1009
1010 let (user_id, access_token) = cx
1011 .platform()
1012 .read_credentials(&ZED_SERVER_URL)
1013 .log_err()
1014 .flatten()?;
1015 Some(Credentials {
1016 user_id: user_id.parse().ok()?,
1017 access_token: String::from_utf8(access_token).ok()?,
1018 })
1019}
1020
1021fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1022 cx.platform().write_credentials(
1023 &ZED_SERVER_URL,
1024 &credentials.user_id.to_string(),
1025 credentials.access_token.as_bytes(),
1026 )
1027}
1028
1029const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1030
1031pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1032 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1033}
1034
1035pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1036 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1037 let mut parts = path.split('/');
1038 let id = parts.next()?.parse::<u64>().ok()?;
1039 let access_token = parts.next()?;
1040 if access_token.is_empty() {
1041 return None;
1042 }
1043 Some((id, access_token.to_string()))
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::*;
1049 use crate::test::{FakeHttpClient, FakeServer};
1050 use gpui::TestAppContext;
1051
1052 #[gpui::test(iterations = 10)]
1053 async fn test_reconnection(cx: &mut TestAppContext) {
1054 cx.foreground().forbid_parking();
1055
1056 let user_id = 5;
1057 let mut client = Client::new(FakeHttpClient::with_404_response());
1058 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1059 let mut status = client.status();
1060 assert!(matches!(
1061 status.next().await,
1062 Some(Status::Connected { .. })
1063 ));
1064 assert_eq!(server.auth_count(), 1);
1065
1066 server.forbid_connections();
1067 server.disconnect();
1068 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1069
1070 server.allow_connections();
1071 cx.foreground().advance_clock(Duration::from_secs(10));
1072 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1073 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1074
1075 server.forbid_connections();
1076 server.disconnect();
1077 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1078
1079 // Clear cached credentials after authentication fails
1080 server.roll_access_token();
1081 server.allow_connections();
1082 cx.foreground().advance_clock(Duration::from_secs(10));
1083 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1084 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1085 }
1086
1087 #[test]
1088 fn test_encode_and_decode_worktree_url() {
1089 let url = encode_worktree_url(5, "deadbeef");
1090 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1091 assert_eq!(
1092 decode_worktree_url(&format!("\n {}\t", url)),
1093 Some((5, "deadbeef".to_string()))
1094 );
1095 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1096 }
1097
1098 #[gpui::test]
1099 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1100 cx.foreground().forbid_parking();
1101
1102 let user_id = 5;
1103 let mut client = Client::new(FakeHttpClient::with_404_response());
1104 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1105
1106 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1107 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1108 client.add_model_message_handler(
1109 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1110 match model.read_with(&cx, |model, _| model.id) {
1111 1 => done_tx1.try_send(()).unwrap(),
1112 2 => done_tx2.try_send(()).unwrap(),
1113 _ => unreachable!(),
1114 }
1115 async { Ok(()) }
1116 },
1117 );
1118 let model1 = cx.add_model(|_| Model {
1119 id: 1,
1120 subscription: None,
1121 });
1122 let model2 = cx.add_model(|_| Model {
1123 id: 2,
1124 subscription: None,
1125 });
1126 let model3 = cx.add_model(|_| Model {
1127 id: 3,
1128 subscription: None,
1129 });
1130
1131 let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1132 let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1133 // Ensure dropping a subscription for the same entity type still allows receiving of
1134 // messages for other entity IDs of the same type.
1135 let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1136 drop(subscription3);
1137
1138 server.send(proto::JoinProject { project_id: 1 });
1139 server.send(proto::JoinProject { project_id: 2 });
1140 done_rx1.next().await.unwrap();
1141 done_rx2.next().await.unwrap();
1142 }
1143
1144 #[gpui::test]
1145 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1146 cx.foreground().forbid_parking();
1147
1148 let user_id = 5;
1149 let mut client = Client::new(FakeHttpClient::with_404_response());
1150 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1151
1152 let model = cx.add_model(|_| Model::default());
1153 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1154 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1155 let subscription1 = client.add_message_handler(
1156 model.clone(),
1157 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1158 done_tx1.try_send(()).unwrap();
1159 async { Ok(()) }
1160 },
1161 );
1162 drop(subscription1);
1163 let _subscription2 =
1164 client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1165 done_tx2.try_send(()).unwrap();
1166 async { Ok(()) }
1167 });
1168 server.send(proto::Ping {});
1169 done_rx2.next().await.unwrap();
1170 }
1171
1172 #[gpui::test]
1173 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1174 cx.foreground().forbid_parking();
1175
1176 let user_id = 5;
1177 let mut client = Client::new(FakeHttpClient::with_404_response());
1178 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1179
1180 let model = cx.add_model(|_| Model::default());
1181 let (done_tx, mut done_rx) = smol::channel::unbounded();
1182 let subscription = client.add_message_handler(
1183 model.clone(),
1184 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1185 model.update(&mut cx, |model, _| model.subscription.take());
1186 done_tx.try_send(()).unwrap();
1187 async { Ok(()) }
1188 },
1189 );
1190 model.update(cx, |model, _| {
1191 model.subscription = Some(subscription);
1192 });
1193 server.send(proto::Ping {});
1194 done_rx.next().await.unwrap();
1195 }
1196
1197 #[derive(Default)]
1198 struct Model {
1199 id: usize,
1200 subscription: Option<Subscription>,
1201 }
1202
1203 impl Entity for Model {
1204 type Event = ();
1205 }
1206}