1use crate::util::ResultExt;
2use anyhow::{anyhow, Context, Result};
3use async_tungstenite::tungstenite::http::Request;
4use gpui::{AsyncAppContext, Entity, ModelContext, Task};
5use lazy_static::lazy_static;
6use parking_lot::RwLock;
7use postage::{prelude::Stream, watch};
8use std::{
9 any::TypeId,
10 collections::HashMap,
11 convert::TryFrom,
12 future::Future,
13 sync::{Arc, Weak},
14 time::{Duration, Instant},
15};
16use surf::Url;
17pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
18use zrpc::{
19 proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
20 Conn, Peer, Receipt,
21};
22
23lazy_static! {
24 static ref ZED_SERVER_URL: String =
25 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
26}
27
28pub struct Client {
29 peer: Arc<Peer>,
30 state: RwLock<ClientState>,
31 auth_callback: Option<
32 Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
33 >,
34 connect_callback: Option<
35 Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
36 >,
37}
38
39#[derive(Copy, Clone, Debug)]
40pub enum Status {
41 Disconnected,
42 Authenticating,
43 Connecting {
44 user_id: u64,
45 },
46 ConnectionError,
47 Connected {
48 connection_id: ConnectionId,
49 user_id: u64,
50 },
51 ConnectionLost,
52 Reauthenticating,
53 Reconnecting {
54 user_id: u64,
55 },
56 ReconnectionError {
57 next_reconnection: Instant,
58 },
59}
60
61struct ClientState {
62 status: (watch::Sender<Status>, watch::Receiver<Status>),
63 entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
64 model_handlers: HashMap<
65 (TypeId, u64),
66 Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
67 >,
68 _maintain_connection: Option<Task<()>>,
69 heartbeat_interval: Duration,
70}
71
72impl Default for ClientState {
73 fn default() -> Self {
74 Self {
75 status: watch::channel_with(Status::Disconnected),
76 entity_id_extractors: Default::default(),
77 model_handlers: Default::default(),
78 _maintain_connection: None,
79 heartbeat_interval: Duration::from_secs(5),
80 }
81 }
82}
83
84pub struct Subscription {
85 client: Weak<Client>,
86 id: (TypeId, u64),
87}
88
89impl Drop for Subscription {
90 fn drop(&mut self) {
91 if let Some(client) = self.client.upgrade() {
92 drop(
93 client
94 .state
95 .write()
96 .model_handlers
97 .remove(&self.id)
98 .unwrap(),
99 );
100 }
101 }
102}
103
104impl Client {
105 pub fn new() -> Arc<Self> {
106 Arc::new(Self {
107 peer: Peer::new(),
108 state: Default::default(),
109 auth_callback: None,
110 connect_callback: None,
111 })
112 }
113
114 #[cfg(any(test, feature = "test-support"))]
115 pub fn set_login_and_connect_callbacks<Login, Connect>(
116 &mut self,
117 login: Login,
118 connect: Connect,
119 ) where
120 Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
121 Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
122 {
123 self.auth_callback = Some(Box::new(login));
124 self.connect_callback = Some(Box::new(connect));
125 }
126
127 pub fn status(&self) -> watch::Receiver<Status> {
128 self.state.read().status.1.clone()
129 }
130
131 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
132 let mut state = self.state.write();
133 *state.status.0.borrow_mut() = status;
134
135 match status {
136 Status::Connected { .. } => {
137 let heartbeat_interval = state.heartbeat_interval;
138 let this = self.clone();
139 let foreground = cx.foreground();
140 state._maintain_connection = Some(cx.foreground().spawn(async move {
141 let mut next_ping_id = 0;
142 loop {
143 foreground.timer(heartbeat_interval).await;
144 this.request(proto::Ping { id: next_ping_id })
145 .await
146 .unwrap();
147 next_ping_id += 1;
148 }
149 }));
150 }
151 Status::ConnectionLost => {
152 let this = self.clone();
153 let foreground = cx.foreground();
154 state._maintain_connection = Some(cx.spawn(|cx| async move {
155 let mut delay_seconds = 5;
156 while let Err(error) = this.authenticate_and_connect(&cx).await {
157 log::error!("failed to connect {}", error);
158 let delay = Duration::from_secs(delay_seconds);
159 this.set_status(
160 Status::ReconnectionError {
161 next_reconnection: Instant::now() + delay,
162 },
163 &cx,
164 );
165 foreground.timer(delay).await;
166 delay_seconds = (delay_seconds * 2).min(300);
167 }
168 }));
169 }
170 Status::Disconnected => {
171 state._maintain_connection.take();
172 }
173 _ => {}
174 }
175 }
176
177 pub fn subscribe_from_model<T, M, F>(
178 self: &Arc<Self>,
179 remote_id: u64,
180 cx: &mut ModelContext<M>,
181 mut handler: F,
182 ) -> Subscription
183 where
184 T: EntityMessage,
185 M: Entity,
186 F: 'static
187 + Send
188 + Sync
189 + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
190 {
191 let subscription_id = (TypeId::of::<T>(), remote_id);
192 let client = self.clone();
193 let mut state = self.state.write();
194 let model = cx.handle().downgrade();
195 state
196 .entity_id_extractors
197 .entry(subscription_id.0)
198 .or_insert_with(|| {
199 Box::new(|envelope| {
200 let envelope = envelope
201 .as_any()
202 .downcast_ref::<TypedEnvelope<T>>()
203 .unwrap();
204 envelope.payload.remote_entity_id()
205 })
206 });
207 let prev_handler = state.model_handlers.insert(
208 subscription_id,
209 Box::new(move |envelope, cx| {
210 if let Some(model) = model.upgrade(cx) {
211 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
212 model.update(cx, |model, cx| {
213 if let Err(error) = handler(model, *envelope, client.clone(), cx) {
214 log::error!("error handling message: {}", error)
215 }
216 });
217 }
218 }),
219 );
220 if prev_handler.is_some() {
221 panic!("registered a handler for the same entity twice")
222 }
223
224 Subscription {
225 client: Arc::downgrade(self),
226 id: subscription_id,
227 }
228 }
229
230 pub async fn authenticate_and_connect(
231 self: &Arc<Self>,
232 cx: &AsyncAppContext,
233 ) -> anyhow::Result<()> {
234 let was_disconnected = match *self.status().borrow() {
235 Status::Disconnected => true,
236 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
237 false
238 }
239 Status::Connected { .. }
240 | Status::Connecting { .. }
241 | Status::Reconnecting { .. }
242 | Status::Authenticating
243 | Status::Reauthenticating => return Ok(()),
244 };
245
246 if was_disconnected {
247 self.set_status(Status::Authenticating, cx);
248 } else {
249 self.set_status(Status::Reauthenticating, cx)
250 }
251
252 let (user_id, access_token) = match self.authenticate(&cx).await {
253 Ok(result) => result,
254 Err(err) => {
255 self.set_status(Status::ConnectionError, cx);
256 return Err(err);
257 }
258 };
259
260 if was_disconnected {
261 self.set_status(Status::Connecting { user_id }, cx);
262 } else {
263 self.set_status(Status::Reconnecting { user_id }, cx);
264 }
265 match self.connect(user_id, &access_token, cx).await {
266 Ok(conn) => {
267 log::info!("connected to rpc address {}", *ZED_SERVER_URL);
268 self.set_connection(user_id, conn, cx).await;
269 Ok(())
270 }
271 Err(err) => {
272 self.set_status(Status::ConnectionError, cx);
273 Err(err)
274 }
275 }
276 }
277
278 async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
279 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
280 cx.foreground()
281 .spawn({
282 let mut cx = cx.clone();
283 let this = self.clone();
284 async move {
285 while let Some(message) = incoming.recv().await {
286 let mut state = this.state.write();
287 if let Some(extract_entity_id) =
288 state.entity_id_extractors.get(&message.payload_type_id())
289 {
290 let entity_id = (extract_entity_id)(message.as_ref());
291 if let Some(handler) = state
292 .model_handlers
293 .get_mut(&(message.payload_type_id(), entity_id))
294 {
295 let start_time = Instant::now();
296 log::info!("RPC client message {}", message.payload_type_name());
297 (handler)(message, &mut cx);
298 log::info!(
299 "RPC message handled. duration:{:?}",
300 start_time.elapsed()
301 );
302 } else {
303 log::info!("unhandled message {}", message.payload_type_name());
304 }
305 } else {
306 log::info!("unhandled message {}", message.payload_type_name());
307 }
308 }
309 }
310 })
311 .detach();
312
313 self.set_status(
314 Status::Connected {
315 connection_id,
316 user_id,
317 },
318 cx,
319 );
320
321 let handle_io = cx.background().spawn(handle_io);
322 let this = self.clone();
323 let cx = cx.clone();
324 cx.foreground()
325 .spawn(async move {
326 match handle_io.await {
327 Ok(()) => this.set_status(Status::Disconnected, &cx),
328 Err(err) => {
329 log::error!("connection error: {:?}", err);
330 this.set_status(Status::ConnectionLost, &cx);
331 }
332 }
333 })
334 .detach();
335 }
336
337 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
338 if let Some(callback) = self.auth_callback.as_ref() {
339 callback(cx)
340 } else {
341 self.authenticate_with_browser(cx)
342 }
343 }
344
345 fn connect(
346 self: &Arc<Self>,
347 user_id: u64,
348 access_token: &str,
349 cx: &AsyncAppContext,
350 ) -> Task<Result<Conn>> {
351 if let Some(callback) = self.connect_callback.as_ref() {
352 callback(user_id, access_token, cx)
353 } else {
354 self.connect_with_websocket(user_id, access_token, cx)
355 }
356 }
357
358 fn connect_with_websocket(
359 self: &Arc<Self>,
360 user_id: u64,
361 access_token: &str,
362 cx: &AsyncAppContext,
363 ) -> Task<Result<Conn>> {
364 let request =
365 Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
366 cx.background().spawn(async move {
367 if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
368 let stream = smol::net::TcpStream::connect(host).await?;
369 let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
370 let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
371 .await
372 .context("websocket handshake")?;
373 Ok(Conn::new(stream))
374 } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
375 let stream = smol::net::TcpStream::connect(host).await?;
376 let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
377 let (stream, _) = async_tungstenite::client_async(request, stream)
378 .await
379 .context("websocket handshake")?;
380 Ok(Conn::new(stream))
381 } else {
382 Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
383 }
384 })
385 }
386
387 pub fn authenticate_with_browser(
388 self: &Arc<Self>,
389 cx: &AsyncAppContext,
390 ) -> Task<Result<(u64, String)>> {
391 let platform = cx.platform();
392 let executor = cx.background();
393 executor.clone().spawn(async move {
394 if let Some((user_id, access_token)) = platform
395 .read_credentials(&ZED_SERVER_URL)
396 .log_err()
397 .flatten()
398 {
399 log::info!("already signed in. user_id: {}", user_id);
400 return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
401 }
402
403 // Generate a pair of asymmetric encryption keys. The public key will be used by the
404 // zed server to encrypt the user's access token, so that it can'be intercepted by
405 // any other app running on the user's device.
406 let (public_key, private_key) =
407 zrpc::auth::keypair().expect("failed to generate keypair for auth");
408 let public_key_string =
409 String::try_from(public_key).expect("failed to serialize public key for auth");
410
411 // Start an HTTP server to receive the redirect from Zed's sign-in page.
412 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
413 let port = server.server_addr().port();
414
415 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
416 // that the user is signing in from a Zed app running on the same device.
417 platform.open_url(&format!(
418 "{}/sign_in?native_app_port={}&native_app_public_key={}",
419 *ZED_SERVER_URL, port, public_key_string
420 ));
421
422 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
423 // access token from the query params.
424 //
425 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
426 // custom URL scheme instead of this local HTTP server.
427 let (user_id, access_token) = executor
428 .spawn(async move {
429 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
430 let path = req.url();
431 let mut user_id = None;
432 let mut access_token = None;
433 let url = Url::parse(&format!("http://example.com{}", path))
434 .context("failed to parse login notification url")?;
435 for (key, value) in url.query_pairs() {
436 if key == "access_token" {
437 access_token = Some(value.to_string());
438 } else if key == "user_id" {
439 user_id = Some(value.to_string());
440 }
441 }
442 req.respond(
443 tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
444 tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
445 ),
446 )
447 .context("failed to respond to login http request")?;
448 Ok((
449 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
450 access_token
451 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
452 ))
453 } else {
454 Err(anyhow!("didn't receive login redirect"))
455 }
456 })
457 .await?;
458
459 let access_token = private_key
460 .decrypt_string(&access_token)
461 .context("failed to decrypt access token")?;
462 platform.activate(true);
463 platform
464 .write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
465 .log_err();
466 Ok((user_id.parse()?, access_token))
467 })
468 }
469
470 pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
471 let conn_id = self.connection_id()?;
472 self.peer.disconnect(conn_id).await;
473 self.set_status(Status::Disconnected, cx);
474 Ok(())
475 }
476
477 fn connection_id(&self) -> Result<ConnectionId> {
478 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
479 Ok(connection_id)
480 } else {
481 Err(anyhow!("not connected"))
482 }
483 }
484
485 pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
486 self.peer.send(self.connection_id()?, message).await
487 }
488
489 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
490 self.peer.request(self.connection_id()?, request).await
491 }
492
493 pub fn respond<T: RequestMessage>(
494 &self,
495 receipt: Receipt<T>,
496 response: T::Response,
497 ) -> impl Future<Output = Result<()>> {
498 self.peer.respond(receipt, response)
499 }
500}
501
502const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
503
504pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
505 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
506}
507
508pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
509 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
510 let mut parts = path.split('/');
511 let id = parts.next()?.parse::<u64>().ok()?;
512 let access_token = parts.next()?;
513 if access_token.is_empty() {
514 return None;
515 }
516 Some((id, access_token.to_string()))
517}
518
519const LOGIN_RESPONSE: &'static str = "
520<!DOCTYPE html>
521<html>
522<script>window.close();</script>
523</html>
524";
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::test::FakeServer;
530 use gpui::TestAppContext;
531
532 #[gpui::test(iterations = 10)]
533 async fn test_heartbeat(cx: TestAppContext) {
534 cx.foreground().forbid_parking();
535
536 let user_id = 5;
537 let mut client = Client::new();
538 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
539
540 cx.foreground().advance_clock(Duration::from_secs(10));
541 let ping = server.receive::<proto::Ping>().await.unwrap();
542 assert_eq!(ping.payload.id, 0);
543 server.respond(ping.receipt(), proto::Pong { id: 0 }).await;
544
545 cx.foreground().advance_clock(Duration::from_secs(10));
546 let ping = server.receive::<proto::Ping>().await.unwrap();
547 assert_eq!(ping.payload.id, 1);
548 server.respond(ping.receipt(), proto::Pong { id: 1 }).await;
549
550 client.disconnect(&cx.to_async()).await.unwrap();
551 assert!(server.receive::<proto::Ping>().await.is_err());
552 }
553
554 #[gpui::test(iterations = 10)]
555 async fn test_reconnection(cx: TestAppContext) {
556 cx.foreground().forbid_parking();
557
558 let user_id = 5;
559 let mut client = Client::new();
560 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
561 let mut status = client.status();
562 assert!(matches!(
563 status.recv().await,
564 Some(Status::Connected { .. })
565 ));
566
567 server.forbid_connections();
568 server.disconnect().await;
569 while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
570
571 server.allow_connections();
572 cx.foreground().advance_clock(Duration::from_secs(10));
573 while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
574 }
575
576 #[test]
577 fn test_encode_and_decode_worktree_url() {
578 let url = encode_worktree_url(5, "deadbeef");
579 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
580 assert_eq!(
581 decode_worktree_url(&format!("\n {}\t", url)),
582 Some((5, "deadbeef".to_string()))
583 );
584 assert_eq!(decode_worktree_url("not://the-right-format"), None);
585 }
586}