1use super::util::SurfResultExt as _;
2use crate::{language::LanguageRegistry, worktree::Worktree};
3use anyhow::{anyhow, Context, Result};
4use gpui::executor::Background;
5use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
6use lazy_static::lazy_static;
7use postage::prelude::Stream;
8use smol::lock::RwLock;
9use std::collections::HashMap;
10use std::time::Duration;
11use std::{convert::TryFrom, future::Future, sync::Arc};
12use surf::Url;
13pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope};
14use zed_rpc::{
15 proto::{EnvelopedMessage, RequestMessage},
16 rest, Peer, Receipt,
17};
18
19lazy_static! {
20 static ref ZED_SERVER_URL: String =
21 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
22}
23
24#[derive(Clone)]
25pub struct Client {
26 peer: Arc<Peer>,
27 pub state: Arc<RwLock<ClientState>>,
28}
29
30pub struct ClientState {
31 connection_id: Option<ConnectionId>,
32 pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
33 pub languages: Arc<LanguageRegistry>,
34}
35
36impl ClientState {
37 pub fn shared_worktree(
38 &self,
39 id: u64,
40 cx: &mut AsyncAppContext,
41 ) -> Result<ModelHandle<Worktree>> {
42 if let Some(worktree) = self.shared_worktrees.get(&id) {
43 if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
44 Ok(worktree)
45 } else {
46 Err(anyhow!("worktree {} was dropped", id))
47 }
48 } else {
49 Err(anyhow!("worktree {} does not exist", id))
50 }
51 }
52}
53
54impl Client {
55 pub fn new(languages: Arc<LanguageRegistry>) -> Self {
56 Self {
57 peer: Peer::new(),
58 state: Arc::new(RwLock::new(ClientState {
59 connection_id: None,
60 shared_worktrees: Default::default(),
61 languages,
62 })),
63 }
64 }
65
66 pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
67 where
68 H: 'static + for<'a> MessageHandler<'a, M>,
69 M: proto::EnvelopedMessage,
70 {
71 let this = self.clone();
72 let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
73 cx.spawn(|mut cx| async move {
74 while let Some(message) = messages.recv().await {
75 if let Err(err) = handler.handle(message, &this, &mut cx).await {
76 log::error!("error handling message: {:?}", err);
77 }
78 }
79 })
80 .detach();
81 }
82
83 pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
84 if self.state.read().await.connection_id.is_some() {
85 return Ok(());
86 }
87
88 let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
89
90 let mut response = surf::get(format!(
91 "{}{}",
92 *ZED_SERVER_URL,
93 &rest::GET_RPC_ADDRESS_PATH
94 ))
95 .header(
96 "Authorization",
97 http_auth_basic::Credentials::new(&user_id, &access_token).as_http_header(),
98 )
99 .await
100 .context("rpc address request failed")?;
101
102 let rest::GetRpcAddressResponse { address } = response
103 .body_json()
104 .await
105 .context("failed to parse rpc address response")?;
106
107 self.connect(&address, user_id.parse()?, access_token, &cx.background())
108 .await
109 }
110
111 pub async fn connect(
112 &self,
113 address: &str,
114 user_id: i32,
115 access_token: String,
116 executor: &Arc<Background>,
117 ) -> surf::Result<()> {
118 // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
119 // a TLS stream using `native-tls`.
120 let stream = smol::net::TcpStream::connect(&address).await?;
121 log::info!("connected to rpc address {}", address);
122
123 let (connection_id, handler) = self.peer.add_connection(stream).await;
124 executor.spawn(handler.run()).detach();
125
126 let auth_response = self
127 .peer
128 .request(
129 connection_id,
130 proto::Auth {
131 user_id,
132 access_token,
133 },
134 )
135 .await
136 .context("rpc auth request failed")?;
137 if !auth_response.credentials_valid {
138 Err(anyhow!("failed to authenticate with RPC server"))?;
139 }
140
141 self.state.write().await.connection_id = Some(connection_id);
142 Ok(())
143 }
144
145 pub fn login(
146 platform: Arc<dyn gpui::Platform>,
147 executor: &Arc<gpui::executor::Background>,
148 ) -> Task<Result<(String, String)>> {
149 let executor = executor.clone();
150 executor.clone().spawn(async move {
151 if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
152 log::info!("already signed in. user_id: {}", user_id);
153 return Ok((user_id, String::from_utf8(access_token).unwrap()));
154 }
155
156 // Generate a pair of asymmetric encryption keys. The public key will be used by the
157 // zed server to encrypt the user's access token, so that it can'be intercepted by
158 // any other app running on the user's device.
159 let (public_key, private_key) =
160 zed_rpc::auth::keypair().expect("failed to generate keypair for auth");
161 let public_key_string =
162 String::try_from(public_key).expect("failed to serialize public key for auth");
163
164 // Start an HTTP server to receive the redirect from Zed's sign-in page.
165 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
166 let port = server.server_addr().port();
167
168 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
169 // that the user is signing in from a Zed app running on the same device.
170 platform.open_url(&format!(
171 "{}/sign_in?native_app_port={}&native_app_public_key={}",
172 *ZED_SERVER_URL, port, public_key_string
173 ));
174
175 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
176 // access token from the query params.
177 //
178 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
179 // custom URL scheme instead of this local HTTP server.
180 let (user_id, access_token) = executor
181 .spawn(async move {
182 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
183 let path = req.url();
184 let mut user_id = None;
185 let mut access_token = None;
186 let url = Url::parse(&format!("http://example.com{}", path))
187 .context("failed to parse login notification url")?;
188 for (key, value) in url.query_pairs() {
189 if key == "access_token" {
190 access_token = Some(value.to_string());
191 } else if key == "user_id" {
192 user_id = Some(value.to_string());
193 }
194 }
195 req.respond(
196 tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
197 tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
198 ),
199 )
200 .context("failed to respond to login http request")?;
201 Ok((
202 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
203 access_token
204 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
205 ))
206 } else {
207 Err(anyhow!("didn't receive login redirect"))
208 }
209 })
210 .await?;
211
212 let access_token = private_key
213 .decrypt_string(&access_token)
214 .context("failed to decrypt access token")?;
215 platform.activate(true);
216 platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
217 Ok((user_id.to_string(), access_token))
218 })
219 }
220
221 pub async fn disconnect(&self) -> Result<()> {
222 let conn_id = self.connection_id().await?;
223 self.peer.disconnect(conn_id).await;
224 Ok(())
225 }
226
227 async fn connection_id(&self) -> Result<ConnectionId> {
228 self.state
229 .read()
230 .await
231 .connection_id
232 .ok_or_else(|| anyhow!("not connected"))
233 }
234
235 pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
236 self.peer.send(self.connection_id().await?, message).await
237 }
238
239 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
240 self.peer
241 .request(self.connection_id().await?, request)
242 .await
243 }
244
245 pub fn respond<T: RequestMessage>(
246 &self,
247 receipt: Receipt<T>,
248 response: T::Response,
249 ) -> impl Future<Output = Result<()>> {
250 self.peer.respond(receipt, response)
251 }
252}
253
254pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
255 type Output: 'a + Future<Output = anyhow::Result<()>>;
256
257 fn handle(
258 &self,
259 message: TypedEnvelope<M>,
260 rpc: &'a Client,
261 cx: &'a mut gpui::AsyncAppContext,
262 ) -> Self::Output;
263}
264
265impl<'a, M, F, Fut> MessageHandler<'a, M> for F
266where
267 M: proto::EnvelopedMessage,
268 F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
269 Fut: 'a + Future<Output = anyhow::Result<()>>,
270{
271 type Output = Fut;
272
273 fn handle(
274 &self,
275 message: TypedEnvelope<M>,
276 rpc: &'a Client,
277 cx: &'a mut gpui::AsyncAppContext,
278 ) -> Self::Output {
279 (self)(message, rpc, cx)
280 }
281}
282
283const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
284
285pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
286 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
287}
288
289pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
290 let path = url.strip_prefix(WORKTREE_URL_PREFIX)?;
291 let mut parts = path.split('/');
292 let id = parts.next()?.parse::<u64>().ok()?;
293 let access_token = parts.next()?;
294 if access_token.is_empty() {
295 return None;
296 }
297 Some((id, access_token.to_string()))
298}
299
300const LOGIN_RESPONSE: &'static str = "
301<!DOCTYPE html>
302<html>
303<script>window.close();</script>
304</html>
305";
306
307#[test]
308fn test_encode_and_decode_worktree_url() {
309 let url = encode_worktree_url(5, "deadbeef");
310 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
311 assert_eq!(decode_worktree_url("not://the-right-format"), None);
312}