rpc.rs

  1use super::util::SurfResultExt as _;
  2use crate::{language::LanguageRegistry, worktree::Worktree};
  3use anyhow::{anyhow, Context, Result};
  4use gpui::executor::Background;
  5use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
  6use lazy_static::lazy_static;
  7use postage::prelude::Stream;
  8use smol::lock::RwLock;
  9use std::collections::HashMap;
 10use std::time::Duration;
 11use std::{convert::TryFrom, future::Future, sync::Arc};
 12use surf::Url;
 13pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope};
 14use zed_rpc::{
 15    proto::{EnvelopedMessage, RequestMessage},
 16    rest, Peer, Receipt,
 17};
 18
 19lazy_static! {
 20    static ref ZED_SERVER_URL: String =
 21        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
 22}
 23
 24#[derive(Clone)]
 25pub struct Client {
 26    peer: Arc<Peer>,
 27    pub state: Arc<RwLock<ClientState>>,
 28}
 29
 30pub struct ClientState {
 31    connection_id: Option<ConnectionId>,
 32    pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
 33    pub languages: Arc<LanguageRegistry>,
 34}
 35
 36impl ClientState {
 37    pub fn shared_worktree(
 38        &self,
 39        id: u64,
 40        cx: &mut AsyncAppContext,
 41    ) -> Result<ModelHandle<Worktree>> {
 42        if let Some(worktree) = self.shared_worktrees.get(&id) {
 43            if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
 44                Ok(worktree)
 45            } else {
 46                Err(anyhow!("worktree {} was dropped", id))
 47            }
 48        } else {
 49            Err(anyhow!("worktree {} does not exist", id))
 50        }
 51    }
 52}
 53
 54impl Client {
 55    pub fn new(languages: Arc<LanguageRegistry>) -> Self {
 56        Self {
 57            peer: Peer::new(),
 58            state: Arc::new(RwLock::new(ClientState {
 59                connection_id: None,
 60                shared_worktrees: Default::default(),
 61                languages,
 62            })),
 63        }
 64    }
 65
 66    pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
 67    where
 68        H: 'static + for<'a> MessageHandler<'a, M>,
 69        M: proto::EnvelopedMessage,
 70    {
 71        let this = self.clone();
 72        let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
 73        cx.spawn(|mut cx| async move {
 74            while let Some(message) = messages.recv().await {
 75                if let Err(err) = handler.handle(message, &this, &mut cx).await {
 76                    log::error!("error handling message: {:?}", err);
 77                }
 78            }
 79        })
 80        .detach();
 81    }
 82
 83    pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
 84        if self.state.read().await.connection_id.is_some() {
 85            return Ok(());
 86        }
 87
 88        let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
 89
 90        let mut response = surf::get(format!(
 91            "{}{}",
 92            *ZED_SERVER_URL,
 93            &rest::GET_RPC_ADDRESS_PATH
 94        ))
 95        .header(
 96            "Authorization",
 97            http_auth_basic::Credentials::new(&user_id, &access_token).as_http_header(),
 98        )
 99        .await
100        .context("rpc address request failed")?;
101
102        let rest::GetRpcAddressResponse { address } = response
103            .body_json()
104            .await
105            .context("failed to parse rpc address response")?;
106
107        self.connect(&address, user_id.parse()?, access_token, &cx.background())
108            .await
109    }
110
111    pub async fn connect(
112        &self,
113        address: &str,
114        user_id: i32,
115        access_token: String,
116        executor: &Arc<Background>,
117    ) -> surf::Result<()> {
118        // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
119        // a TLS stream using `native-tls`.
120        let stream = smol::net::TcpStream::connect(&address).await?;
121        log::info!("connected to rpc address {}", address);
122
123        let (connection_id, handler) = self.peer.add_connection(stream).await;
124        executor.spawn(handler.run()).detach();
125
126        let auth_response = self
127            .peer
128            .request(
129                connection_id,
130                proto::Auth {
131                    user_id,
132                    access_token,
133                },
134            )
135            .await
136            .context("rpc auth request failed")?;
137        if !auth_response.credentials_valid {
138            Err(anyhow!("failed to authenticate with RPC server"))?;
139        }
140
141        self.state.write().await.connection_id = Some(connection_id);
142        Ok(())
143    }
144
145    pub fn login(
146        platform: Arc<dyn gpui::Platform>,
147        executor: &Arc<gpui::executor::Background>,
148    ) -> Task<Result<(String, String)>> {
149        let executor = executor.clone();
150        executor.clone().spawn(async move {
151            if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
152                log::info!("already signed in. user_id: {}", user_id);
153                return Ok((user_id, String::from_utf8(access_token).unwrap()));
154            }
155
156            // Generate a pair of asymmetric encryption keys. The public key will be used by the
157            // zed server to encrypt the user's access token, so that it can'be intercepted by
158            // any other app running on the user's device.
159            let (public_key, private_key) =
160                zed_rpc::auth::keypair().expect("failed to generate keypair for auth");
161            let public_key_string =
162                String::try_from(public_key).expect("failed to serialize public key for auth");
163
164            // Start an HTTP server to receive the redirect from Zed's sign-in page.
165            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
166            let port = server.server_addr().port();
167
168            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
169            // that the user is signing in from a Zed app running on the same device.
170            platform.open_url(&format!(
171                "{}/sign_in?native_app_port={}&native_app_public_key={}",
172                *ZED_SERVER_URL, port, public_key_string
173            ));
174
175            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
176            // access token from the query params.
177            //
178            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
179            // custom URL scheme instead of this local HTTP server.
180            let (user_id, access_token) = executor
181                .spawn(async move {
182                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
183                        let path = req.url();
184                        let mut user_id = None;
185                        let mut access_token = None;
186                        let url = Url::parse(&format!("http://example.com{}", path))
187                            .context("failed to parse login notification url")?;
188                        for (key, value) in url.query_pairs() {
189                            if key == "access_token" {
190                                access_token = Some(value.to_string());
191                            } else if key == "user_id" {
192                                user_id = Some(value.to_string());
193                            }
194                        }
195                        req.respond(
196                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
197                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
198                            ),
199                        )
200                        .context("failed to respond to login http request")?;
201                        Ok((
202                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
203                            access_token
204                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
205                        ))
206                    } else {
207                        Err(anyhow!("didn't receive login redirect"))
208                    }
209                })
210                .await?;
211
212            let access_token = private_key
213                .decrypt_string(&access_token)
214                .context("failed to decrypt access token")?;
215            platform.activate(true);
216            platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
217            Ok((user_id.to_string(), access_token))
218        })
219    }
220
221    pub async fn disconnect(&self) -> Result<()> {
222        let conn_id = self.connection_id().await?;
223        self.peer.disconnect(conn_id).await;
224        Ok(())
225    }
226
227    async fn connection_id(&self) -> Result<ConnectionId> {
228        self.state
229            .read()
230            .await
231            .connection_id
232            .ok_or_else(|| anyhow!("not connected"))
233    }
234
235    pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
236        self.peer.send(self.connection_id().await?, message).await
237    }
238
239    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
240        self.peer
241            .request(self.connection_id().await?, request)
242            .await
243    }
244
245    pub fn respond<T: RequestMessage>(
246        &self,
247        receipt: Receipt<T>,
248        response: T::Response,
249    ) -> impl Future<Output = Result<()>> {
250        self.peer.respond(receipt, response)
251    }
252}
253
254pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
255    type Output: 'a + Future<Output = anyhow::Result<()>>;
256
257    fn handle(
258        &self,
259        message: TypedEnvelope<M>,
260        rpc: &'a Client,
261        cx: &'a mut gpui::AsyncAppContext,
262    ) -> Self::Output;
263}
264
265impl<'a, M, F, Fut> MessageHandler<'a, M> for F
266where
267    M: proto::EnvelopedMessage,
268    F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
269    Fut: 'a + Future<Output = anyhow::Result<()>>,
270{
271    type Output = Fut;
272
273    fn handle(
274        &self,
275        message: TypedEnvelope<M>,
276        rpc: &'a Client,
277        cx: &'a mut gpui::AsyncAppContext,
278    ) -> Self::Output {
279        (self)(message, rpc, cx)
280    }
281}
282
283const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
284
285pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
286    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
287}
288
289pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
290    let path = url.strip_prefix(WORKTREE_URL_PREFIX)?;
291    let mut parts = path.split('/');
292    let id = parts.next()?.parse::<u64>().ok()?;
293    let access_token = parts.next()?;
294    if access_token.is_empty() {
295        return None;
296    }
297    Some((id, access_token.to_string()))
298}
299
300const LOGIN_RESPONSE: &'static str = "
301<!DOCTYPE html>
302<html>
303<script>window.close();</script>
304</html>
305";
306
307#[test]
308fn test_encode_and_decode_worktree_url() {
309    let url = encode_worktree_url(5, "deadbeef");
310    assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
311    assert_eq!(decode_worktree_url("not://the-right-format"), None);
312}