rpc.rs

  1use crate::worktree::{File, Worktree};
  2
  3use super::util::SurfResultExt as _;
  4use anyhow::{anyhow, Context, Result};
  5use gpui::executor::Background;
  6use gpui::{AsyncAppContext, ModelHandle, Task};
  7use lazy_static::lazy_static;
  8use postage::prelude::Stream;
  9use smol::lock::Mutex;
 10use std::collections::HashMap;
 11use std::time::Duration;
 12use std::{convert::TryFrom, future::Future, sync::Arc};
 13use surf::Url;
 14use zed_rpc::{
 15    proto::{EnvelopedMessage, RequestMessage},
 16    rest, Peer, Receipt, TypedEnvelope,
 17};
 18
 19pub use zed_rpc::{proto, ConnectionId, PeerId};
 20
 21lazy_static! {
 22    static ref ZED_SERVER_URL: String =
 23        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
 24}
 25
 26#[derive(Clone)]
 27pub struct Client {
 28    peer: Arc<Peer>,
 29    pub state: Arc<Mutex<ClientState>>,
 30}
 31
 32#[derive(Default)]
 33pub struct ClientState {
 34    connection_id: Option<ConnectionId>,
 35    pub shared_worktrees: HashMap<u64, ModelHandle<Worktree>>,
 36    pub shared_files: HashMap<File, HashMap<PeerId, usize>>,
 37}
 38
 39impl Client {
 40    pub fn new() -> Self {
 41        Self {
 42            peer: Peer::new(),
 43            state: Default::default(),
 44        }
 45    }
 46
 47    pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
 48    where
 49        H: 'static + for<'a> MessageHandler<'a, M>,
 50        M: proto::EnvelopedMessage,
 51    {
 52        let this = self.clone();
 53        let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
 54        cx.spawn(|mut cx| async move {
 55            while let Some(message) = messages.recv().await {
 56                if let Err(err) = handler.handle(message, &this, &mut cx).await {
 57                    log::error!("error handling message: {:?}", err);
 58                }
 59            }
 60        })
 61        .detach();
 62    }
 63
 64    pub async fn connect_to_server(
 65        &self,
 66        cx: &AsyncAppContext,
 67        executor: &Arc<Background>,
 68    ) -> surf::Result<ConnectionId> {
 69        if let Some(connection_id) = self.state.lock().await.connection_id {
 70            return Ok(connection_id);
 71        }
 72
 73        let (user_id, access_token) = Self::login(cx.platform(), executor).await?;
 74
 75        let mut response = surf::get(format!(
 76            "{}{}",
 77            *ZED_SERVER_URL,
 78            &rest::GET_RPC_ADDRESS_PATH
 79        ))
 80        .header(
 81            "Authorization",
 82            http_auth_basic::Credentials::new(&user_id, &access_token).as_http_header(),
 83        )
 84        .await
 85        .context("rpc address request failed")?;
 86
 87        let rest::GetRpcAddressResponse { address } = response
 88            .body_json()
 89            .await
 90            .context("failed to parse rpc address response")?;
 91
 92        // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
 93        // a TLS stream using `native-tls`.
 94        let stream = smol::net::TcpStream::connect(&address).await?;
 95        log::info!("connected to rpc address {}", address);
 96
 97        let connection_id = self.peer.add_connection(stream).await;
 98        executor
 99            .spawn(self.peer.handle_messages(connection_id))
100            .detach();
101
102        let auth_response = self
103            .peer
104            .request(
105                connection_id,
106                proto::Auth {
107                    user_id: user_id.parse()?,
108                    access_token,
109                },
110            )
111            .await
112            .context("rpc auth request failed")?;
113        if !auth_response.credentials_valid {
114            Err(anyhow!("failed to authenticate with RPC server"))?;
115        }
116
117        Ok(connection_id)
118    }
119
120    pub fn login(
121        platform: Arc<dyn gpui::Platform>,
122        executor: &Arc<gpui::executor::Background>,
123    ) -> Task<Result<(String, String)>> {
124        let executor = executor.clone();
125        executor.clone().spawn(async move {
126            if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
127                log::info!("already signed in. user_id: {}", user_id);
128                return Ok((user_id, String::from_utf8(access_token).unwrap()));
129            }
130
131            // Generate a pair of asymmetric encryption keys. The public key will be used by the
132            // zed server to encrypt the user's access token, so that it can'be intercepted by
133            // any other app running on the user's device.
134            let (public_key, private_key) =
135                zed_rpc::auth::keypair().expect("failed to generate keypair for auth");
136            let public_key_string =
137                String::try_from(public_key).expect("failed to serialize public key for auth");
138
139            // Start an HTTP server to receive the redirect from Zed's sign-in page.
140            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
141            let port = server.server_addr().port();
142
143            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
144            // that the user is signing in from a Zed app running on the same device.
145            platform.open_url(&format!(
146                "{}/sign_in?native_app_port={}&native_app_public_key={}",
147                *ZED_SERVER_URL, port, public_key_string
148            ));
149
150            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
151            // access token from the query params.
152            //
153            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
154            // custom URL scheme instead of this local HTTP server.
155            let (user_id, access_token) = executor
156                .spawn(async move {
157                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
158                        let path = req.url();
159                        let mut user_id = None;
160                        let mut access_token = None;
161                        let url = Url::parse(&format!("http://example.com{}", path))
162                            .context("failed to parse login notification url")?;
163                        for (key, value) in url.query_pairs() {
164                            if key == "access_token" {
165                                access_token = Some(value.to_string());
166                            } else if key == "user_id" {
167                                user_id = Some(value.to_string());
168                            }
169                        }
170                        req.respond(
171                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
172                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
173                            ),
174                        )
175                        .context("failed to respond to login http request")?;
176                        Ok((
177                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
178                            access_token
179                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
180                        ))
181                    } else {
182                        Err(anyhow!("didn't receive login redirect"))
183                    }
184                })
185                .await?;
186
187            let access_token = private_key
188                .decrypt_string(&access_token)
189                .context("failed to decrypt access token")?;
190            platform.activate(true);
191            platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
192            Ok((user_id.to_string(), access_token))
193        })
194    }
195
196    pub fn send<T: EnvelopedMessage>(
197        &self,
198        connection_id: ConnectionId,
199        message: T,
200    ) -> impl Future<Output = Result<()>> {
201        self.peer.send(connection_id, message)
202    }
203
204    pub fn request<T: RequestMessage>(
205        &self,
206        connection_id: ConnectionId,
207        request: T,
208    ) -> impl Future<Output = Result<T::Response>> {
209        self.peer.request(connection_id, request)
210    }
211
212    pub fn respond<T: RequestMessage>(
213        &self,
214        receipt: Receipt<T>,
215        response: T::Response,
216    ) -> impl Future<Output = Result<()>> {
217        self.peer.respond(receipt, response)
218    }
219}
220
221pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
222    type Output: 'a + Future<Output = anyhow::Result<()>>;
223
224    fn handle(
225        &self,
226        message: TypedEnvelope<M>,
227        rpc: &'a Client,
228        cx: &'a mut gpui::AsyncAppContext,
229    ) -> Self::Output;
230}
231
232impl<'a, M, F, Fut> MessageHandler<'a, M> for F
233where
234    M: proto::EnvelopedMessage,
235    F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
236    Fut: 'a + Future<Output = anyhow::Result<()>>,
237{
238    type Output = Fut;
239
240    fn handle(
241        &self,
242        message: TypedEnvelope<M>,
243        rpc: &'a Client,
244        cx: &'a mut gpui::AsyncAppContext,
245    ) -> Self::Output {
246        (self)(message, rpc, cx)
247    }
248}
249
250const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
251
252pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
253    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
254}
255
256pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
257    let path = url.strip_prefix(WORKTREE_URL_PREFIX)?;
258    let mut parts = path.split('/');
259    let id = parts.next()?.parse::<u64>().ok()?;
260    let access_token = parts.next()?;
261    if access_token.is_empty() {
262        return None;
263    }
264    Some((id, access_token.to_string()))
265}
266
267const LOGIN_RESPONSE: &'static str = "
268<!DOCTYPE html>
269<html>
270<script>window.close();</script>
271</html>
272";
273
274#[test]
275fn test_encode_and_decode_worktree_url() {
276    let url = encode_worktree_url(5, "deadbeef");
277    assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
278    assert_eq!(decode_worktree_url("not://the-right-format"), None);
279}