1use crate::worktree::{File, Worktree};
2
3use super::util::SurfResultExt as _;
4use anyhow::{anyhow, Context, Result};
5use gpui::executor::Background;
6use gpui::{AsyncAppContext, ModelHandle, Task};
7use lazy_static::lazy_static;
8use postage::prelude::Stream;
9use smol::lock::Mutex;
10use std::collections::HashMap;
11use std::time::Duration;
12use std::{convert::TryFrom, future::Future, sync::Arc};
13use surf::Url;
14use zed_rpc::{
15 proto::{EnvelopedMessage, RequestMessage},
16 rest, Peer, Receipt, TypedEnvelope,
17};
18
19pub use zed_rpc::{proto, ConnectionId, PeerId};
20
21lazy_static! {
22 static ref ZED_SERVER_URL: String =
23 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
24}
25
26#[derive(Clone)]
27pub struct Client {
28 peer: Arc<Peer>,
29 pub state: Arc<Mutex<ClientState>>,
30}
31
32#[derive(Default)]
33pub struct ClientState {
34 connection_id: Option<ConnectionId>,
35 pub shared_worktrees: HashMap<u64, ModelHandle<Worktree>>,
36 pub shared_files: HashMap<File, HashMap<PeerId, usize>>,
37}
38
39impl Client {
40 pub fn new() -> Self {
41 Self {
42 peer: Peer::new(),
43 state: Default::default(),
44 }
45 }
46
47 pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
48 where
49 H: 'static + for<'a> MessageHandler<'a, M>,
50 M: proto::EnvelopedMessage,
51 {
52 let this = self.clone();
53 let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
54 cx.spawn(|mut cx| async move {
55 while let Some(message) = messages.recv().await {
56 if let Err(err) = handler.handle(message, &this, &mut cx).await {
57 log::error!("error handling message: {:?}", err);
58 }
59 }
60 })
61 .detach();
62 }
63
64 pub async fn connect_to_server(
65 &self,
66 cx: &AsyncAppContext,
67 executor: &Arc<Background>,
68 ) -> surf::Result<ConnectionId> {
69 if let Some(connection_id) = self.state.lock().await.connection_id {
70 return Ok(connection_id);
71 }
72
73 let (user_id, access_token) = Self::login(cx.platform(), executor).await?;
74
75 let mut response = surf::get(format!(
76 "{}{}",
77 *ZED_SERVER_URL,
78 &rest::GET_RPC_ADDRESS_PATH
79 ))
80 .header(
81 "Authorization",
82 http_auth_basic::Credentials::new(&user_id, &access_token).as_http_header(),
83 )
84 .await
85 .context("rpc address request failed")?;
86
87 let rest::GetRpcAddressResponse { address } = response
88 .body_json()
89 .await
90 .context("failed to parse rpc address response")?;
91
92 // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
93 // a TLS stream using `native-tls`.
94 let stream = smol::net::TcpStream::connect(&address).await?;
95 log::info!("connected to rpc address {}", address);
96
97 let connection_id = self.peer.add_connection(stream).await;
98 executor
99 .spawn(self.peer.handle_messages(connection_id))
100 .detach();
101
102 let auth_response = self
103 .peer
104 .request(
105 connection_id,
106 proto::Auth {
107 user_id: user_id.parse()?,
108 access_token,
109 },
110 )
111 .await
112 .context("rpc auth request failed")?;
113 if !auth_response.credentials_valid {
114 Err(anyhow!("failed to authenticate with RPC server"))?;
115 }
116
117 Ok(connection_id)
118 }
119
120 pub fn login(
121 platform: Arc<dyn gpui::Platform>,
122 executor: &Arc<gpui::executor::Background>,
123 ) -> Task<Result<(String, String)>> {
124 let executor = executor.clone();
125 executor.clone().spawn(async move {
126 if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
127 log::info!("already signed in. user_id: {}", user_id);
128 return Ok((user_id, String::from_utf8(access_token).unwrap()));
129 }
130
131 // Generate a pair of asymmetric encryption keys. The public key will be used by the
132 // zed server to encrypt the user's access token, so that it can'be intercepted by
133 // any other app running on the user's device.
134 let (public_key, private_key) =
135 zed_rpc::auth::keypair().expect("failed to generate keypair for auth");
136 let public_key_string =
137 String::try_from(public_key).expect("failed to serialize public key for auth");
138
139 // Start an HTTP server to receive the redirect from Zed's sign-in page.
140 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
141 let port = server.server_addr().port();
142
143 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
144 // that the user is signing in from a Zed app running on the same device.
145 platform.open_url(&format!(
146 "{}/sign_in?native_app_port={}&native_app_public_key={}",
147 *ZED_SERVER_URL, port, public_key_string
148 ));
149
150 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
151 // access token from the query params.
152 //
153 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
154 // custom URL scheme instead of this local HTTP server.
155 let (user_id, access_token) = executor
156 .spawn(async move {
157 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
158 let path = req.url();
159 let mut user_id = None;
160 let mut access_token = None;
161 let url = Url::parse(&format!("http://example.com{}", path))
162 .context("failed to parse login notification url")?;
163 for (key, value) in url.query_pairs() {
164 if key == "access_token" {
165 access_token = Some(value.to_string());
166 } else if key == "user_id" {
167 user_id = Some(value.to_string());
168 }
169 }
170 req.respond(
171 tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
172 tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
173 ),
174 )
175 .context("failed to respond to login http request")?;
176 Ok((
177 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
178 access_token
179 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
180 ))
181 } else {
182 Err(anyhow!("didn't receive login redirect"))
183 }
184 })
185 .await?;
186
187 let access_token = private_key
188 .decrypt_string(&access_token)
189 .context("failed to decrypt access token")?;
190 platform.activate(true);
191 platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
192 Ok((user_id.to_string(), access_token))
193 })
194 }
195
196 pub fn send<T: EnvelopedMessage>(
197 &self,
198 connection_id: ConnectionId,
199 message: T,
200 ) -> impl Future<Output = Result<()>> {
201 self.peer.send(connection_id, message)
202 }
203
204 pub fn request<T: RequestMessage>(
205 &self,
206 connection_id: ConnectionId,
207 request: T,
208 ) -> impl Future<Output = Result<T::Response>> {
209 self.peer.request(connection_id, request)
210 }
211
212 pub fn respond<T: RequestMessage>(
213 &self,
214 receipt: Receipt<T>,
215 response: T::Response,
216 ) -> impl Future<Output = Result<()>> {
217 self.peer.respond(receipt, response)
218 }
219}
220
221pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
222 type Output: 'a + Future<Output = anyhow::Result<()>>;
223
224 fn handle(
225 &self,
226 message: TypedEnvelope<M>,
227 rpc: &'a Client,
228 cx: &'a mut gpui::AsyncAppContext,
229 ) -> Self::Output;
230}
231
232impl<'a, M, F, Fut> MessageHandler<'a, M> for F
233where
234 M: proto::EnvelopedMessage,
235 F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
236 Fut: 'a + Future<Output = anyhow::Result<()>>,
237{
238 type Output = Fut;
239
240 fn handle(
241 &self,
242 message: TypedEnvelope<M>,
243 rpc: &'a Client,
244 cx: &'a mut gpui::AsyncAppContext,
245 ) -> Self::Output {
246 (self)(message, rpc, cx)
247 }
248}
249
250const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
251
252pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
253 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
254}
255
256pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
257 let path = url.strip_prefix(WORKTREE_URL_PREFIX)?;
258 let mut parts = path.split('/');
259 let id = parts.next()?.parse::<u64>().ok()?;
260 let access_token = parts.next()?;
261 if access_token.is_empty() {
262 return None;
263 }
264 Some((id, access_token.to_string()))
265}
266
267const LOGIN_RESPONSE: &'static str = "
268<!DOCTYPE html>
269<html>
270<script>window.close();</script>
271</html>
272";
273
274#[test]
275fn test_encode_and_decode_worktree_url() {
276 let url = encode_worktree_url(5, "deadbeef");
277 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
278 assert_eq!(decode_worktree_url("not://the-right-format"), None);
279}