rpc.rs

  1use anyhow::{anyhow, Context, Result};
  2use async_tungstenite::tungstenite::http::Request;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use gpui::{AsyncAppContext, Entity, ModelContext, Task};
  5use lazy_static::lazy_static;
  6use parking_lot::RwLock;
  7use postage::prelude::Stream;
  8use postage::sink::Sink;
  9use postage::watch;
 10use std::any::TypeId;
 11use std::collections::HashMap;
 12use std::sync::Weak;
 13use std::time::{Duration, Instant};
 14use std::{convert::TryFrom, future::Future, sync::Arc};
 15use surf::Url;
 16use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
 17pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
 18use zrpc::{
 19    proto::{EnvelopedMessage, RequestMessage},
 20    Peer, Receipt,
 21};
 22
 23lazy_static! {
 24    static ref ZED_SERVER_URL: String =
 25        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
 26}
 27
 28pub struct Client {
 29    peer: Arc<Peer>,
 30    state: RwLock<ClientState>,
 31}
 32
 33struct ClientState {
 34    connection_id: Option<ConnectionId>,
 35    user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
 36    entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
 37    model_handlers: HashMap<
 38        (TypeId, u64),
 39        Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
 40    >,
 41}
 42
 43impl Default for ClientState {
 44    fn default() -> Self {
 45        Self {
 46            connection_id: Default::default(),
 47            user_id: watch::channel(),
 48            entity_id_extractors: Default::default(),
 49            model_handlers: Default::default(),
 50        }
 51    }
 52}
 53
 54pub struct Subscription {
 55    client: Weak<Client>,
 56    id: (TypeId, u64),
 57}
 58
 59impl Drop for Subscription {
 60    fn drop(&mut self) {
 61        if let Some(client) = self.client.upgrade() {
 62            drop(
 63                client
 64                    .state
 65                    .write()
 66                    .model_handlers
 67                    .remove(&self.id)
 68                    .unwrap(),
 69            );
 70        }
 71    }
 72}
 73
 74impl Client {
 75    pub fn new() -> Arc<Self> {
 76        Arc::new(Self {
 77            peer: Peer::new(),
 78            state: Default::default(),
 79        })
 80    }
 81
 82    pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
 83        self.state.read().user_id.1.clone()
 84    }
 85
 86    pub fn subscribe_from_model<T, M, F>(
 87        self: &Arc<Self>,
 88        remote_id: u64,
 89        cx: &mut ModelContext<M>,
 90        mut handler: F,
 91    ) -> Subscription
 92    where
 93        T: EntityMessage,
 94        M: Entity,
 95        F: 'static
 96            + Send
 97            + Sync
 98            + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
 99    {
100        let subscription_id = (TypeId::of::<T>(), remote_id);
101        let client = self.clone();
102        let mut state = self.state.write();
103        let model = cx.handle().downgrade();
104        state
105            .entity_id_extractors
106            .entry(subscription_id.0)
107            .or_insert_with(|| {
108                Box::new(|envelope| {
109                    let envelope = envelope
110                        .as_any()
111                        .downcast_ref::<TypedEnvelope<T>>()
112                        .unwrap();
113                    envelope.payload.remote_entity_id()
114                })
115            });
116        let prev_handler = state.model_handlers.insert(
117            subscription_id,
118            Box::new(move |envelope, cx| {
119                if let Some(model) = model.upgrade(cx) {
120                    let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
121                    model.update(cx, |model, cx| {
122                        if let Err(error) = handler(model, *envelope, client.clone(), cx) {
123                            log::error!("error handling message: {}", error)
124                        }
125                    });
126                }
127            }),
128        );
129        if prev_handler.is_some() {
130            panic!("registered a handler for the same entity twice")
131        }
132
133        Subscription {
134            client: Arc::downgrade(self),
135            id: subscription_id,
136        }
137    }
138
139    pub async fn authenticate_and_connect(
140        self: &Arc<Self>,
141        cx: AsyncAppContext,
142    ) -> anyhow::Result<()> {
143        if self.state.read().connection_id.is_some() {
144            return Ok(());
145        }
146
147        let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
148        let user_id = user_id.parse::<u64>()?;
149        let request =
150            Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
151
152        if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
153            let stream = smol::net::TcpStream::connect(host).await?;
154            let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
155            let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
156                .await
157                .context("websocket handshake")?;
158            self.add_connection(user_id, stream, cx).await?;
159        } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
160            let stream = smol::net::TcpStream::connect(host).await?;
161            let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
162            let (stream, _) = async_tungstenite::client_async(request, stream)
163                .await
164                .context("websocket handshake")?;
165            self.add_connection(user_id, stream, cx).await?;
166        } else {
167            return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
168        };
169
170        log::info!("connected to rpc address {}", *ZED_SERVER_URL);
171        Ok(())
172    }
173
174    pub async fn add_connection<Conn>(
175        self: &Arc<Self>,
176        user_id: u64,
177        conn: Conn,
178        cx: AsyncAppContext,
179    ) -> anyhow::Result<()>
180    where
181        Conn: 'static
182            + futures::Sink<WebSocketMessage, Error = WebSocketError>
183            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
184            + Unpin
185            + Send,
186    {
187        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
188        {
189            let mut cx = cx.clone();
190            let this = self.clone();
191            cx.foreground()
192                .spawn(async move {
193                    while let Some(message) = incoming.recv().await {
194                        let mut state = this.state.write();
195                        if let Some(extract_entity_id) =
196                            state.entity_id_extractors.get(&message.payload_type_id())
197                        {
198                            let entity_id = (extract_entity_id)(message.as_ref());
199                            if let Some(handler) = state
200                                .model_handlers
201                                .get_mut(&(message.payload_type_id(), entity_id))
202                            {
203                                let start_time = Instant::now();
204                                log::info!("RPC client message {}", message.payload_type_name());
205                                (handler)(message, &mut cx);
206                                log::info!(
207                                    "RPC message handled. duration:{:?}",
208                                    start_time.elapsed()
209                                );
210                            } else {
211                                log::info!("unhandled message {}", message.payload_type_name());
212                            }
213                        } else {
214                            log::info!("unhandled message {}", message.payload_type_name());
215                        }
216                    }
217                })
218                .detach();
219        }
220        cx.background()
221            .spawn(async move {
222                if let Err(error) = handle_io.await {
223                    log::error!("connection error: {:?}", error);
224                }
225            })
226            .detach();
227        let mut state = self.state.write();
228        state.connection_id = Some(connection_id);
229        state.user_id.0.send(Some(user_id)).await?;
230        Ok(())
231    }
232
233    pub fn login(
234        platform: Arc<dyn gpui::Platform>,
235        executor: &Arc<gpui::executor::Background>,
236    ) -> Task<Result<(String, String)>> {
237        let executor = executor.clone();
238        executor.clone().spawn(async move {
239            if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
240                log::info!("already signed in. user_id: {}", user_id);
241                return Ok((user_id, String::from_utf8(access_token).unwrap()));
242            }
243
244            // Generate a pair of asymmetric encryption keys. The public key will be used by the
245            // zed server to encrypt the user's access token, so that it can'be intercepted by
246            // any other app running on the user's device.
247            let (public_key, private_key) =
248                zrpc::auth::keypair().expect("failed to generate keypair for auth");
249            let public_key_string =
250                String::try_from(public_key).expect("failed to serialize public key for auth");
251
252            // Start an HTTP server to receive the redirect from Zed's sign-in page.
253            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
254            let port = server.server_addr().port();
255
256            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
257            // that the user is signing in from a Zed app running on the same device.
258            platform.open_url(&format!(
259                "{}/sign_in?native_app_port={}&native_app_public_key={}",
260                *ZED_SERVER_URL, port, public_key_string
261            ));
262
263            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
264            // access token from the query params.
265            //
266            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
267            // custom URL scheme instead of this local HTTP server.
268            let (user_id, access_token) = executor
269                .spawn(async move {
270                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
271                        let path = req.url();
272                        let mut user_id = None;
273                        let mut access_token = None;
274                        let url = Url::parse(&format!("http://example.com{}", path))
275                            .context("failed to parse login notification url")?;
276                        for (key, value) in url.query_pairs() {
277                            if key == "access_token" {
278                                access_token = Some(value.to_string());
279                            } else if key == "user_id" {
280                                user_id = Some(value.to_string());
281                            }
282                        }
283                        req.respond(
284                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
285                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
286                            ),
287                        )
288                        .context("failed to respond to login http request")?;
289                        Ok((
290                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
291                            access_token
292                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
293                        ))
294                    } else {
295                        Err(anyhow!("didn't receive login redirect"))
296                    }
297                })
298                .await?;
299
300            let access_token = private_key
301                .decrypt_string(&access_token)
302                .context("failed to decrypt access token")?;
303            platform.activate(true);
304            platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
305            Ok((user_id.to_string(), access_token))
306        })
307    }
308
309    pub async fn disconnect(&self) -> Result<()> {
310        let conn_id = self.connection_id()?;
311        self.peer.disconnect(conn_id).await;
312        Ok(())
313    }
314
315    fn connection_id(&self) -> Result<ConnectionId> {
316        self.state
317            .read()
318            .connection_id
319            .ok_or_else(|| anyhow!("not connected"))
320    }
321
322    pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
323        self.peer.send(self.connection_id()?, message).await
324    }
325
326    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
327        self.peer.request(self.connection_id()?, request).await
328    }
329
330    pub fn respond<T: RequestMessage>(
331        &self,
332        receipt: Receipt<T>,
333        response: T::Response,
334    ) -> impl Future<Output = Result<()>> {
335        self.peer.respond(receipt, response)
336    }
337}
338
339pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
340    type Output: 'a + Future<Output = anyhow::Result<()>>;
341
342    fn handle(
343        &self,
344        message: TypedEnvelope<M>,
345        rpc: &'a Client,
346        cx: &'a mut gpui::AsyncAppContext,
347    ) -> Self::Output;
348}
349
350impl<'a, M, F, Fut> MessageHandler<'a, M> for F
351where
352    M: proto::EnvelopedMessage,
353    F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
354    Fut: 'a + Future<Output = anyhow::Result<()>>,
355{
356    type Output = Fut;
357
358    fn handle(
359        &self,
360        message: TypedEnvelope<M>,
361        rpc: &'a Client,
362        cx: &'a mut gpui::AsyncAppContext,
363    ) -> Self::Output {
364        (self)(message, rpc, cx)
365    }
366}
367
368const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
369
370pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
371    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
372}
373
374pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
375    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
376    let mut parts = path.split('/');
377    let id = parts.next()?.parse::<u64>().ok()?;
378    let access_token = parts.next()?;
379    if access_token.is_empty() {
380        return None;
381    }
382    Some((id, access_token.to_string()))
383}
384
385const LOGIN_RESPONSE: &'static str = "
386<!DOCTYPE html>
387<html>
388<script>window.close();</script>
389</html>
390";
391
392#[test]
393fn test_encode_and_decode_worktree_url() {
394    let url = encode_worktree_url(5, "deadbeef");
395    assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
396    assert_eq!(
397        decode_worktree_url(&format!("\n {}\t", url)),
398        Some((5, "deadbeef".to_string()))
399    );
400    assert_eq!(decode_worktree_url("not://the-right-format"), None);
401}