1use super::{
2 auth::{self, PeerExt as _},
3 db::{ChannelId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::task;
8use async_tungstenite::{
9 tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
10 WebSocketStream,
11};
12use sha1::{Digest as _, Sha1};
13use std::{
14 collections::{HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use zrpc::{
27 auth::random_token,
28 proto::{self, EnvelopedMessage},
29 ConnectionId, Peer, Router, TypedEnvelope,
30};
31
32type ReplicaId = u16;
33
34#[derive(Default)]
35pub struct State {
36 connections: HashMap<ConnectionId, ConnectionState>,
37 pub worktrees: HashMap<u64, WorktreeState>,
38 next_worktree_id: u64,
39}
40
41struct ConnectionState {
42 user_id: UserId,
43 worktrees: HashSet<u64>,
44}
45
46pub struct WorktreeState {
47 host_connection_id: Option<ConnectionId>,
48 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
49 active_replica_ids: HashSet<ReplicaId>,
50 access_token: String,
51 root_name: String,
52 entries: HashMap<u64, proto::Entry>,
53}
54
55impl WorktreeState {
56 pub fn connection_ids(&self) -> Vec<ConnectionId> {
57 self.guest_connection_ids
58 .keys()
59 .copied()
60 .chain(self.host_connection_id)
61 .collect()
62 }
63
64 fn host_connection_id(&self) -> tide::Result<ConnectionId> {
65 Ok(self
66 .host_connection_id
67 .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
68 }
69}
70
71impl State {
72 // Add a new connection associated with a given user.
73 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
74 self.connections.insert(
75 connection_id,
76 ConnectionState {
77 user_id,
78 worktrees: Default::default(),
79 },
80 );
81 }
82
83 // Remove the given connection and its association with any worktrees.
84 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
85 let mut worktree_ids = Vec::new();
86 if let Some(connection_state) = self.connections.remove(&connection_id) {
87 for worktree_id in connection_state.worktrees {
88 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
89 if worktree.host_connection_id == Some(connection_id) {
90 worktree_ids.push(worktree_id);
91 } else if let Some(replica_id) =
92 worktree.guest_connection_ids.remove(&connection_id)
93 {
94 worktree.active_replica_ids.remove(&replica_id);
95 worktree_ids.push(worktree_id);
96 }
97 }
98 }
99 }
100 worktree_ids
101 }
102
103 // Add the given connection as a guest of the given worktree
104 pub fn join_worktree(
105 &mut self,
106 connection_id: ConnectionId,
107 worktree_id: u64,
108 access_token: &str,
109 ) -> Option<(ReplicaId, &WorktreeState)> {
110 if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
111 if access_token == worktree_state.access_token {
112 if let Some(connection_state) = self.connections.get_mut(&connection_id) {
113 connection_state.worktrees.insert(worktree_id);
114 }
115
116 let mut replica_id = 1;
117 while worktree_state.active_replica_ids.contains(&replica_id) {
118 replica_id += 1;
119 }
120 worktree_state.active_replica_ids.insert(replica_id);
121 worktree_state
122 .guest_connection_ids
123 .insert(connection_id, replica_id);
124 Some((replica_id, worktree_state))
125 } else {
126 None
127 }
128 } else {
129 None
130 }
131 }
132
133 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
134 Ok(self
135 .connections
136 .get(&connection_id)
137 .ok_or_else(|| anyhow!("unknown connection"))?
138 .user_id)
139 }
140
141 fn read_worktree(
142 &self,
143 worktree_id: u64,
144 connection_id: ConnectionId,
145 ) -> tide::Result<&WorktreeState> {
146 let worktree = self
147 .worktrees
148 .get(&worktree_id)
149 .ok_or_else(|| anyhow!("worktree not found"))?;
150
151 if worktree.host_connection_id == Some(connection_id)
152 || worktree.guest_connection_ids.contains_key(&connection_id)
153 {
154 Ok(worktree)
155 } else {
156 Err(anyhow!(
157 "{} is not a member of worktree {}",
158 connection_id,
159 worktree_id
160 ))?
161 }
162 }
163
164 fn write_worktree(
165 &mut self,
166 worktree_id: u64,
167 connection_id: ConnectionId,
168 ) -> tide::Result<&mut WorktreeState> {
169 let worktree = self
170 .worktrees
171 .get_mut(&worktree_id)
172 .ok_or_else(|| anyhow!("worktree not found"))?;
173
174 if worktree.host_connection_id == Some(connection_id)
175 || worktree.guest_connection_ids.contains_key(&connection_id)
176 {
177 Ok(worktree)
178 } else {
179 Err(anyhow!(
180 "{} is not a member of worktree {}",
181 connection_id,
182 worktree_id
183 ))?
184 }
185 }
186}
187
188trait MessageHandler<'a, M: proto::EnvelopedMessage> {
189 type Output: 'a + Send + Future<Output = tide::Result<()>>;
190
191 fn handle(
192 &self,
193 message: TypedEnvelope<M>,
194 rpc: &'a Arc<Peer>,
195 app_state: &'a Arc<AppState>,
196 ) -> Self::Output;
197}
198
199impl<'a, M, F, Fut> MessageHandler<'a, M> for F
200where
201 M: proto::EnvelopedMessage,
202 F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
203 Fut: 'a + Send + Future<Output = tide::Result<()>>,
204{
205 type Output = Fut;
206
207 fn handle(
208 &self,
209 message: TypedEnvelope<M>,
210 rpc: &'a Arc<Peer>,
211 app_state: &'a Arc<AppState>,
212 ) -> Self::Output {
213 (self)(message, rpc, app_state)
214 }
215}
216
217fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
218where
219 M: EnvelopedMessage,
220 H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
221{
222 let rpc = rpc.clone();
223 let handler = handler.clone();
224 let app_state = app_state.clone();
225 router.add_message_handler(move |message| {
226 let rpc = rpc.clone();
227 let handler = handler.clone();
228 let app_state = app_state.clone();
229 async move {
230 let sender_id = message.sender_id;
231 let message_id = message.message_id;
232 let start_time = Instant::now();
233 log::info!(
234 "RPC message received. id: {}.{}, type:{}",
235 sender_id,
236 message_id,
237 M::NAME
238 );
239 if let Err(err) = handler.handle(message, &rpc, &app_state).await {
240 log::error!("error handling message: {:?}", err);
241 } else {
242 log::info!(
243 "RPC message handled. id:{}.{}, duration:{:?}",
244 sender_id,
245 message_id,
246 start_time.elapsed()
247 );
248 }
249
250 Ok(())
251 }
252 });
253}
254
255pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
256 on_message(router, rpc, state, share_worktree);
257 on_message(router, rpc, state, join_worktree);
258 on_message(router, rpc, state, update_worktree);
259 on_message(router, rpc, state, close_worktree);
260 on_message(router, rpc, state, open_buffer);
261 on_message(router, rpc, state, close_buffer);
262 on_message(router, rpc, state, update_buffer);
263 on_message(router, rpc, state, buffer_saved);
264 on_message(router, rpc, state, save_buffer);
265 on_message(router, rpc, state, get_channels);
266 on_message(router, rpc, state, join_channel);
267}
268
269pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
270 let mut router = Router::new();
271 add_rpc_routes(&mut router, app.state(), rpc);
272 let router = Arc::new(router);
273
274 let rpc = rpc.clone();
275 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
276 let user_id = request.ext::<UserId>().copied();
277 let rpc = rpc.clone();
278 let router = router.clone();
279 async move {
280 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
281
282 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
283 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
284 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
285
286 if !upgrade_requested {
287 return Ok(Response::new(StatusCode::UpgradeRequired));
288 }
289
290 let header = match request.header("Sec-Websocket-Key") {
291 Some(h) => h.as_str(),
292 None => return Err(anyhow!("expected sec-websocket-key"))?,
293 };
294
295 let mut response = Response::new(StatusCode::SwitchingProtocols);
296 response.insert_header(UPGRADE, "websocket");
297 response.insert_header(CONNECTION, "Upgrade");
298 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
299 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
300 response.insert_header("Sec-Websocket-Version", "13");
301
302 let http_res: &mut tide::http::Response = response.as_mut();
303 let upgrade_receiver = http_res.recv_upgrade().await;
304 let addr = request.remote().unwrap_or("unknown").to_string();
305 let state = request.state().clone();
306 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
307 task::spawn(async move {
308 if let Some(stream) = upgrade_receiver.await {
309 let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
310 handle_connection(rpc, router, state, addr, stream, user_id).await;
311 }
312 });
313
314 Ok(response)
315 }
316 });
317}
318
319pub async fn handle_connection<Conn>(
320 rpc: Arc<Peer>,
321 router: Arc<Router>,
322 state: Arc<AppState>,
323 addr: String,
324 stream: Conn,
325 user_id: UserId,
326) where
327 Conn: 'static
328 + futures::Sink<WebSocketMessage, Error = WebSocketError>
329 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
330 + Send
331 + Unpin,
332{
333 log::info!("accepted rpc connection: {:?}", addr);
334 let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
335 state
336 .rpc
337 .write()
338 .await
339 .add_connection(connection_id, user_id);
340
341 let handle_messages = async move {
342 handle_messages.await;
343 Ok(())
344 };
345
346 if let Err(e) = futures::try_join!(handle_messages, handle_io) {
347 log::error!("error handling rpc connection {:?} - {:?}", addr, e);
348 }
349
350 log::info!("closing connection to {:?}", addr);
351 if let Err(e) = rpc.sign_out(connection_id, &state).await {
352 log::error!("error signing out connection {:?} - {:?}", addr, e);
353 }
354}
355
356async fn share_worktree(
357 mut request: TypedEnvelope<proto::ShareWorktree>,
358 rpc: &Arc<Peer>,
359 state: &Arc<AppState>,
360) -> tide::Result<()> {
361 let mut state = state.rpc.write().await;
362 let worktree_id = state.next_worktree_id;
363 state.next_worktree_id += 1;
364 let access_token = random_token();
365 let worktree = request
366 .payload
367 .worktree
368 .as_mut()
369 .ok_or_else(|| anyhow!("missing worktree"))?;
370 let entries = mem::take(&mut worktree.entries)
371 .into_iter()
372 .map(|entry| (entry.id, entry))
373 .collect();
374 state.worktrees.insert(
375 worktree_id,
376 WorktreeState {
377 host_connection_id: Some(request.sender_id),
378 guest_connection_ids: Default::default(),
379 active_replica_ids: Default::default(),
380 access_token: access_token.clone(),
381 root_name: mem::take(&mut worktree.root_name),
382 entries,
383 },
384 );
385
386 rpc.respond(
387 request.receipt(),
388 proto::ShareWorktreeResponse {
389 worktree_id,
390 access_token,
391 },
392 )
393 .await?;
394 Ok(())
395}
396
397async fn join_worktree(
398 request: TypedEnvelope<proto::OpenWorktree>,
399 rpc: &Arc<Peer>,
400 state: &Arc<AppState>,
401) -> tide::Result<()> {
402 let worktree_id = request.payload.worktree_id;
403 let access_token = &request.payload.access_token;
404
405 let mut state = state.rpc.write().await;
406 if let Some((peer_replica_id, worktree)) =
407 state.join_worktree(request.sender_id, worktree_id, access_token)
408 {
409 let mut peers = Vec::new();
410 if let Some(host_connection_id) = worktree.host_connection_id {
411 peers.push(proto::Peer {
412 peer_id: host_connection_id.0,
413 replica_id: 0,
414 });
415 }
416 for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
417 if *peer_conn_id != request.sender_id {
418 peers.push(proto::Peer {
419 peer_id: peer_conn_id.0,
420 replica_id: *peer_replica_id as u32,
421 });
422 }
423 }
424
425 broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
426 rpc.send(
427 conn_id,
428 proto::AddPeer {
429 worktree_id,
430 peer: Some(proto::Peer {
431 peer_id: request.sender_id.0,
432 replica_id: peer_replica_id as u32,
433 }),
434 },
435 )
436 })
437 .await?;
438 rpc.respond(
439 request.receipt(),
440 proto::OpenWorktreeResponse {
441 worktree_id,
442 worktree: Some(proto::Worktree {
443 root_name: worktree.root_name.clone(),
444 entries: worktree.entries.values().cloned().collect(),
445 }),
446 replica_id: peer_replica_id as u32,
447 peers,
448 },
449 )
450 .await?;
451 } else {
452 rpc.respond(
453 request.receipt(),
454 proto::OpenWorktreeResponse {
455 worktree_id,
456 worktree: None,
457 replica_id: 0,
458 peers: Vec::new(),
459 },
460 )
461 .await?;
462 }
463
464 Ok(())
465}
466
467async fn update_worktree(
468 request: TypedEnvelope<proto::UpdateWorktree>,
469 rpc: &Arc<Peer>,
470 state: &Arc<AppState>,
471) -> tide::Result<()> {
472 {
473 let mut state = state.rpc.write().await;
474 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
475 for entry_id in &request.payload.removed_entries {
476 worktree.entries.remove(&entry_id);
477 }
478
479 for entry in &request.payload.updated_entries {
480 worktree.entries.insert(entry.id, entry.clone());
481 }
482 }
483
484 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
485 Ok(())
486}
487
488async fn close_worktree(
489 request: TypedEnvelope<proto::CloseWorktree>,
490 rpc: &Arc<Peer>,
491 state: &Arc<AppState>,
492) -> tide::Result<()> {
493 let connection_ids;
494 {
495 let mut state = state.rpc.write().await;
496 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
497 connection_ids = worktree.connection_ids();
498 if worktree.host_connection_id == Some(request.sender_id) {
499 worktree.host_connection_id = None;
500 } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
501 worktree.active_replica_ids.remove(&replica_id);
502 }
503 }
504
505 broadcast(request.sender_id, connection_ids, |conn_id| {
506 rpc.send(
507 conn_id,
508 proto::RemovePeer {
509 worktree_id: request.payload.worktree_id,
510 peer_id: request.sender_id.0,
511 },
512 )
513 })
514 .await?;
515
516 Ok(())
517}
518
519async fn open_buffer(
520 request: TypedEnvelope<proto::OpenBuffer>,
521 rpc: &Arc<Peer>,
522 state: &Arc<AppState>,
523) -> tide::Result<()> {
524 let receipt = request.receipt();
525 let worktree_id = request.payload.worktree_id;
526 let host_connection_id = state
527 .rpc
528 .read()
529 .await
530 .read_worktree(worktree_id, request.sender_id)?
531 .host_connection_id()?;
532
533 let response = rpc
534 .forward_request(request.sender_id, host_connection_id, request.payload)
535 .await?;
536 rpc.respond(receipt, response).await?;
537 Ok(())
538}
539
540async fn close_buffer(
541 request: TypedEnvelope<proto::CloseBuffer>,
542 rpc: &Arc<Peer>,
543 state: &Arc<AppState>,
544) -> tide::Result<()> {
545 let host_connection_id = state
546 .rpc
547 .read()
548 .await
549 .read_worktree(request.payload.worktree_id, request.sender_id)?
550 .host_connection_id()?;
551
552 rpc.forward_send(request.sender_id, host_connection_id, request.payload)
553 .await?;
554
555 Ok(())
556}
557
558async fn save_buffer(
559 request: TypedEnvelope<proto::SaveBuffer>,
560 rpc: &Arc<Peer>,
561 state: &Arc<AppState>,
562) -> tide::Result<()> {
563 let host;
564 let guests;
565 {
566 let state = state.rpc.read().await;
567 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
568 host = worktree.host_connection_id()?;
569 guests = worktree
570 .guest_connection_ids
571 .keys()
572 .copied()
573 .collect::<Vec<_>>();
574 }
575
576 let sender = request.sender_id;
577 let receipt = request.receipt();
578 let response = rpc
579 .forward_request(sender, host, request.payload.clone())
580 .await?;
581
582 broadcast(host, guests, |conn_id| {
583 let response = response.clone();
584 async move {
585 if conn_id == sender {
586 rpc.respond(receipt, response).await
587 } else {
588 rpc.forward_send(host, conn_id, response).await
589 }
590 }
591 })
592 .await?;
593
594 Ok(())
595}
596
597async fn update_buffer(
598 request: TypedEnvelope<proto::UpdateBuffer>,
599 rpc: &Arc<Peer>,
600 state: &Arc<AppState>,
601) -> tide::Result<()> {
602 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
603}
604
605async fn buffer_saved(
606 request: TypedEnvelope<proto::BufferSaved>,
607 rpc: &Arc<Peer>,
608 state: &Arc<AppState>,
609) -> tide::Result<()> {
610 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
611}
612
613async fn get_channels(
614 request: TypedEnvelope<proto::GetChannels>,
615 rpc: &Arc<Peer>,
616 state: &Arc<AppState>,
617) -> tide::Result<()> {
618 let user_id = state
619 .rpc
620 .read()
621 .await
622 .user_id_for_connection(request.sender_id)?;
623 let channels = state.db.get_channels_for_user(user_id).await?;
624 rpc.respond(
625 request.receipt(),
626 proto::GetChannelsResponse {
627 channels: channels
628 .into_iter()
629 .map(|chan| proto::Channel {
630 id: chan.id().0 as u64,
631 name: chan.name,
632 })
633 .collect(),
634 },
635 )
636 .await?;
637 Ok(())
638}
639
640async fn join_channel(
641 request: TypedEnvelope<proto::JoinChannel>,
642 rpc: &Arc<Peer>,
643 state: &Arc<AppState>,
644) -> tide::Result<()> {
645 let user_id = state
646 .rpc
647 .read()
648 .await
649 .user_id_for_connection(request.sender_id)?;
650 if !state
651 .db
652 .can_user_access_channel(user_id, ChannelId(request.payload.channel_id as i32))
653 .await?
654 {
655 Err(anyhow!("access denied"))?;
656 }
657
658 Ok(())
659}
660
661async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
662 worktree_id: u64,
663 request: TypedEnvelope<T>,
664 rpc: &Arc<Peer>,
665 state: &Arc<AppState>,
666) -> tide::Result<()> {
667 let connection_ids = state
668 .rpc
669 .read()
670 .await
671 .read_worktree(worktree_id, request.sender_id)?
672 .connection_ids();
673
674 broadcast(request.sender_id, connection_ids, |conn_id| {
675 rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
676 })
677 .await?;
678
679 Ok(())
680}
681
682pub async fn broadcast<F, T>(
683 sender_id: ConnectionId,
684 receiver_ids: Vec<ConnectionId>,
685 mut f: F,
686) -> anyhow::Result<()>
687where
688 F: FnMut(ConnectionId) -> T,
689 T: Future<Output = anyhow::Result<()>>,
690{
691 let futures = receiver_ids
692 .into_iter()
693 .filter(|id| *id != sender_id)
694 .map(|id| f(id));
695 futures::future::try_join_all(futures).await?;
696 Ok(())
697}
698
699fn header_contains_ignore_case<T>(
700 request: &tide::Request<T>,
701 header_name: HeaderName,
702 value: &str,
703) -> bool {
704 request
705 .header(header_name)
706 .map(|h| {
707 h.as_str()
708 .split(',')
709 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
710 })
711 .unwrap_or(false)
712}