1use crate::{language::LanguageRegistry, worktree::Worktree};
2use anyhow::{anyhow, Context, Result};
3use async_tungstenite::tungstenite::http::Request;
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::Stream;
6use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
7use lazy_static::lazy_static;
8use smol::lock::RwLock;
9use std::collections::HashMap;
10use std::time::Duration;
11use std::{convert::TryFrom, future::Future, sync::Arc};
12use surf::Url;
13pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
14use zrpc::{
15 proto::{EnvelopedMessage, RequestMessage},
16 ForegroundRouter, Peer, Receipt,
17};
18
19lazy_static! {
20 static ref ZED_SERVER_URL: String =
21 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
22}
23
24#[derive(Clone)]
25pub struct Client {
26 peer: Arc<Peer>,
27 pub state: Arc<RwLock<ClientState>>,
28}
29
30pub struct ClientState {
31 connection_id: Option<ConnectionId>,
32 pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
33 pub languages: Arc<LanguageRegistry>,
34}
35
36impl ClientState {
37 pub fn shared_worktree(
38 &self,
39 id: u64,
40 cx: &mut AsyncAppContext,
41 ) -> Result<ModelHandle<Worktree>> {
42 if let Some(worktree) = self.shared_worktrees.get(&id) {
43 if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
44 Ok(worktree)
45 } else {
46 Err(anyhow!("worktree {} was dropped", id))
47 }
48 } else {
49 Err(anyhow!("worktree {} does not exist", id))
50 }
51 }
52}
53
54impl Client {
55 pub fn new(languages: Arc<LanguageRegistry>) -> Self {
56 Self {
57 peer: Peer::new(),
58 state: Arc::new(RwLock::new(ClientState {
59 connection_id: None,
60 shared_worktrees: Default::default(),
61 languages,
62 })),
63 }
64 }
65
66 pub fn on_message<H, M>(
67 &self,
68 router: &mut ForegroundRouter,
69 handler: H,
70 cx: &mut gpui::MutableAppContext,
71 ) where
72 H: 'static + Clone + for<'a> MessageHandler<'a, M>,
73 M: proto::EnvelopedMessage,
74 {
75 let this = self.clone();
76 let cx = cx.to_async();
77 router.add_message_handler(move |message| {
78 let this = this.clone();
79 let mut cx = cx.clone();
80 let handler = handler.clone();
81 async move { handler.handle(message, &this, &mut cx).await }
82 });
83 }
84
85 pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
86 self.peer.subscribe()
87 }
88
89 pub async fn log_in_and_connect(
90 &self,
91 router: Arc<ForegroundRouter>,
92 cx: AsyncAppContext,
93 ) -> surf::Result<()> {
94 if self.state.read().await.connection_id.is_some() {
95 return Ok(());
96 }
97
98 let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
99 let user_id: i32 = user_id.parse()?;
100 let request =
101 Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
102
103 if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
104 let stream = smol::net::TcpStream::connect(host).await?;
105 let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
106 let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
107 .await
108 .context("websocket handshake")?;
109 log::info!("connected to rpc address {}", *ZED_SERVER_URL);
110 self.add_connection(stream, router, cx).await?;
111 } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
112 let stream = smol::net::TcpStream::connect(host).await?;
113 let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
114 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
115 log::info!("connected to rpc address {}", *ZED_SERVER_URL);
116 self.add_connection(stream, router, cx).await?;
117 } else {
118 return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
119 };
120
121 Ok(())
122 }
123
124 pub async fn add_connection<Conn>(
125 &self,
126 conn: Conn,
127 router: Arc<ForegroundRouter>,
128 cx: AsyncAppContext,
129 ) -> surf::Result<()>
130 where
131 Conn: 'static
132 + futures::Sink<WebSocketMessage, Error = WebSocketError>
133 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
134 + Unpin
135 + Send,
136 {
137 let (connection_id, handle_io, handle_messages) =
138 self.peer.add_connection(conn, router).await;
139 cx.foreground().spawn(handle_messages).detach();
140 cx.background()
141 .spawn(async move {
142 if let Err(error) = handle_io.await {
143 log::error!("connection error: {:?}", error);
144 }
145 })
146 .detach();
147 self.state.write().await.connection_id = Some(connection_id);
148 Ok(())
149 }
150
151 pub fn login(
152 platform: Arc<dyn gpui::Platform>,
153 executor: &Arc<gpui::executor::Background>,
154 ) -> Task<Result<(String, String)>> {
155 let executor = executor.clone();
156 executor.clone().spawn(async move {
157 if let Some((user_id, access_token)) = platform.read_credentials(&ZED_SERVER_URL) {
158 log::info!("already signed in. user_id: {}", user_id);
159 return Ok((user_id, String::from_utf8(access_token).unwrap()));
160 }
161
162 // Generate a pair of asymmetric encryption keys. The public key will be used by the
163 // zed server to encrypt the user's access token, so that it can'be intercepted by
164 // any other app running on the user's device.
165 let (public_key, private_key) =
166 zrpc::auth::keypair().expect("failed to generate keypair for auth");
167 let public_key_string =
168 String::try_from(public_key).expect("failed to serialize public key for auth");
169
170 // Start an HTTP server to receive the redirect from Zed's sign-in page.
171 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
172 let port = server.server_addr().port();
173
174 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
175 // that the user is signing in from a Zed app running on the same device.
176 platform.open_url(&format!(
177 "{}/sign_in?native_app_port={}&native_app_public_key={}",
178 *ZED_SERVER_URL, port, public_key_string
179 ));
180
181 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
182 // access token from the query params.
183 //
184 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
185 // custom URL scheme instead of this local HTTP server.
186 let (user_id, access_token) = executor
187 .spawn(async move {
188 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
189 let path = req.url();
190 let mut user_id = None;
191 let mut access_token = None;
192 let url = Url::parse(&format!("http://example.com{}", path))
193 .context("failed to parse login notification url")?;
194 for (key, value) in url.query_pairs() {
195 if key == "access_token" {
196 access_token = Some(value.to_string());
197 } else if key == "user_id" {
198 user_id = Some(value.to_string());
199 }
200 }
201 req.respond(
202 tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
203 tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
204 ),
205 )
206 .context("failed to respond to login http request")?;
207 Ok((
208 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
209 access_token
210 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
211 ))
212 } else {
213 Err(anyhow!("didn't receive login redirect"))
214 }
215 })
216 .await?;
217
218 let access_token = private_key
219 .decrypt_string(&access_token)
220 .context("failed to decrypt access token")?;
221 platform.activate(true);
222 platform.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes());
223 Ok((user_id.to_string(), access_token))
224 })
225 }
226
227 pub async fn disconnect(&self) -> Result<()> {
228 let conn_id = self.connection_id().await?;
229 self.peer.disconnect(conn_id).await;
230 Ok(())
231 }
232
233 async fn connection_id(&self) -> Result<ConnectionId> {
234 self.state
235 .read()
236 .await
237 .connection_id
238 .ok_or_else(|| anyhow!("not connected"))
239 }
240
241 pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
242 self.peer.send(self.connection_id().await?, message).await
243 }
244
245 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
246 self.peer
247 .request(self.connection_id().await?, request)
248 .await
249 }
250
251 pub fn respond<T: RequestMessage>(
252 &self,
253 receipt: Receipt<T>,
254 response: T::Response,
255 ) -> impl Future<Output = Result<()>> {
256 self.peer.respond(receipt, response)
257 }
258}
259
260pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
261 type Output: 'a + Future<Output = anyhow::Result<()>>;
262
263 fn handle(
264 &self,
265 message: TypedEnvelope<M>,
266 rpc: &'a Client,
267 cx: &'a mut gpui::AsyncAppContext,
268 ) -> Self::Output;
269}
270
271impl<'a, M, F, Fut> MessageHandler<'a, M> for F
272where
273 M: proto::EnvelopedMessage,
274 F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
275 Fut: 'a + Future<Output = anyhow::Result<()>>,
276{
277 type Output = Fut;
278
279 fn handle(
280 &self,
281 message: TypedEnvelope<M>,
282 rpc: &'a Client,
283 cx: &'a mut gpui::AsyncAppContext,
284 ) -> Self::Output {
285 (self)(message, rpc, cx)
286 }
287}
288
289const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
290
291pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
292 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
293}
294
295pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
296 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
297 let mut parts = path.split('/');
298 let id = parts.next()?.parse::<u64>().ok()?;
299 let access_token = parts.next()?;
300 if access_token.is_empty() {
301 return None;
302 }
303 Some((id, access_token.to_string()))
304}
305
306const LOGIN_RESPONSE: &'static str = "
307<!DOCTYPE html>
308<html>
309<script>window.close();</script>
310</html>
311";
312
313#[test]
314fn test_encode_and_decode_worktree_url() {
315 let url = encode_worktree_url(5, "deadbeef");
316 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
317 assert_eq!(
318 decode_worktree_url(&format!("\n {}\t", url)),
319 Some((5, "deadbeef".to_string()))
320 );
321 assert_eq!(decode_worktree_url("not://the-right-format"), None);
322}