1use super::util::SurfResultExt as _;
2use crate::{language::LanguageRegistry, worktree::Worktree};
3use anyhow::{anyhow, Context, Result};
4use gpui::executor::Background;
5use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
6use lazy_static::lazy_static;
7use postage::prelude::Stream;
8use smol::lock::Mutex;
9use std::collections::HashMap;
10use std::time::Duration;
11use std::{convert::TryFrom, future::Future, sync::Arc};
12use surf::Url;
13pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope};
14use zed_rpc::{
15 proto::{EnvelopedMessage, RequestMessage},
16 rest, Peer, Receipt,
17};
18
19lazy_static! {
20 static ref ZED_SERVER_URL: String =
21 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
22}
23
24#[derive(Clone)]
25pub struct Client {
26 peer: Arc<Peer>,
27 pub state: Arc<Mutex<ClientState>>,
28}
29
30pub struct ClientState {
31 connection_id: Option<ConnectionId>,
32 pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
33 pub languages: Arc<LanguageRegistry>,
34}
35
36impl ClientState {
37 pub fn shared_worktree(
38 &mut self,
39 id: u64,
40 cx: &mut AsyncAppContext,
41 ) -> Result<ModelHandle<Worktree>> {
42 if let Some(worktree) = self.shared_worktrees.get(&id) {
43 if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
44 Ok(worktree)
45 } else {
46 self.shared_worktrees.remove(&id);
47 Err(anyhow!("worktree {} was dropped", id))
48 }
49 } else {
50 Err(anyhow!("worktree {} does not exist", id))
51 }
52 }
53}
54
55impl Client {
56 pub fn new(languages: Arc<LanguageRegistry>) -> Self {
57 Self {
58 peer: Peer::new(),
59 state: Arc::new(Mutex::new(ClientState {
60 connection_id: None,
61 shared_worktrees: Default::default(),
62 languages,
63 })),
64 }
65 }
66
67 pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
68 where
69 H: 'static + for<'a> MessageHandler<'a, M>,
70 M: proto::EnvelopedMessage,
71 {
72 let this = self.clone();
73 let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
74 cx.spawn(|mut cx| async move {
75 while let Some(message) = messages.recv().await {
76 if let Err(err) = handler.handle(message, &this, &mut cx).await {
77 log::error!("error handling message: {:?}", err);
78 }
79 }
80 })
81 .detach();
82 }
83
84 pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
85 if self.state.lock().await.connection_id.is_some() {
86 return Ok(());
87 }
88
89 let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
90
91 let mut response = surf::get(format!(
92 "{}{}",
93 *ZED_SERVER_URL,
94 &rest::GET_RPC_ADDRESS_PATH
95 ))
96 .header(
97 "Authorization",
98 http_auth_basic::Credentials::new(&user_id, &access_token).as_http_header(),
99 )
100 .await
101 .context("rpc address request failed")?;
102
103 let rest::GetRpcAddressResponse { address } = response
104 .body_json()
105 .await
106 .context("failed to parse rpc address response")?;
107
108 self.connect(&address, user_id.parse()?, access_token, &cx.background())
109 .await
110 }
111
112 pub async fn connect(
113 &self,
114 address: &str,
115 user_id: i32,
116 access_token: String,
117 executor: &Arc<Background>,
118 ) -> surf::Result<()> {
119 // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
120 // a TLS stream using `native-tls`.
121 let stream = smol::net::TcpStream::connect(&address).await?;
122 log::info!("connected to rpc address {}", address);
123
124 let connection_id = self.peer.add_connection(stream).await;
125 executor
126 .spawn(self.peer.handle_messages(connection_id))
127 .detach();
128
129 let auth_response = self
130 .peer
131 .request(
132 connection_id,
133 proto::Auth {
134 user_id,
135 access_token,
136 },
137 )
138 .await
139 .context("rpc auth request failed")?;
140 if !auth_response.credentials_valid {
141 Err(anyhow!("failed to authenticate with RPC server"))?;
142 }
143
144 self.state.lock().await.connection_id = Some(connection_id);
145 Ok(())
146 }
147
148 pub fn login(
149 platform: Arc<dyn gpui::Platform>,
150 executor: &Arc<gpui::executor::Background>,
151 ) -> Task<Result<(String, String)>> {
152 let executor = executor.clone();
153 executor.clone().spawn(async move {
154 if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
155 log::info!("already signed in. user_id: {}", user_id);
156 return Ok((user_id, String::from_utf8(access_token).unwrap()));
157 }
158
159 // Generate a pair of asymmetric encryption keys. The public key will be used by the
160 // zed server to encrypt the user's access token, so that it can'be intercepted by
161 // any other app running on the user's device.
162 let (public_key, private_key) =
163 zed_rpc::auth::keypair().expect("failed to generate keypair for auth");
164 let public_key_string =
165 String::try_from(public_key).expect("failed to serialize public key for auth");
166
167 // Start an HTTP server to receive the redirect from Zed's sign-in page.
168 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
169 let port = server.server_addr().port();
170
171 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
172 // that the user is signing in from a Zed app running on the same device.
173 platform.open_url(&format!(
174 "{}/sign_in?native_app_port={}&native_app_public_key={}",
175 *ZED_SERVER_URL, port, public_key_string
176 ));
177
178 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
179 // access token from the query params.
180 //
181 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
182 // custom URL scheme instead of this local HTTP server.
183 let (user_id, access_token) = executor
184 .spawn(async move {
185 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
186 let path = req.url();
187 let mut user_id = None;
188 let mut access_token = None;
189 let url = Url::parse(&format!("http://example.com{}", path))
190 .context("failed to parse login notification url")?;
191 for (key, value) in url.query_pairs() {
192 if key == "access_token" {
193 access_token = Some(value.to_string());
194 } else if key == "user_id" {
195 user_id = Some(value.to_string());
196 }
197 }
198 req.respond(
199 tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
200 tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
201 ),
202 )
203 .context("failed to respond to login http request")?;
204 Ok((
205 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
206 access_token
207 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
208 ))
209 } else {
210 Err(anyhow!("didn't receive login redirect"))
211 }
212 })
213 .await?;
214
215 let access_token = private_key
216 .decrypt_string(&access_token)
217 .context("failed to decrypt access token")?;
218 platform.activate(true);
219 platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
220 Ok((user_id.to_string(), access_token))
221 })
222 }
223
224 async fn connection_id(&self) -> Result<ConnectionId> {
225 self.state
226 .lock()
227 .await
228 .connection_id
229 .ok_or_else(|| anyhow!("not connected"))
230 }
231
232 pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
233 self.peer.send(self.connection_id().await?, message).await
234 }
235
236 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
237 self.peer
238 .request(self.connection_id().await?, request)
239 .await
240 }
241
242 pub fn respond<T: RequestMessage>(
243 &self,
244 receipt: Receipt<T>,
245 response: T::Response,
246 ) -> impl Future<Output = Result<()>> {
247 self.peer.respond(receipt, response)
248 }
249}
250
251pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
252 type Output: 'a + Future<Output = anyhow::Result<()>>;
253
254 fn handle(
255 &self,
256 message: TypedEnvelope<M>,
257 rpc: &'a Client,
258 cx: &'a mut gpui::AsyncAppContext,
259 ) -> Self::Output;
260}
261
262impl<'a, M, F, Fut> MessageHandler<'a, M> for F
263where
264 M: proto::EnvelopedMessage,
265 F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
266 Fut: 'a + Future<Output = anyhow::Result<()>>,
267{
268 type Output = Fut;
269
270 fn handle(
271 &self,
272 message: TypedEnvelope<M>,
273 rpc: &'a Client,
274 cx: &'a mut gpui::AsyncAppContext,
275 ) -> Self::Output {
276 (self)(message, rpc, cx)
277 }
278}
279
280const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
281
282pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
283 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
284}
285
286pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
287 let path = url.strip_prefix(WORKTREE_URL_PREFIX)?;
288 let mut parts = path.split('/');
289 let id = parts.next()?.parse::<u64>().ok()?;
290 let access_token = parts.next()?;
291 if access_token.is_empty() {
292 return None;
293 }
294 Some((id, access_token.to_string()))
295}
296
297const LOGIN_RESPONSE: &'static str = "
298<!DOCTYPE html>
299<html>
300<script>window.close();</script>
301</html>
302";
303
304#[test]
305fn test_encode_and_decode_worktree_url() {
306 let url = encode_worktree_url(5, "deadbeef");
307 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
308 assert_eq!(decode_worktree_url("not://the-right-format"), None);
309}