1use super::util::SurfResultExt as _;
2use crate::{language::LanguageRegistry, worktree::Worktree};
3use anyhow::{anyhow, Context, Result};
4use gpui::executor::Background;
5use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
6use lazy_static::lazy_static;
7use postage::prelude::Stream;
8use smol::lock::Mutex;
9use std::collections::HashMap;
10use std::time::Duration;
11use std::{convert::TryFrom, future::Future, sync::Arc};
12use surf::Url;
13pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope};
14use zed_rpc::{
15 proto::{EnvelopedMessage, RequestMessage},
16 rest, Peer, Receipt,
17};
18
19lazy_static! {
20 static ref ZED_SERVER_URL: String =
21 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
22}
23
24#[derive(Clone)]
25pub struct Client {
26 peer: Arc<Peer>,
27 pub state: Arc<Mutex<ClientState>>,
28}
29
30pub struct ClientState {
31 connection_id: Option<ConnectionId>,
32 pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
33 pub languages: Arc<LanguageRegistry>,
34}
35
36impl ClientState {
37 pub fn shared_worktree(
38 &mut self,
39 id: u64,
40 cx: &mut AsyncAppContext,
41 ) -> Result<ModelHandle<Worktree>> {
42 if let Some(worktree) = self.shared_worktrees.get(&id) {
43 if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
44 Ok(worktree)
45 } else {
46 self.shared_worktrees.remove(&id);
47 Err(anyhow!("worktree {} was dropped", id))
48 }
49 } else {
50 Err(anyhow!("worktree {} does not exist", id))
51 }
52 }
53}
54
55impl Client {
56 pub fn new(languages: Arc<LanguageRegistry>) -> Self {
57 Self {
58 peer: Peer::new(),
59 state: Arc::new(Mutex::new(ClientState {
60 connection_id: None,
61 shared_worktrees: Default::default(),
62 languages,
63 })),
64 }
65 }
66
67 pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
68 where
69 H: 'static + for<'a> MessageHandler<'a, M>,
70 M: proto::EnvelopedMessage,
71 {
72 let this = self.clone();
73 let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
74 cx.spawn(|mut cx| async move {
75 while let Some(message) = messages.recv().await {
76 if let Err(err) = handler.handle(message, &this, &mut cx).await {
77 log::error!("error handling message: {:?}", err);
78 }
79 }
80 })
81 .detach();
82 }
83
84 pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
85 if self.state.lock().await.connection_id.is_some() {
86 return Ok(());
87 }
88
89 let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
90
91 let mut response = surf::get(format!(
92 "{}{}",
93 *ZED_SERVER_URL,
94 &rest::GET_RPC_ADDRESS_PATH
95 ))
96 .header(
97 "Authorization",
98 http_auth_basic::Credentials::new(&user_id, &access_token).as_http_header(),
99 )
100 .await
101 .context("rpc address request failed")?;
102
103 let rest::GetRpcAddressResponse { address } = response
104 .body_json()
105 .await
106 .context("failed to parse rpc address response")?;
107
108 self.connect(&address, user_id.parse()?, access_token, &cx.background())
109 .await
110 }
111
112 pub async fn connect(
113 &self,
114 address: &str,
115 user_id: i32,
116 access_token: String,
117 executor: &Arc<Background>,
118 ) -> surf::Result<()> {
119 // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
120 // a TLS stream using `native-tls`.
121 let stream = smol::net::TcpStream::connect(&address).await?;
122 log::info!("connected to rpc address {}", address);
123
124 let (connection_id, handler) = self.peer.add_connection(stream).await;
125 executor.spawn(handler.run()).detach();
126
127 let auth_response = self
128 .peer
129 .request(
130 connection_id,
131 proto::Auth {
132 user_id,
133 access_token,
134 },
135 )
136 .await
137 .context("rpc auth request failed")?;
138 if !auth_response.credentials_valid {
139 Err(anyhow!("failed to authenticate with RPC server"))?;
140 }
141
142 self.state.lock().await.connection_id = Some(connection_id);
143 Ok(())
144 }
145
146 pub fn login(
147 platform: Arc<dyn gpui::Platform>,
148 executor: &Arc<gpui::executor::Background>,
149 ) -> Task<Result<(String, String)>> {
150 let executor = executor.clone();
151 executor.clone().spawn(async move {
152 if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
153 log::info!("already signed in. user_id: {}", user_id);
154 return Ok((user_id, String::from_utf8(access_token).unwrap()));
155 }
156
157 // Generate a pair of asymmetric encryption keys. The public key will be used by the
158 // zed server to encrypt the user's access token, so that it can'be intercepted by
159 // any other app running on the user's device.
160 let (public_key, private_key) =
161 zed_rpc::auth::keypair().expect("failed to generate keypair for auth");
162 let public_key_string =
163 String::try_from(public_key).expect("failed to serialize public key for auth");
164
165 // Start an HTTP server to receive the redirect from Zed's sign-in page.
166 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
167 let port = server.server_addr().port();
168
169 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
170 // that the user is signing in from a Zed app running on the same device.
171 platform.open_url(&format!(
172 "{}/sign_in?native_app_port={}&native_app_public_key={}",
173 *ZED_SERVER_URL, port, public_key_string
174 ));
175
176 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
177 // access token from the query params.
178 //
179 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
180 // custom URL scheme instead of this local HTTP server.
181 let (user_id, access_token) = executor
182 .spawn(async move {
183 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
184 let path = req.url();
185 let mut user_id = None;
186 let mut access_token = None;
187 let url = Url::parse(&format!("http://example.com{}", path))
188 .context("failed to parse login notification url")?;
189 for (key, value) in url.query_pairs() {
190 if key == "access_token" {
191 access_token = Some(value.to_string());
192 } else if key == "user_id" {
193 user_id = Some(value.to_string());
194 }
195 }
196 req.respond(
197 tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
198 tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
199 ),
200 )
201 .context("failed to respond to login http request")?;
202 Ok((
203 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
204 access_token
205 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
206 ))
207 } else {
208 Err(anyhow!("didn't receive login redirect"))
209 }
210 })
211 .await?;
212
213 let access_token = private_key
214 .decrypt_string(&access_token)
215 .context("failed to decrypt access token")?;
216 platform.activate(true);
217 platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
218 Ok((user_id.to_string(), access_token))
219 })
220 }
221
222 async fn connection_id(&self) -> Result<ConnectionId> {
223 self.state
224 .lock()
225 .await
226 .connection_id
227 .ok_or_else(|| anyhow!("not connected"))
228 }
229
230 pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
231 self.peer.send(self.connection_id().await?, message).await
232 }
233
234 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
235 self.peer
236 .request(self.connection_id().await?, request)
237 .await
238 }
239
240 pub fn respond<T: RequestMessage>(
241 &self,
242 receipt: Receipt<T>,
243 response: T::Response,
244 ) -> impl Future<Output = Result<()>> {
245 self.peer.respond(receipt, response)
246 }
247}
248
249pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
250 type Output: 'a + Future<Output = anyhow::Result<()>>;
251
252 fn handle(
253 &self,
254 message: TypedEnvelope<M>,
255 rpc: &'a Client,
256 cx: &'a mut gpui::AsyncAppContext,
257 ) -> Self::Output;
258}
259
260impl<'a, M, F, Fut> MessageHandler<'a, M> for F
261where
262 M: proto::EnvelopedMessage,
263 F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
264 Fut: 'a + Future<Output = anyhow::Result<()>>,
265{
266 type Output = Fut;
267
268 fn handle(
269 &self,
270 message: TypedEnvelope<M>,
271 rpc: &'a Client,
272 cx: &'a mut gpui::AsyncAppContext,
273 ) -> Self::Output {
274 (self)(message, rpc, cx)
275 }
276}
277
278const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
279
280pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
281 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
282}
283
284pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
285 let path = url.strip_prefix(WORKTREE_URL_PREFIX)?;
286 let mut parts = path.split('/');
287 let id = parts.next()?.parse::<u64>().ok()?;
288 let access_token = parts.next()?;
289 if access_token.is_empty() {
290 return None;
291 }
292 Some((id, access_token.to_string()))
293}
294
295const LOGIN_RESPONSE: &'static str = "
296<!DOCTYPE html>
297<html>
298<script>window.close();</script>
299</html>
300";
301
302#[test]
303fn test_encode_and_decode_worktree_url() {
304 let url = encode_worktree_url(5, "deadbeef");
305 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
306 assert_eq!(decode_worktree_url("not://the-right-format"), None);
307}