1use anyhow::{anyhow, Context, Result};
2use async_tungstenite::tungstenite::http::Request;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use gpui::{AsyncAppContext, Entity, ModelContext, Task};
5use lazy_static::lazy_static;
6use parking_lot::RwLock;
7use postage::prelude::Stream;
8use postage::sink::Sink;
9use postage::watch;
10use std::any::TypeId;
11use std::collections::HashMap;
12use std::sync::Weak;
13use std::time::{Duration, Instant};
14use std::{convert::TryFrom, future::Future, sync::Arc};
15use surf::Url;
16use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
17pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
18use zrpc::{
19 proto::{EnvelopedMessage, RequestMessage},
20 Peer, Receipt,
21};
22
23lazy_static! {
24 static ref ZED_SERVER_URL: String =
25 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
26}
27
28pub struct Client {
29 peer: Arc<Peer>,
30 state: RwLock<ClientState>,
31}
32
33struct ClientState {
34 connection_id: Option<ConnectionId>,
35 user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
36 entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
37 model_handlers: HashMap<
38 (TypeId, u64),
39 Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
40 >,
41}
42
43impl Default for ClientState {
44 fn default() -> Self {
45 Self {
46 connection_id: Default::default(),
47 user_id: watch::channel(),
48 entity_id_extractors: Default::default(),
49 model_handlers: Default::default(),
50 }
51 }
52}
53
54pub struct Subscription {
55 client: Weak<Client>,
56 id: (TypeId, u64),
57}
58
59impl Drop for Subscription {
60 fn drop(&mut self) {
61 if let Some(client) = self.client.upgrade() {
62 drop(
63 client
64 .state
65 .write()
66 .model_handlers
67 .remove(&self.id)
68 .unwrap(),
69 );
70 }
71 }
72}
73
74impl Client {
75 pub fn new() -> Arc<Self> {
76 Arc::new(Self {
77 peer: Peer::new(),
78 state: Default::default(),
79 })
80 }
81
82 pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
83 self.state.read().user_id.1.clone()
84 }
85
86 pub fn subscribe_from_model<T, M, F>(
87 self: &Arc<Self>,
88 remote_id: u64,
89 cx: &mut ModelContext<M>,
90 mut handler: F,
91 ) -> Subscription
92 where
93 T: EntityMessage,
94 M: Entity,
95 F: 'static
96 + Send
97 + Sync
98 + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
99 {
100 let subscription_id = (TypeId::of::<T>(), remote_id);
101 let client = self.clone();
102 let mut state = self.state.write();
103 let model = cx.handle().downgrade();
104 state
105 .entity_id_extractors
106 .entry(subscription_id.0)
107 .or_insert_with(|| {
108 Box::new(|envelope| {
109 let envelope = envelope
110 .as_any()
111 .downcast_ref::<TypedEnvelope<T>>()
112 .unwrap();
113 envelope.payload.remote_entity_id()
114 })
115 });
116 let prev_handler = state.model_handlers.insert(
117 subscription_id,
118 Box::new(move |envelope, cx| {
119 if let Some(model) = model.upgrade(cx) {
120 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
121 model.update(cx, |model, cx| {
122 if let Err(error) = handler(model, *envelope, client.clone(), cx) {
123 log::error!("error handling message: {}", error)
124 }
125 });
126 }
127 }),
128 );
129 if prev_handler.is_some() {
130 panic!("registered a handler for the same entity twice")
131 }
132
133 Subscription {
134 client: Arc::downgrade(self),
135 id: subscription_id,
136 }
137 }
138
139 pub async fn authenticate_and_connect(
140 self: &Arc<Self>,
141 cx: AsyncAppContext,
142 ) -> anyhow::Result<()> {
143 if self.state.read().connection_id.is_some() {
144 return Ok(());
145 }
146
147 let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
148 let user_id = user_id.parse::<u64>()?;
149 let request =
150 Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
151
152 if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
153 let stream = smol::net::TcpStream::connect(host).await?;
154 let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
155 let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
156 .await
157 .context("websocket handshake")?;
158 self.add_connection(user_id, stream, cx).await?;
159 } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
160 let stream = smol::net::TcpStream::connect(host).await?;
161 let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
162 let (stream, _) = async_tungstenite::client_async(request, stream)
163 .await
164 .context("websocket handshake")?;
165 self.add_connection(user_id, stream, cx).await?;
166 } else {
167 return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
168 };
169
170 log::info!("connected to rpc address {}", *ZED_SERVER_URL);
171 Ok(())
172 }
173
174 pub async fn add_connection<Conn>(
175 self: &Arc<Self>,
176 user_id: u64,
177 conn: Conn,
178 cx: AsyncAppContext,
179 ) -> anyhow::Result<()>
180 where
181 Conn: 'static
182 + futures::Sink<WebSocketMessage, Error = WebSocketError>
183 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
184 + Unpin
185 + Send,
186 {
187 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
188 {
189 let mut cx = cx.clone();
190 let this = self.clone();
191 cx.foreground()
192 .spawn(async move {
193 while let Some(message) = incoming.recv().await {
194 let mut state = this.state.write();
195 if let Some(extract_entity_id) =
196 state.entity_id_extractors.get(&message.payload_type_id())
197 {
198 let entity_id = (extract_entity_id)(message.as_ref());
199 if let Some(handler) = state
200 .model_handlers
201 .get_mut(&(message.payload_type_id(), entity_id))
202 {
203 let start_time = Instant::now();
204 log::info!("RPC client message {}", message.payload_type_name());
205 (handler)(message, &mut cx);
206 log::info!(
207 "RPC message handled. duration:{:?}",
208 start_time.elapsed()
209 );
210 } else {
211 log::info!("unhandled message {}", message.payload_type_name());
212 }
213 } else {
214 log::info!("unhandled message {}", message.payload_type_name());
215 }
216 }
217 })
218 .detach();
219 }
220 cx.background()
221 .spawn(async move {
222 if let Err(error) = handle_io.await {
223 log::error!("connection error: {:?}", error);
224 }
225 })
226 .detach();
227 let mut state = self.state.write();
228 state.connection_id = Some(connection_id);
229 state.user_id.0.send(Some(user_id)).await?;
230 Ok(())
231 }
232
233 pub fn login(
234 platform: Arc<dyn gpui::Platform>,
235 executor: &Arc<gpui::executor::Background>,
236 ) -> Task<Result<(String, String)>> {
237 let executor = executor.clone();
238 executor.clone().spawn(async move {
239 if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
240 log::info!("already signed in. user_id: {}", user_id);
241 return Ok((user_id, String::from_utf8(access_token).unwrap()));
242 }
243
244 // Generate a pair of asymmetric encryption keys. The public key will be used by the
245 // zed server to encrypt the user's access token, so that it can'be intercepted by
246 // any other app running on the user's device.
247 let (public_key, private_key) =
248 zrpc::auth::keypair().expect("failed to generate keypair for auth");
249 let public_key_string =
250 String::try_from(public_key).expect("failed to serialize public key for auth");
251
252 // Start an HTTP server to receive the redirect from Zed's sign-in page.
253 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
254 let port = server.server_addr().port();
255
256 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
257 // that the user is signing in from a Zed app running on the same device.
258 platform.open_url(&format!(
259 "{}/sign_in?native_app_port={}&native_app_public_key={}",
260 *ZED_SERVER_URL, port, public_key_string
261 ));
262
263 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
264 // access token from the query params.
265 //
266 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
267 // custom URL scheme instead of this local HTTP server.
268 let (user_id, access_token) = executor
269 .spawn(async move {
270 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
271 let path = req.url();
272 let mut user_id = None;
273 let mut access_token = None;
274 let url = Url::parse(&format!("http://example.com{}", path))
275 .context("failed to parse login notification url")?;
276 for (key, value) in url.query_pairs() {
277 if key == "access_token" {
278 access_token = Some(value.to_string());
279 } else if key == "user_id" {
280 user_id = Some(value.to_string());
281 }
282 }
283 req.respond(
284 tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
285 tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
286 ),
287 )
288 .context("failed to respond to login http request")?;
289 Ok((
290 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
291 access_token
292 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
293 ))
294 } else {
295 Err(anyhow!("didn't receive login redirect"))
296 }
297 })
298 .await?;
299
300 let access_token = private_key
301 .decrypt_string(&access_token)
302 .context("failed to decrypt access token")?;
303 platform.activate(true);
304 platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
305 Ok((user_id.to_string(), access_token))
306 })
307 }
308
309 pub async fn disconnect(&self) -> Result<()> {
310 let conn_id = self.connection_id()?;
311 self.peer.disconnect(conn_id).await;
312 Ok(())
313 }
314
315 fn connection_id(&self) -> Result<ConnectionId> {
316 self.state
317 .read()
318 .connection_id
319 .ok_or_else(|| anyhow!("not connected"))
320 }
321
322 pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
323 self.peer.send(self.connection_id()?, message).await
324 }
325
326 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
327 self.peer.request(self.connection_id()?, request).await
328 }
329
330 pub fn respond<T: RequestMessage>(
331 &self,
332 receipt: Receipt<T>,
333 response: T::Response,
334 ) -> impl Future<Output = Result<()>> {
335 self.peer.respond(receipt, response)
336 }
337}
338
339pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
340 type Output: 'a + Future<Output = anyhow::Result<()>>;
341
342 fn handle(
343 &self,
344 message: TypedEnvelope<M>,
345 rpc: &'a Client,
346 cx: &'a mut gpui::AsyncAppContext,
347 ) -> Self::Output;
348}
349
350impl<'a, M, F, Fut> MessageHandler<'a, M> for F
351where
352 M: proto::EnvelopedMessage,
353 F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
354 Fut: 'a + Future<Output = anyhow::Result<()>>,
355{
356 type Output = Fut;
357
358 fn handle(
359 &self,
360 message: TypedEnvelope<M>,
361 rpc: &'a Client,
362 cx: &'a mut gpui::AsyncAppContext,
363 ) -> Self::Output {
364 (self)(message, rpc, cx)
365 }
366}
367
368const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
369
370pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
371 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
372}
373
374pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
375 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
376 let mut parts = path.split('/');
377 let id = parts.next()?.parse::<u64>().ok()?;
378 let access_token = parts.next()?;
379 if access_token.is_empty() {
380 return None;
381 }
382 Some((id, access_token.to_string()))
383}
384
385const LOGIN_RESPONSE: &'static str = "
386<!DOCTYPE html>
387<html>
388<script>window.close();</script>
389</html>
390";
391
392#[test]
393fn test_encode_and_decode_worktree_url() {
394 let url = encode_worktree_url(5, "deadbeef");
395 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
396 assert_eq!(
397 decode_worktree_url(&format!("\n {}\t", url)),
398 Some((5, "deadbeef".to_string()))
399 );
400 assert_eq!(decode_worktree_url("not://the-right-format"), None);
401}