1use crate::auth::{self, UserId};
2
3use super::{auth::PeerExt as _, AppState};
4use anyhow::anyhow;
5use async_std::task;
6use async_tungstenite::{
7 tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
8 WebSocketStream,
9};
10use sha1::{Digest as _, Sha1};
11use std::{
12 collections::{HashMap, HashSet},
13 future::Future,
14 mem,
15 sync::Arc,
16 time::Instant,
17};
18use surf::StatusCode;
19use tide::log;
20use tide::{
21 http::headers::{HeaderName, CONNECTION, UPGRADE},
22 Request, Response,
23};
24use zed_rpc::{
25 auth::random_token,
26 proto::{self, EnvelopedMessage},
27 ConnectionId, Peer, Router, TypedEnvelope,
28};
29
30type ReplicaId = u16;
31
32#[derive(Default)]
33pub struct State {
34 connections: HashMap<ConnectionId, ConnectionState>,
35 pub worktrees: HashMap<u64, WorktreeState>,
36 next_worktree_id: u64,
37}
38
39struct ConnectionState {
40 _user_id: i32,
41 worktrees: HashSet<u64>,
42}
43
44pub struct WorktreeState {
45 host_connection_id: Option<ConnectionId>,
46 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
47 active_replica_ids: HashSet<ReplicaId>,
48 access_token: String,
49 root_name: String,
50 entries: HashMap<u64, proto::Entry>,
51}
52
53impl WorktreeState {
54 pub fn connection_ids(&self) -> Vec<ConnectionId> {
55 self.guest_connection_ids
56 .keys()
57 .copied()
58 .chain(self.host_connection_id)
59 .collect()
60 }
61
62 fn host_connection_id(&self) -> tide::Result<ConnectionId> {
63 Ok(self
64 .host_connection_id
65 .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
66 }
67}
68
69impl State {
70 // Add a new connection associated with a given user.
71 pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) {
72 self.connections.insert(
73 connection_id,
74 ConnectionState {
75 _user_id,
76 worktrees: Default::default(),
77 },
78 );
79 }
80
81 // Remove the given connection and its association with any worktrees.
82 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
83 let mut worktree_ids = Vec::new();
84 if let Some(connection_state) = self.connections.remove(&connection_id) {
85 for worktree_id in connection_state.worktrees {
86 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
87 if worktree.host_connection_id == Some(connection_id) {
88 worktree_ids.push(worktree_id);
89 } else if let Some(replica_id) =
90 worktree.guest_connection_ids.remove(&connection_id)
91 {
92 worktree.active_replica_ids.remove(&replica_id);
93 worktree_ids.push(worktree_id);
94 }
95 }
96 }
97 }
98 worktree_ids
99 }
100
101 // Add the given connection as a guest of the given worktree
102 pub fn join_worktree(
103 &mut self,
104 connection_id: ConnectionId,
105 worktree_id: u64,
106 access_token: &str,
107 ) -> Option<(ReplicaId, &WorktreeState)> {
108 if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
109 if access_token == worktree_state.access_token {
110 if let Some(connection_state) = self.connections.get_mut(&connection_id) {
111 connection_state.worktrees.insert(worktree_id);
112 }
113
114 let mut replica_id = 1;
115 while worktree_state.active_replica_ids.contains(&replica_id) {
116 replica_id += 1;
117 }
118 worktree_state.active_replica_ids.insert(replica_id);
119 worktree_state
120 .guest_connection_ids
121 .insert(connection_id, replica_id);
122 Some((replica_id, worktree_state))
123 } else {
124 None
125 }
126 } else {
127 None
128 }
129 }
130
131 fn read_worktree(
132 &self,
133 worktree_id: u64,
134 connection_id: ConnectionId,
135 ) -> tide::Result<&WorktreeState> {
136 let worktree = self
137 .worktrees
138 .get(&worktree_id)
139 .ok_or_else(|| anyhow!("worktree not found"))?;
140
141 if worktree.host_connection_id == Some(connection_id)
142 || worktree.guest_connection_ids.contains_key(&connection_id)
143 {
144 Ok(worktree)
145 } else {
146 Err(anyhow!(
147 "{} is not a member of worktree {}",
148 connection_id,
149 worktree_id
150 ))?
151 }
152 }
153
154 fn write_worktree(
155 &mut self,
156 worktree_id: u64,
157 connection_id: ConnectionId,
158 ) -> tide::Result<&mut WorktreeState> {
159 let worktree = self
160 .worktrees
161 .get_mut(&worktree_id)
162 .ok_or_else(|| anyhow!("worktree not found"))?;
163
164 if worktree.host_connection_id == Some(connection_id)
165 || worktree.guest_connection_ids.contains_key(&connection_id)
166 {
167 Ok(worktree)
168 } else {
169 Err(anyhow!(
170 "{} is not a member of worktree {}",
171 connection_id,
172 worktree_id
173 ))?
174 }
175 }
176}
177
178trait MessageHandler<'a, M: proto::EnvelopedMessage> {
179 type Output: 'a + Send + Future<Output = tide::Result<()>>;
180
181 fn handle(
182 &self,
183 message: TypedEnvelope<M>,
184 rpc: &'a Arc<Peer>,
185 app_state: &'a Arc<AppState>,
186 ) -> Self::Output;
187}
188
189impl<'a, M, F, Fut> MessageHandler<'a, M> for F
190where
191 M: proto::EnvelopedMessage,
192 F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
193 Fut: 'a + Send + Future<Output = tide::Result<()>>,
194{
195 type Output = Fut;
196
197 fn handle(
198 &self,
199 message: TypedEnvelope<M>,
200 rpc: &'a Arc<Peer>,
201 app_state: &'a Arc<AppState>,
202 ) -> Self::Output {
203 (self)(message, rpc, app_state)
204 }
205}
206
207fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
208where
209 M: EnvelopedMessage,
210 H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
211{
212 let rpc = rpc.clone();
213 let handler = handler.clone();
214 let app_state = app_state.clone();
215 router.add_message_handler(move |message| {
216 let rpc = rpc.clone();
217 let handler = handler.clone();
218 let app_state = app_state.clone();
219 async move {
220 let sender_id = message.sender_id;
221 let message_id = message.message_id;
222 let start_time = Instant::now();
223 log::info!(
224 "RPC message received. id: {}.{}, type:{}",
225 sender_id,
226 message_id,
227 M::NAME
228 );
229 if let Err(err) = handler.handle(message, &rpc, &app_state).await {
230 log::error!("error handling message: {:?}", err);
231 } else {
232 log::info!(
233 "RPC message handled. id:{}.{}, duration:{:?}",
234 sender_id,
235 message_id,
236 start_time.elapsed()
237 );
238 }
239
240 Ok(())
241 }
242 });
243}
244
245pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
246 on_message(router, rpc, state, share_worktree);
247 on_message(router, rpc, state, join_worktree);
248 on_message(router, rpc, state, update_worktree);
249 on_message(router, rpc, state, close_worktree);
250 on_message(router, rpc, state, open_buffer);
251 on_message(router, rpc, state, close_buffer);
252 on_message(router, rpc, state, update_buffer);
253 on_message(router, rpc, state, buffer_saved);
254 on_message(router, rpc, state, save_buffer);
255}
256
257pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
258 let mut router = Router::new();
259 add_rpc_routes(&mut router, app.state(), rpc);
260 let router = Arc::new(router);
261
262 let rpc = rpc.clone();
263 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
264 let user_id = request.ext::<UserId>().copied();
265 let rpc = rpc.clone();
266 let router = router.clone();
267 async move {
268 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
269
270 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
271 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
272 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
273
274 if !upgrade_requested {
275 return Ok(Response::new(StatusCode::UpgradeRequired));
276 }
277
278 let header = match request.header("Sec-Websocket-Key") {
279 Some(h) => h.as_str(),
280 None => return Err(anyhow!("expected sec-websocket-key"))?,
281 };
282
283 let mut response = Response::new(StatusCode::SwitchingProtocols);
284 response.insert_header(UPGRADE, "websocket");
285 response.insert_header(CONNECTION, "Upgrade");
286 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
287 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
288 response.insert_header("Sec-Websocket-Version", "13");
289
290 let http_res: &mut tide::http::Response = response.as_mut();
291 let upgrade_receiver = http_res.recv_upgrade().await;
292 let addr = request.remote().unwrap_or("unknown").to_string();
293 let state = request.state().clone();
294 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0;
295 task::spawn(async move {
296 if let Some(stream) = upgrade_receiver.await {
297 let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
298 handle_connection(rpc, router, state, addr, stream, user_id).await;
299 }
300 });
301
302 Ok(response)
303 }
304 });
305}
306
307pub async fn handle_connection<Conn>(
308 rpc: Arc<Peer>,
309 router: Arc<Router>,
310 state: Arc<AppState>,
311 addr: String,
312 stream: Conn,
313 user_id: i32,
314) where
315 Conn: 'static
316 + futures::Sink<WebSocketMessage, Error = WebSocketError>
317 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
318 + Send
319 + Unpin,
320{
321 log::info!("accepted rpc connection: {:?}", addr);
322 let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
323 state
324 .rpc
325 .write()
326 .await
327 .add_connection(connection_id, user_id);
328
329 let handle_messages = async move {
330 handle_messages.await;
331 Ok(())
332 };
333
334 if let Err(e) = futures::try_join!(handle_messages, handle_io) {
335 log::error!("error handling rpc connection {:?} - {:?}", addr, e);
336 }
337
338 log::info!("closing connection to {:?}", addr);
339 if let Err(e) = rpc.sign_out(connection_id, &state).await {
340 log::error!("error signing out connection {:?} - {:?}", addr, e);
341 }
342}
343
344async fn share_worktree(
345 mut request: TypedEnvelope<proto::ShareWorktree>,
346 rpc: &Arc<Peer>,
347 state: &Arc<AppState>,
348) -> tide::Result<()> {
349 let mut state = state.rpc.write().await;
350 let worktree_id = state.next_worktree_id;
351 state.next_worktree_id += 1;
352 let access_token = random_token();
353 let worktree = request
354 .payload
355 .worktree
356 .as_mut()
357 .ok_or_else(|| anyhow!("missing worktree"))?;
358 let entries = mem::take(&mut worktree.entries)
359 .into_iter()
360 .map(|entry| (entry.id, entry))
361 .collect();
362 state.worktrees.insert(
363 worktree_id,
364 WorktreeState {
365 host_connection_id: Some(request.sender_id),
366 guest_connection_ids: Default::default(),
367 active_replica_ids: Default::default(),
368 access_token: access_token.clone(),
369 root_name: mem::take(&mut worktree.root_name),
370 entries,
371 },
372 );
373
374 rpc.respond(
375 request.receipt(),
376 proto::ShareWorktreeResponse {
377 worktree_id,
378 access_token,
379 },
380 )
381 .await?;
382 Ok(())
383}
384
385async fn join_worktree(
386 request: TypedEnvelope<proto::OpenWorktree>,
387 rpc: &Arc<Peer>,
388 state: &Arc<AppState>,
389) -> tide::Result<()> {
390 let worktree_id = request.payload.worktree_id;
391 let access_token = &request.payload.access_token;
392
393 let mut state = state.rpc.write().await;
394 if let Some((peer_replica_id, worktree)) =
395 state.join_worktree(request.sender_id, worktree_id, access_token)
396 {
397 let mut peers = Vec::new();
398 if let Some(host_connection_id) = worktree.host_connection_id {
399 peers.push(proto::Peer {
400 peer_id: host_connection_id.0,
401 replica_id: 0,
402 });
403 }
404 for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
405 if *peer_conn_id != request.sender_id {
406 peers.push(proto::Peer {
407 peer_id: peer_conn_id.0,
408 replica_id: *peer_replica_id as u32,
409 });
410 }
411 }
412
413 broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
414 rpc.send(
415 conn_id,
416 proto::AddPeer {
417 worktree_id,
418 peer: Some(proto::Peer {
419 peer_id: request.sender_id.0,
420 replica_id: peer_replica_id as u32,
421 }),
422 },
423 )
424 })
425 .await?;
426 rpc.respond(
427 request.receipt(),
428 proto::OpenWorktreeResponse {
429 worktree_id,
430 worktree: Some(proto::Worktree {
431 root_name: worktree.root_name.clone(),
432 entries: worktree.entries.values().cloned().collect(),
433 }),
434 replica_id: peer_replica_id as u32,
435 peers,
436 },
437 )
438 .await?;
439 } else {
440 rpc.respond(
441 request.receipt(),
442 proto::OpenWorktreeResponse {
443 worktree_id,
444 worktree: None,
445 replica_id: 0,
446 peers: Vec::new(),
447 },
448 )
449 .await?;
450 }
451
452 Ok(())
453}
454
455async fn update_worktree(
456 request: TypedEnvelope<proto::UpdateWorktree>,
457 rpc: &Arc<Peer>,
458 state: &Arc<AppState>,
459) -> tide::Result<()> {
460 {
461 let mut state = state.rpc.write().await;
462 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
463 for entry_id in &request.payload.removed_entries {
464 worktree.entries.remove(&entry_id);
465 }
466
467 for entry in &request.payload.updated_entries {
468 worktree.entries.insert(entry.id, entry.clone());
469 }
470 }
471
472 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
473 Ok(())
474}
475
476async fn close_worktree(
477 request: TypedEnvelope<proto::CloseWorktree>,
478 rpc: &Arc<Peer>,
479 state: &Arc<AppState>,
480) -> tide::Result<()> {
481 let connection_ids;
482 {
483 let mut state = state.rpc.write().await;
484 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
485 connection_ids = worktree.connection_ids();
486 if worktree.host_connection_id == Some(request.sender_id) {
487 worktree.host_connection_id = None;
488 } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
489 worktree.active_replica_ids.remove(&replica_id);
490 }
491 }
492
493 broadcast(request.sender_id, connection_ids, |conn_id| {
494 rpc.send(
495 conn_id,
496 proto::RemovePeer {
497 worktree_id: request.payload.worktree_id,
498 peer_id: request.sender_id.0,
499 },
500 )
501 })
502 .await?;
503
504 Ok(())
505}
506
507async fn open_buffer(
508 request: TypedEnvelope<proto::OpenBuffer>,
509 rpc: &Arc<Peer>,
510 state: &Arc<AppState>,
511) -> tide::Result<()> {
512 let receipt = request.receipt();
513 let worktree_id = request.payload.worktree_id;
514 let host_connection_id = state
515 .rpc
516 .read()
517 .await
518 .read_worktree(worktree_id, request.sender_id)?
519 .host_connection_id()?;
520
521 let response = rpc
522 .forward_request(request.sender_id, host_connection_id, request.payload)
523 .await?;
524 rpc.respond(receipt, response).await?;
525 Ok(())
526}
527
528async fn close_buffer(
529 request: TypedEnvelope<proto::CloseBuffer>,
530 rpc: &Arc<Peer>,
531 state: &Arc<AppState>,
532) -> tide::Result<()> {
533 let host_connection_id = state
534 .rpc
535 .read()
536 .await
537 .read_worktree(request.payload.worktree_id, request.sender_id)?
538 .host_connection_id()?;
539
540 rpc.forward_send(request.sender_id, host_connection_id, request.payload)
541 .await?;
542
543 Ok(())
544}
545
546async fn save_buffer(
547 request: TypedEnvelope<proto::SaveBuffer>,
548 rpc: &Arc<Peer>,
549 state: &Arc<AppState>,
550) -> tide::Result<()> {
551 let host;
552 let guests;
553 {
554 let state = state.rpc.read().await;
555 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
556 host = worktree.host_connection_id()?;
557 guests = worktree
558 .guest_connection_ids
559 .keys()
560 .copied()
561 .collect::<Vec<_>>();
562 }
563
564 let sender = request.sender_id;
565 let receipt = request.receipt();
566 let response = rpc
567 .forward_request(sender, host, request.payload.clone())
568 .await?;
569
570 broadcast(host, guests, |conn_id| {
571 let response = response.clone();
572 async move {
573 if conn_id == sender {
574 rpc.respond(receipt, response).await
575 } else {
576 rpc.forward_send(host, conn_id, response).await
577 }
578 }
579 })
580 .await?;
581
582 Ok(())
583}
584
585async fn update_buffer(
586 request: TypedEnvelope<proto::UpdateBuffer>,
587 rpc: &Arc<Peer>,
588 state: &Arc<AppState>,
589) -> tide::Result<()> {
590 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
591}
592
593async fn buffer_saved(
594 request: TypedEnvelope<proto::BufferSaved>,
595 rpc: &Arc<Peer>,
596 state: &Arc<AppState>,
597) -> tide::Result<()> {
598 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
599}
600
601async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
602 worktree_id: u64,
603 request: TypedEnvelope<T>,
604 rpc: &Arc<Peer>,
605 state: &Arc<AppState>,
606) -> tide::Result<()> {
607 let connection_ids = state
608 .rpc
609 .read()
610 .await
611 .read_worktree(worktree_id, request.sender_id)?
612 .connection_ids();
613
614 broadcast(request.sender_id, connection_ids, |conn_id| {
615 rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
616 })
617 .await?;
618
619 Ok(())
620}
621
622pub async fn broadcast<F, T>(
623 sender_id: ConnectionId,
624 receiver_ids: Vec<ConnectionId>,
625 mut f: F,
626) -> anyhow::Result<()>
627where
628 F: FnMut(ConnectionId) -> T,
629 T: Future<Output = anyhow::Result<()>>,
630{
631 let futures = receiver_ids
632 .into_iter()
633 .filter(|id| *id != sender_id)
634 .map(|id| f(id));
635 futures::future::try_join_all(futures).await?;
636 Ok(())
637}
638
639fn header_contains_ignore_case<T>(
640 request: &tide::Request<T>,
641 header_name: HeaderName,
642 value: &str,
643) -> bool {
644 request
645 .header(header_name)
646 .map(|h| {
647 h.as_str()
648 .split(',')
649 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
650 })
651 .unwrap_or(false)
652}