rpc.rs

  1use crate::{language::LanguageRegistry, worktree::Worktree};
  2use anyhow::{anyhow, Context, Result};
  3use async_tungstenite::tungstenite::http::Request;
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::Stream;
  6use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
  7use lazy_static::lazy_static;
  8use smol::lock::RwLock;
  9use std::collections::HashMap;
 10use std::time::Duration;
 11use std::{convert::TryFrom, future::Future, sync::Arc};
 12use surf::Url;
 13pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
 14use zrpc::{
 15    proto::{EnvelopedMessage, RequestMessage},
 16    ForegroundRouter, Peer, Receipt,
 17};
 18
 19lazy_static! {
 20    static ref ZED_SERVER_URL: String =
 21        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
 22}
 23
 24#[derive(Clone)]
 25pub struct Client {
 26    peer: Arc<Peer>,
 27    pub state: Arc<RwLock<ClientState>>,
 28}
 29
 30pub struct ClientState {
 31    connection_id: Option<ConnectionId>,
 32    pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
 33    pub languages: Arc<LanguageRegistry>,
 34}
 35
 36impl ClientState {
 37    pub fn shared_worktree(
 38        &self,
 39        id: u64,
 40        cx: &mut AsyncAppContext,
 41    ) -> Result<ModelHandle<Worktree>> {
 42        if let Some(worktree) = self.shared_worktrees.get(&id) {
 43            if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
 44                Ok(worktree)
 45            } else {
 46                Err(anyhow!("worktree {} was dropped", id))
 47            }
 48        } else {
 49            Err(anyhow!("worktree {} does not exist", id))
 50        }
 51    }
 52}
 53
 54impl Client {
 55    pub fn new(languages: Arc<LanguageRegistry>) -> Self {
 56        Self {
 57            peer: Peer::new(),
 58            state: Arc::new(RwLock::new(ClientState {
 59                connection_id: None,
 60                shared_worktrees: Default::default(),
 61                languages,
 62            })),
 63        }
 64    }
 65
 66    pub fn on_message<H, M>(
 67        &self,
 68        router: &mut ForegroundRouter,
 69        handler: H,
 70        cx: &mut gpui::MutableAppContext,
 71    ) where
 72        H: 'static + Clone + for<'a> MessageHandler<'a, M>,
 73        M: proto::EnvelopedMessage,
 74    {
 75        let this = self.clone();
 76        let cx = cx.to_async();
 77        router.add_message_handler(move |message| {
 78            let this = this.clone();
 79            let mut cx = cx.clone();
 80            let handler = handler.clone();
 81            async move { handler.handle(message, &this, &mut cx).await }
 82        });
 83    }
 84
 85    pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
 86        self.peer.subscribe()
 87    }
 88
 89    pub async fn log_in_and_connect(
 90        &self,
 91        router: Arc<ForegroundRouter>,
 92        cx: AsyncAppContext,
 93    ) -> surf::Result<()> {
 94        if self.state.read().await.connection_id.is_some() {
 95            return Ok(());
 96        }
 97
 98        let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
 99        let user_id: i32 = user_id.parse()?;
100        let request =
101            Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
102
103        if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
104            let stream = smol::net::TcpStream::connect(host).await?;
105            let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
106            let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
107                .await
108                .context("websocket handshake")?;
109            log::info!("connected to rpc address {}", *ZED_SERVER_URL);
110            self.add_connection(stream, router, cx).await?;
111        } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
112            let stream = smol::net::TcpStream::connect(host).await?;
113            let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
114            let (stream, _) = async_tungstenite::client_async(request, stream).await?;
115            log::info!("connected to rpc address {}", *ZED_SERVER_URL);
116            self.add_connection(stream, router, cx).await?;
117        } else {
118            return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
119        };
120
121        Ok(())
122    }
123
124    pub async fn add_connection<Conn>(
125        &self,
126        conn: Conn,
127        router: Arc<ForegroundRouter>,
128        cx: AsyncAppContext,
129    ) -> surf::Result<()>
130    where
131        Conn: 'static
132            + futures::Sink<WebSocketMessage, Error = WebSocketError>
133            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
134            + Unpin
135            + Send,
136    {
137        let (connection_id, handle_io, handle_messages) =
138            self.peer.add_connection(conn, router).await;
139        cx.foreground().spawn(handle_messages).detach();
140        cx.background()
141            .spawn(async move {
142                if let Err(error) = handle_io.await {
143                    log::error!("connection error: {:?}", error);
144                }
145            })
146            .detach();
147        self.state.write().await.connection_id = Some(connection_id);
148        Ok(())
149    }
150
151    pub fn login(
152        platform: Arc<dyn gpui::Platform>,
153        executor: &Arc<gpui::executor::Background>,
154    ) -> Task<Result<(String, String)>> {
155        let executor = executor.clone();
156        executor.clone().spawn(async move {
157            if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
158                log::info!("already signed in. user_id: {}", user_id);
159                return Ok((user_id, String::from_utf8(access_token).unwrap()));
160            }
161
162            // Generate a pair of asymmetric encryption keys. The public key will be used by the
163            // zed server to encrypt the user's access token, so that it can'be intercepted by
164            // any other app running on the user's device.
165            let (public_key, private_key) =
166                zrpc::auth::keypair().expect("failed to generate keypair for auth");
167            let public_key_string =
168                String::try_from(public_key).expect("failed to serialize public key for auth");
169
170            // Start an HTTP server to receive the redirect from Zed's sign-in page.
171            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
172            let port = server.server_addr().port();
173
174            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
175            // that the user is signing in from a Zed app running on the same device.
176            platform.open_url(&format!(
177                "{}/sign_in?native_app_port={}&native_app_public_key={}",
178                *ZED_SERVER_URL, port, public_key_string
179            ));
180
181            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
182            // access token from the query params.
183            //
184            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
185            // custom URL scheme instead of this local HTTP server.
186            let (user_id, access_token) = executor
187                .spawn(async move {
188                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
189                        let path = req.url();
190                        let mut user_id = None;
191                        let mut access_token = None;
192                        let url = Url::parse(&format!("http://example.com{}", path))
193                            .context("failed to parse login notification url")?;
194                        for (key, value) in url.query_pairs() {
195                            if key == "access_token" {
196                                access_token = Some(value.to_string());
197                            } else if key == "user_id" {
198                                user_id = Some(value.to_string());
199                            }
200                        }
201                        req.respond(
202                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
203                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
204                            ),
205                        )
206                        .context("failed to respond to login http request")?;
207                        Ok((
208                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
209                            access_token
210                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
211                        ))
212                    } else {
213                        Err(anyhow!("didn't receive login redirect"))
214                    }
215                })
216                .await?;
217
218            let access_token = private_key
219                .decrypt_string(&access_token)
220                .context("failed to decrypt access token")?;
221            platform.activate(true);
222            platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
223            Ok((user_id.to_string(), access_token))
224        })
225    }
226
227    pub async fn disconnect(&self) -> Result<()> {
228        let conn_id = self.connection_id().await?;
229        self.peer.disconnect(conn_id).await;
230        Ok(())
231    }
232
233    async fn connection_id(&self) -> Result<ConnectionId> {
234        self.state
235            .read()
236            .await
237            .connection_id
238            .ok_or_else(|| anyhow!("not connected"))
239    }
240
241    pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
242        self.peer.send(self.connection_id().await?, message).await
243    }
244
245    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
246        self.peer
247            .request(self.connection_id().await?, request)
248            .await
249    }
250
251    pub fn respond<T: RequestMessage>(
252        &self,
253        receipt: Receipt<T>,
254        response: T::Response,
255    ) -> impl Future<Output = Result<()>> {
256        self.peer.respond(receipt, response)
257    }
258}
259
260pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
261    type Output: 'a + Future<Output = anyhow::Result<()>>;
262
263    fn handle(
264        &self,
265        message: TypedEnvelope<M>,
266        rpc: &'a Client,
267        cx: &'a mut gpui::AsyncAppContext,
268    ) -> Self::Output;
269}
270
271impl<'a, M, F, Fut> MessageHandler<'a, M> for F
272where
273    M: proto::EnvelopedMessage,
274    F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
275    Fut: 'a + Future<Output = anyhow::Result<()>>,
276{
277    type Output = Fut;
278
279    fn handle(
280        &self,
281        message: TypedEnvelope<M>,
282        rpc: &'a Client,
283        cx: &'a mut gpui::AsyncAppContext,
284    ) -> Self::Output {
285        (self)(message, rpc, cx)
286    }
287}
288
289const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
290
291pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
292    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
293}
294
295pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
296    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
297    let mut parts = path.split('/');
298    let id = parts.next()?.parse::<u64>().ok()?;
299    let access_token = parts.next()?;
300    if access_token.is_empty() {
301        return None;
302    }
303    Some((id, access_token.to_string()))
304}
305
306const LOGIN_RESPONSE: &'static str = "
307<!DOCTYPE html>
308<html>
309<script>window.close();</script>
310</html>
311";
312
313#[test]
314fn test_encode_and_decode_worktree_url() {
315    let url = encode_worktree_url(5, "deadbeef");
316    assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
317    assert_eq!(
318        decode_worktree_url(&format!("\n {}\t", url)),
319        Some((5, "deadbeef".to_string()))
320    );
321    assert_eq!(decode_worktree_url("not://the-right-format"), None);
322}