rpc.rs

  1use super::util::SurfResultExt as _;
  2use crate::{language::LanguageRegistry, worktree::Worktree};
  3use anyhow::{anyhow, Context, Result};
  4use gpui::executor::Background;
  5use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
  6use lazy_static::lazy_static;
  7use postage::prelude::Stream;
  8use smol::lock::Mutex;
  9use std::collections::HashMap;
 10use std::time::Duration;
 11use std::{convert::TryFrom, future::Future, sync::Arc};
 12use surf::Url;
 13pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope};
 14use zed_rpc::{
 15    proto::{EnvelopedMessage, RequestMessage},
 16    rest, Peer, Receipt,
 17};
 18
 19lazy_static! {
 20    static ref ZED_SERVER_URL: String =
 21        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
 22}
 23
 24#[derive(Clone)]
 25pub struct Client {
 26    peer: Arc<Peer>,
 27    pub state: Arc<Mutex<ClientState>>,
 28}
 29
 30pub struct ClientState {
 31    connection_id: Option<ConnectionId>,
 32    pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
 33    pub languages: Arc<LanguageRegistry>,
 34}
 35
 36impl ClientState {
 37    pub fn shared_worktree(
 38        &mut self,
 39        id: u64,
 40        cx: &mut AsyncAppContext,
 41    ) -> Result<ModelHandle<Worktree>> {
 42        if let Some(worktree) = self.shared_worktrees.get(&id) {
 43            if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
 44                Ok(worktree)
 45            } else {
 46                self.shared_worktrees.remove(&id);
 47                Err(anyhow!("worktree {} was dropped", id))
 48            }
 49        } else {
 50            Err(anyhow!("worktree {} does not exist", id))
 51        }
 52    }
 53}
 54
 55impl Client {
 56    pub fn new(languages: Arc<LanguageRegistry>) -> Self {
 57        Self {
 58            peer: Peer::new(),
 59            state: Arc::new(Mutex::new(ClientState {
 60                connection_id: None,
 61                shared_worktrees: Default::default(),
 62                languages,
 63            })),
 64        }
 65    }
 66
 67    pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
 68    where
 69        H: 'static + for<'a> MessageHandler<'a, M>,
 70        M: proto::EnvelopedMessage,
 71    {
 72        let this = self.clone();
 73        let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
 74        cx.spawn(|mut cx| async move {
 75            while let Some(message) = messages.recv().await {
 76                if let Err(err) = handler.handle(message, &this, &mut cx).await {
 77                    log::error!("error handling message: {:?}", err);
 78                }
 79            }
 80        })
 81        .detach();
 82    }
 83
 84    pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
 85        if self.state.lock().await.connection_id.is_some() {
 86            return Ok(());
 87        }
 88
 89        let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
 90
 91        let mut response = surf::get(format!(
 92            "{}{}",
 93            *ZED_SERVER_URL,
 94            &rest::GET_RPC_ADDRESS_PATH
 95        ))
 96        .header(
 97            "Authorization",
 98            http_auth_basic::Credentials::new(&user_id, &access_token).as_http_header(),
 99        )
100        .await
101        .context("rpc address request failed")?;
102
103        let rest::GetRpcAddressResponse { address } = response
104            .body_json()
105            .await
106            .context("failed to parse rpc address response")?;
107
108        self.connect(&address, user_id.parse()?, access_token, &cx.background())
109            .await
110    }
111
112    pub async fn connect(
113        &self,
114        address: &str,
115        user_id: i32,
116        access_token: String,
117        executor: &Arc<Background>,
118    ) -> surf::Result<()> {
119        // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
120        // a TLS stream using `native-tls`.
121        let stream = smol::net::TcpStream::connect(&address).await?;
122        log::info!("connected to rpc address {}", address);
123
124        let (connection_id, handler) = self.peer.add_connection(stream).await;
125        executor.spawn(handler.run()).detach();
126
127        let auth_response = self
128            .peer
129            .request(
130                connection_id,
131                proto::Auth {
132                    user_id,
133                    access_token,
134                },
135            )
136            .await
137            .context("rpc auth request failed")?;
138        if !auth_response.credentials_valid {
139            Err(anyhow!("failed to authenticate with RPC server"))?;
140        }
141
142        self.state.lock().await.connection_id = Some(connection_id);
143        Ok(())
144    }
145
146    pub fn login(
147        platform: Arc<dyn gpui::Platform>,
148        executor: &Arc<gpui::executor::Background>,
149    ) -> Task<Result<(String, String)>> {
150        let executor = executor.clone();
151        executor.clone().spawn(async move {
152            if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
153                log::info!("already signed in. user_id: {}", user_id);
154                return Ok((user_id, String::from_utf8(access_token).unwrap()));
155            }
156
157            // Generate a pair of asymmetric encryption keys. The public key will be used by the
158            // zed server to encrypt the user's access token, so that it can'be intercepted by
159            // any other app running on the user's device.
160            let (public_key, private_key) =
161                zed_rpc::auth::keypair().expect("failed to generate keypair for auth");
162            let public_key_string =
163                String::try_from(public_key).expect("failed to serialize public key for auth");
164
165            // Start an HTTP server to receive the redirect from Zed's sign-in page.
166            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
167            let port = server.server_addr().port();
168
169            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
170            // that the user is signing in from a Zed app running on the same device.
171            platform.open_url(&format!(
172                "{}/sign_in?native_app_port={}&native_app_public_key={}",
173                *ZED_SERVER_URL, port, public_key_string
174            ));
175
176            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
177            // access token from the query params.
178            //
179            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
180            // custom URL scheme instead of this local HTTP server.
181            let (user_id, access_token) = executor
182                .spawn(async move {
183                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
184                        let path = req.url();
185                        let mut user_id = None;
186                        let mut access_token = None;
187                        let url = Url::parse(&format!("http://example.com{}", path))
188                            .context("failed to parse login notification url")?;
189                        for (key, value) in url.query_pairs() {
190                            if key == "access_token" {
191                                access_token = Some(value.to_string());
192                            } else if key == "user_id" {
193                                user_id = Some(value.to_string());
194                            }
195                        }
196                        req.respond(
197                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
198                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
199                            ),
200                        )
201                        .context("failed to respond to login http request")?;
202                        Ok((
203                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
204                            access_token
205                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
206                        ))
207                    } else {
208                        Err(anyhow!("didn't receive login redirect"))
209                    }
210                })
211                .await?;
212
213            let access_token = private_key
214                .decrypt_string(&access_token)
215                .context("failed to decrypt access token")?;
216            platform.activate(true);
217            platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
218            Ok((user_id.to_string(), access_token))
219        })
220    }
221
222    async fn connection_id(&self) -> Result<ConnectionId> {
223        self.state
224            .lock()
225            .await
226            .connection_id
227            .ok_or_else(|| anyhow!("not connected"))
228    }
229
230    pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
231        self.peer.send(self.connection_id().await?, message).await
232    }
233
234    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
235        self.peer
236            .request(self.connection_id().await?, request)
237            .await
238    }
239
240    pub fn respond<T: RequestMessage>(
241        &self,
242        receipt: Receipt<T>,
243        response: T::Response,
244    ) -> impl Future<Output = Result<()>> {
245        self.peer.respond(receipt, response)
246    }
247}
248
249pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
250    type Output: 'a + Future<Output = anyhow::Result<()>>;
251
252    fn handle(
253        &self,
254        message: TypedEnvelope<M>,
255        rpc: &'a Client,
256        cx: &'a mut gpui::AsyncAppContext,
257    ) -> Self::Output;
258}
259
260impl<'a, M, F, Fut> MessageHandler<'a, M> for F
261where
262    M: proto::EnvelopedMessage,
263    F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
264    Fut: 'a + Future<Output = anyhow::Result<()>>,
265{
266    type Output = Fut;
267
268    fn handle(
269        &self,
270        message: TypedEnvelope<M>,
271        rpc: &'a Client,
272        cx: &'a mut gpui::AsyncAppContext,
273    ) -> Self::Output {
274        (self)(message, rpc, cx)
275    }
276}
277
278const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
279
280pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
281    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
282}
283
284pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
285    let path = url.strip_prefix(WORKTREE_URL_PREFIX)?;
286    let mut parts = path.split('/');
287    let id = parts.next()?.parse::<u64>().ok()?;
288    let access_token = parts.next()?;
289    if access_token.is_empty() {
290        return None;
291    }
292    Some((id, access_token.to_string()))
293}
294
295const LOGIN_RESPONSE: &'static str = "
296<!DOCTYPE html>
297<html>
298<script>window.close();</script>
299</html>
300";
301
302#[test]
303fn test_encode_and_decode_worktree_url() {
304    let url = encode_worktree_url(5, "deadbeef");
305    assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
306    assert_eq!(decode_worktree_url("not://the-right-format"), None);
307}