1use super::{
2 auth::{self, PeerExt as _},
3 db::UserId,
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::task;
8use async_tungstenite::{
9 tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
10 WebSocketStream,
11};
12use sha1::{Digest as _, Sha1};
13use std::{
14 collections::{HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use zrpc::{
27 auth::random_token,
28 proto::{self, EnvelopedMessage},
29 ConnectionId, Peer, Router, TypedEnvelope,
30};
31
32type ReplicaId = u16;
33
34#[derive(Default)]
35pub struct State {
36 connections: HashMap<ConnectionId, ConnectionState>,
37 pub worktrees: HashMap<u64, WorktreeState>,
38 next_worktree_id: u64,
39}
40
41struct ConnectionState {
42 _user_id: UserId,
43 worktrees: HashSet<u64>,
44}
45
46pub struct WorktreeState {
47 host_connection_id: Option<ConnectionId>,
48 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
49 active_replica_ids: HashSet<ReplicaId>,
50 access_token: String,
51 root_name: String,
52 entries: HashMap<u64, proto::Entry>,
53}
54
55impl WorktreeState {
56 pub fn connection_ids(&self) -> Vec<ConnectionId> {
57 self.guest_connection_ids
58 .keys()
59 .copied()
60 .chain(self.host_connection_id)
61 .collect()
62 }
63
64 fn host_connection_id(&self) -> tide::Result<ConnectionId> {
65 Ok(self
66 .host_connection_id
67 .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
68 }
69}
70
71impl State {
72 // Add a new connection associated with a given user.
73 pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) {
74 self.connections.insert(
75 connection_id,
76 ConnectionState {
77 _user_id,
78 worktrees: Default::default(),
79 },
80 );
81 }
82
83 // Remove the given connection and its association with any worktrees.
84 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
85 let mut worktree_ids = Vec::new();
86 if let Some(connection_state) = self.connections.remove(&connection_id) {
87 for worktree_id in connection_state.worktrees {
88 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
89 if worktree.host_connection_id == Some(connection_id) {
90 worktree_ids.push(worktree_id);
91 } else if let Some(replica_id) =
92 worktree.guest_connection_ids.remove(&connection_id)
93 {
94 worktree.active_replica_ids.remove(&replica_id);
95 worktree_ids.push(worktree_id);
96 }
97 }
98 }
99 }
100 worktree_ids
101 }
102
103 // Add the given connection as a guest of the given worktree
104 pub fn join_worktree(
105 &mut self,
106 connection_id: ConnectionId,
107 worktree_id: u64,
108 access_token: &str,
109 ) -> Option<(ReplicaId, &WorktreeState)> {
110 if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
111 if access_token == worktree_state.access_token {
112 if let Some(connection_state) = self.connections.get_mut(&connection_id) {
113 connection_state.worktrees.insert(worktree_id);
114 }
115
116 let mut replica_id = 1;
117 while worktree_state.active_replica_ids.contains(&replica_id) {
118 replica_id += 1;
119 }
120 worktree_state.active_replica_ids.insert(replica_id);
121 worktree_state
122 .guest_connection_ids
123 .insert(connection_id, replica_id);
124 Some((replica_id, worktree_state))
125 } else {
126 None
127 }
128 } else {
129 None
130 }
131 }
132
133 fn read_worktree(
134 &self,
135 worktree_id: u64,
136 connection_id: ConnectionId,
137 ) -> tide::Result<&WorktreeState> {
138 let worktree = self
139 .worktrees
140 .get(&worktree_id)
141 .ok_or_else(|| anyhow!("worktree not found"))?;
142
143 if worktree.host_connection_id == Some(connection_id)
144 || worktree.guest_connection_ids.contains_key(&connection_id)
145 {
146 Ok(worktree)
147 } else {
148 Err(anyhow!(
149 "{} is not a member of worktree {}",
150 connection_id,
151 worktree_id
152 ))?
153 }
154 }
155
156 fn write_worktree(
157 &mut self,
158 worktree_id: u64,
159 connection_id: ConnectionId,
160 ) -> tide::Result<&mut WorktreeState> {
161 let worktree = self
162 .worktrees
163 .get_mut(&worktree_id)
164 .ok_or_else(|| anyhow!("worktree not found"))?;
165
166 if worktree.host_connection_id == Some(connection_id)
167 || worktree.guest_connection_ids.contains_key(&connection_id)
168 {
169 Ok(worktree)
170 } else {
171 Err(anyhow!(
172 "{} is not a member of worktree {}",
173 connection_id,
174 worktree_id
175 ))?
176 }
177 }
178}
179
180trait MessageHandler<'a, M: proto::EnvelopedMessage> {
181 type Output: 'a + Send + Future<Output = tide::Result<()>>;
182
183 fn handle(
184 &self,
185 message: TypedEnvelope<M>,
186 rpc: &'a Arc<Peer>,
187 app_state: &'a Arc<AppState>,
188 ) -> Self::Output;
189}
190
191impl<'a, M, F, Fut> MessageHandler<'a, M> for F
192where
193 M: proto::EnvelopedMessage,
194 F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
195 Fut: 'a + Send + Future<Output = tide::Result<()>>,
196{
197 type Output = Fut;
198
199 fn handle(
200 &self,
201 message: TypedEnvelope<M>,
202 rpc: &'a Arc<Peer>,
203 app_state: &'a Arc<AppState>,
204 ) -> Self::Output {
205 (self)(message, rpc, app_state)
206 }
207}
208
209fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
210where
211 M: EnvelopedMessage,
212 H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
213{
214 let rpc = rpc.clone();
215 let handler = handler.clone();
216 let app_state = app_state.clone();
217 router.add_message_handler(move |message| {
218 let rpc = rpc.clone();
219 let handler = handler.clone();
220 let app_state = app_state.clone();
221 async move {
222 let sender_id = message.sender_id;
223 let message_id = message.message_id;
224 let start_time = Instant::now();
225 log::info!(
226 "RPC message received. id: {}.{}, type:{}",
227 sender_id,
228 message_id,
229 M::NAME
230 );
231 if let Err(err) = handler.handle(message, &rpc, &app_state).await {
232 log::error!("error handling message: {:?}", err);
233 } else {
234 log::info!(
235 "RPC message handled. id:{}.{}, duration:{:?}",
236 sender_id,
237 message_id,
238 start_time.elapsed()
239 );
240 }
241
242 Ok(())
243 }
244 });
245}
246
247pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
248 on_message(router, rpc, state, share_worktree);
249 on_message(router, rpc, state, join_worktree);
250 on_message(router, rpc, state, update_worktree);
251 on_message(router, rpc, state, close_worktree);
252 on_message(router, rpc, state, open_buffer);
253 on_message(router, rpc, state, close_buffer);
254 on_message(router, rpc, state, update_buffer);
255 on_message(router, rpc, state, buffer_saved);
256 on_message(router, rpc, state, save_buffer);
257}
258
259pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
260 let mut router = Router::new();
261 add_rpc_routes(&mut router, app.state(), rpc);
262 let router = Arc::new(router);
263
264 let rpc = rpc.clone();
265 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
266 let user_id = request.ext::<UserId>().copied();
267 let rpc = rpc.clone();
268 let router = router.clone();
269 async move {
270 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
271
272 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
273 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
274 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
275
276 if !upgrade_requested {
277 return Ok(Response::new(StatusCode::UpgradeRequired));
278 }
279
280 let header = match request.header("Sec-Websocket-Key") {
281 Some(h) => h.as_str(),
282 None => return Err(anyhow!("expected sec-websocket-key"))?,
283 };
284
285 let mut response = Response::new(StatusCode::SwitchingProtocols);
286 response.insert_header(UPGRADE, "websocket");
287 response.insert_header(CONNECTION, "Upgrade");
288 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
289 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
290 response.insert_header("Sec-Websocket-Version", "13");
291
292 let http_res: &mut tide::http::Response = response.as_mut();
293 let upgrade_receiver = http_res.recv_upgrade().await;
294 let addr = request.remote().unwrap_or("unknown").to_string();
295 let state = request.state().clone();
296 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
297 task::spawn(async move {
298 if let Some(stream) = upgrade_receiver.await {
299 let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
300 handle_connection(rpc, router, state, addr, stream, user_id).await;
301 }
302 });
303
304 Ok(response)
305 }
306 });
307}
308
309pub async fn handle_connection<Conn>(
310 rpc: Arc<Peer>,
311 router: Arc<Router>,
312 state: Arc<AppState>,
313 addr: String,
314 stream: Conn,
315 user_id: UserId,
316) where
317 Conn: 'static
318 + futures::Sink<WebSocketMessage, Error = WebSocketError>
319 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
320 + Send
321 + Unpin,
322{
323 log::info!("accepted rpc connection: {:?}", addr);
324 let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
325 state
326 .rpc
327 .write()
328 .await
329 .add_connection(connection_id, user_id);
330
331 let handle_messages = async move {
332 handle_messages.await;
333 Ok(())
334 };
335
336 if let Err(e) = futures::try_join!(handle_messages, handle_io) {
337 log::error!("error handling rpc connection {:?} - {:?}", addr, e);
338 }
339
340 log::info!("closing connection to {:?}", addr);
341 if let Err(e) = rpc.sign_out(connection_id, &state).await {
342 log::error!("error signing out connection {:?} - {:?}", addr, e);
343 }
344}
345
346async fn share_worktree(
347 mut request: TypedEnvelope<proto::ShareWorktree>,
348 rpc: &Arc<Peer>,
349 state: &Arc<AppState>,
350) -> tide::Result<()> {
351 let mut state = state.rpc.write().await;
352 let worktree_id = state.next_worktree_id;
353 state.next_worktree_id += 1;
354 let access_token = random_token();
355 let worktree = request
356 .payload
357 .worktree
358 .as_mut()
359 .ok_or_else(|| anyhow!("missing worktree"))?;
360 let entries = mem::take(&mut worktree.entries)
361 .into_iter()
362 .map(|entry| (entry.id, entry))
363 .collect();
364 state.worktrees.insert(
365 worktree_id,
366 WorktreeState {
367 host_connection_id: Some(request.sender_id),
368 guest_connection_ids: Default::default(),
369 active_replica_ids: Default::default(),
370 access_token: access_token.clone(),
371 root_name: mem::take(&mut worktree.root_name),
372 entries,
373 },
374 );
375
376 rpc.respond(
377 request.receipt(),
378 proto::ShareWorktreeResponse {
379 worktree_id,
380 access_token,
381 },
382 )
383 .await?;
384 Ok(())
385}
386
387async fn join_worktree(
388 request: TypedEnvelope<proto::OpenWorktree>,
389 rpc: &Arc<Peer>,
390 state: &Arc<AppState>,
391) -> tide::Result<()> {
392 let worktree_id = request.payload.worktree_id;
393 let access_token = &request.payload.access_token;
394
395 let mut state = state.rpc.write().await;
396 if let Some((peer_replica_id, worktree)) =
397 state.join_worktree(request.sender_id, worktree_id, access_token)
398 {
399 let mut peers = Vec::new();
400 if let Some(host_connection_id) = worktree.host_connection_id {
401 peers.push(proto::Peer {
402 peer_id: host_connection_id.0,
403 replica_id: 0,
404 });
405 }
406 for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
407 if *peer_conn_id != request.sender_id {
408 peers.push(proto::Peer {
409 peer_id: peer_conn_id.0,
410 replica_id: *peer_replica_id as u32,
411 });
412 }
413 }
414
415 broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
416 rpc.send(
417 conn_id,
418 proto::AddPeer {
419 worktree_id,
420 peer: Some(proto::Peer {
421 peer_id: request.sender_id.0,
422 replica_id: peer_replica_id as u32,
423 }),
424 },
425 )
426 })
427 .await?;
428 rpc.respond(
429 request.receipt(),
430 proto::OpenWorktreeResponse {
431 worktree_id,
432 worktree: Some(proto::Worktree {
433 root_name: worktree.root_name.clone(),
434 entries: worktree.entries.values().cloned().collect(),
435 }),
436 replica_id: peer_replica_id as u32,
437 peers,
438 },
439 )
440 .await?;
441 } else {
442 rpc.respond(
443 request.receipt(),
444 proto::OpenWorktreeResponse {
445 worktree_id,
446 worktree: None,
447 replica_id: 0,
448 peers: Vec::new(),
449 },
450 )
451 .await?;
452 }
453
454 Ok(())
455}
456
457async fn update_worktree(
458 request: TypedEnvelope<proto::UpdateWorktree>,
459 rpc: &Arc<Peer>,
460 state: &Arc<AppState>,
461) -> tide::Result<()> {
462 {
463 let mut state = state.rpc.write().await;
464 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
465 for entry_id in &request.payload.removed_entries {
466 worktree.entries.remove(&entry_id);
467 }
468
469 for entry in &request.payload.updated_entries {
470 worktree.entries.insert(entry.id, entry.clone());
471 }
472 }
473
474 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
475 Ok(())
476}
477
478async fn close_worktree(
479 request: TypedEnvelope<proto::CloseWorktree>,
480 rpc: &Arc<Peer>,
481 state: &Arc<AppState>,
482) -> tide::Result<()> {
483 let connection_ids;
484 {
485 let mut state = state.rpc.write().await;
486 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
487 connection_ids = worktree.connection_ids();
488 if worktree.host_connection_id == Some(request.sender_id) {
489 worktree.host_connection_id = None;
490 } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
491 worktree.active_replica_ids.remove(&replica_id);
492 }
493 }
494
495 broadcast(request.sender_id, connection_ids, |conn_id| {
496 rpc.send(
497 conn_id,
498 proto::RemovePeer {
499 worktree_id: request.payload.worktree_id,
500 peer_id: request.sender_id.0,
501 },
502 )
503 })
504 .await?;
505
506 Ok(())
507}
508
509async fn open_buffer(
510 request: TypedEnvelope<proto::OpenBuffer>,
511 rpc: &Arc<Peer>,
512 state: &Arc<AppState>,
513) -> tide::Result<()> {
514 let receipt = request.receipt();
515 let worktree_id = request.payload.worktree_id;
516 let host_connection_id = state
517 .rpc
518 .read()
519 .await
520 .read_worktree(worktree_id, request.sender_id)?
521 .host_connection_id()?;
522
523 let response = rpc
524 .forward_request(request.sender_id, host_connection_id, request.payload)
525 .await?;
526 rpc.respond(receipt, response).await?;
527 Ok(())
528}
529
530async fn close_buffer(
531 request: TypedEnvelope<proto::CloseBuffer>,
532 rpc: &Arc<Peer>,
533 state: &Arc<AppState>,
534) -> tide::Result<()> {
535 let host_connection_id = state
536 .rpc
537 .read()
538 .await
539 .read_worktree(request.payload.worktree_id, request.sender_id)?
540 .host_connection_id()?;
541
542 rpc.forward_send(request.sender_id, host_connection_id, request.payload)
543 .await?;
544
545 Ok(())
546}
547
548async fn save_buffer(
549 request: TypedEnvelope<proto::SaveBuffer>,
550 rpc: &Arc<Peer>,
551 state: &Arc<AppState>,
552) -> tide::Result<()> {
553 let host;
554 let guests;
555 {
556 let state = state.rpc.read().await;
557 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
558 host = worktree.host_connection_id()?;
559 guests = worktree
560 .guest_connection_ids
561 .keys()
562 .copied()
563 .collect::<Vec<_>>();
564 }
565
566 let sender = request.sender_id;
567 let receipt = request.receipt();
568 let response = rpc
569 .forward_request(sender, host, request.payload.clone())
570 .await?;
571
572 broadcast(host, guests, |conn_id| {
573 let response = response.clone();
574 async move {
575 if conn_id == sender {
576 rpc.respond(receipt, response).await
577 } else {
578 rpc.forward_send(host, conn_id, response).await
579 }
580 }
581 })
582 .await?;
583
584 Ok(())
585}
586
587async fn update_buffer(
588 request: TypedEnvelope<proto::UpdateBuffer>,
589 rpc: &Arc<Peer>,
590 state: &Arc<AppState>,
591) -> tide::Result<()> {
592 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
593}
594
595async fn buffer_saved(
596 request: TypedEnvelope<proto::BufferSaved>,
597 rpc: &Arc<Peer>,
598 state: &Arc<AppState>,
599) -> tide::Result<()> {
600 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
601}
602
603async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
604 worktree_id: u64,
605 request: TypedEnvelope<T>,
606 rpc: &Arc<Peer>,
607 state: &Arc<AppState>,
608) -> tide::Result<()> {
609 let connection_ids = state
610 .rpc
611 .read()
612 .await
613 .read_worktree(worktree_id, request.sender_id)?
614 .connection_ids();
615
616 broadcast(request.sender_id, connection_ids, |conn_id| {
617 rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
618 })
619 .await?;
620
621 Ok(())
622}
623
624pub async fn broadcast<F, T>(
625 sender_id: ConnectionId,
626 receiver_ids: Vec<ConnectionId>,
627 mut f: F,
628) -> anyhow::Result<()>
629where
630 F: FnMut(ConnectionId) -> T,
631 T: Future<Output = anyhow::Result<()>>,
632{
633 let futures = receiver_ids
634 .into_iter()
635 .filter(|id| *id != sender_id)
636 .map(|id| f(id));
637 futures::future::try_join_all(futures).await?;
638 Ok(())
639}
640
641fn header_contains_ignore_case<T>(
642 request: &tide::Request<T>,
643 header_name: HeaderName,
644 value: &str,
645) -> bool {
646 request
647 .header(header_name)
648 .map(|h| {
649 h.as_str()
650 .split(',')
651 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
652 })
653 .unwrap_or(false)
654}