rpc.rs

  1use super::{
  2    auth::{self, PeerExt as _},
  3    db::UserId,
  4    AppState,
  5};
  6use anyhow::anyhow;
  7use async_std::task;
  8use async_tungstenite::{
  9    tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
 10    WebSocketStream,
 11};
 12use sha1::{Digest as _, Sha1};
 13use std::{
 14    collections::{HashMap, HashSet},
 15    future::Future,
 16    mem,
 17    sync::Arc,
 18    time::Instant,
 19};
 20use surf::StatusCode;
 21use tide::log;
 22use tide::{
 23    http::headers::{HeaderName, CONNECTION, UPGRADE},
 24    Request, Response,
 25};
 26use zrpc::{
 27    auth::random_token,
 28    proto::{self, EnvelopedMessage},
 29    ConnectionId, Peer, Router, TypedEnvelope,
 30};
 31
 32type ReplicaId = u16;
 33
 34#[derive(Default)]
 35pub struct State {
 36    connections: HashMap<ConnectionId, ConnectionState>,
 37    pub worktrees: HashMap<u64, WorktreeState>,
 38    next_worktree_id: u64,
 39}
 40
 41struct ConnectionState {
 42    _user_id: UserId,
 43    worktrees: HashSet<u64>,
 44}
 45
 46pub struct WorktreeState {
 47    host_connection_id: Option<ConnectionId>,
 48    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 49    active_replica_ids: HashSet<ReplicaId>,
 50    access_token: String,
 51    root_name: String,
 52    entries: HashMap<u64, proto::Entry>,
 53}
 54
 55impl WorktreeState {
 56    pub fn connection_ids(&self) -> Vec<ConnectionId> {
 57        self.guest_connection_ids
 58            .keys()
 59            .copied()
 60            .chain(self.host_connection_id)
 61            .collect()
 62    }
 63
 64    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
 65        Ok(self
 66            .host_connection_id
 67            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
 68    }
 69}
 70
 71impl State {
 72    // Add a new connection associated with a given user.
 73    pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) {
 74        self.connections.insert(
 75            connection_id,
 76            ConnectionState {
 77                _user_id,
 78                worktrees: Default::default(),
 79            },
 80        );
 81    }
 82
 83    // Remove the given connection and its association with any worktrees.
 84    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
 85        let mut worktree_ids = Vec::new();
 86        if let Some(connection_state) = self.connections.remove(&connection_id) {
 87            for worktree_id in connection_state.worktrees {
 88                if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
 89                    if worktree.host_connection_id == Some(connection_id) {
 90                        worktree_ids.push(worktree_id);
 91                    } else if let Some(replica_id) =
 92                        worktree.guest_connection_ids.remove(&connection_id)
 93                    {
 94                        worktree.active_replica_ids.remove(&replica_id);
 95                        worktree_ids.push(worktree_id);
 96                    }
 97                }
 98            }
 99        }
100        worktree_ids
101    }
102
103    // Add the given connection as a guest of the given worktree
104    pub fn join_worktree(
105        &mut self,
106        connection_id: ConnectionId,
107        worktree_id: u64,
108        access_token: &str,
109    ) -> Option<(ReplicaId, &WorktreeState)> {
110        if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
111            if access_token == worktree_state.access_token {
112                if let Some(connection_state) = self.connections.get_mut(&connection_id) {
113                    connection_state.worktrees.insert(worktree_id);
114                }
115
116                let mut replica_id = 1;
117                while worktree_state.active_replica_ids.contains(&replica_id) {
118                    replica_id += 1;
119                }
120                worktree_state.active_replica_ids.insert(replica_id);
121                worktree_state
122                    .guest_connection_ids
123                    .insert(connection_id, replica_id);
124                Some((replica_id, worktree_state))
125            } else {
126                None
127            }
128        } else {
129            None
130        }
131    }
132
133    fn read_worktree(
134        &self,
135        worktree_id: u64,
136        connection_id: ConnectionId,
137    ) -> tide::Result<&WorktreeState> {
138        let worktree = self
139            .worktrees
140            .get(&worktree_id)
141            .ok_or_else(|| anyhow!("worktree not found"))?;
142
143        if worktree.host_connection_id == Some(connection_id)
144            || worktree.guest_connection_ids.contains_key(&connection_id)
145        {
146            Ok(worktree)
147        } else {
148            Err(anyhow!(
149                "{} is not a member of worktree {}",
150                connection_id,
151                worktree_id
152            ))?
153        }
154    }
155
156    fn write_worktree(
157        &mut self,
158        worktree_id: u64,
159        connection_id: ConnectionId,
160    ) -> tide::Result<&mut WorktreeState> {
161        let worktree = self
162            .worktrees
163            .get_mut(&worktree_id)
164            .ok_or_else(|| anyhow!("worktree not found"))?;
165
166        if worktree.host_connection_id == Some(connection_id)
167            || worktree.guest_connection_ids.contains_key(&connection_id)
168        {
169            Ok(worktree)
170        } else {
171            Err(anyhow!(
172                "{} is not a member of worktree {}",
173                connection_id,
174                worktree_id
175            ))?
176        }
177    }
178}
179
180trait MessageHandler<'a, M: proto::EnvelopedMessage> {
181    type Output: 'a + Send + Future<Output = tide::Result<()>>;
182
183    fn handle(
184        &self,
185        message: TypedEnvelope<M>,
186        rpc: &'a Arc<Peer>,
187        app_state: &'a Arc<AppState>,
188    ) -> Self::Output;
189}
190
191impl<'a, M, F, Fut> MessageHandler<'a, M> for F
192where
193    M: proto::EnvelopedMessage,
194    F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
195    Fut: 'a + Send + Future<Output = tide::Result<()>>,
196{
197    type Output = Fut;
198
199    fn handle(
200        &self,
201        message: TypedEnvelope<M>,
202        rpc: &'a Arc<Peer>,
203        app_state: &'a Arc<AppState>,
204    ) -> Self::Output {
205        (self)(message, rpc, app_state)
206    }
207}
208
209fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
210where
211    M: EnvelopedMessage,
212    H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
213{
214    let rpc = rpc.clone();
215    let handler = handler.clone();
216    let app_state = app_state.clone();
217    router.add_message_handler(move |message| {
218        let rpc = rpc.clone();
219        let handler = handler.clone();
220        let app_state = app_state.clone();
221        async move {
222            let sender_id = message.sender_id;
223            let message_id = message.message_id;
224            let start_time = Instant::now();
225            log::info!(
226                "RPC message received. id: {}.{}, type:{}",
227                sender_id,
228                message_id,
229                M::NAME
230            );
231            if let Err(err) = handler.handle(message, &rpc, &app_state).await {
232                log::error!("error handling message: {:?}", err);
233            } else {
234                log::info!(
235                    "RPC message handled. id:{}.{}, duration:{:?}",
236                    sender_id,
237                    message_id,
238                    start_time.elapsed()
239                );
240            }
241
242            Ok(())
243        }
244    });
245}
246
247pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
248    on_message(router, rpc, state, share_worktree);
249    on_message(router, rpc, state, join_worktree);
250    on_message(router, rpc, state, update_worktree);
251    on_message(router, rpc, state, close_worktree);
252    on_message(router, rpc, state, open_buffer);
253    on_message(router, rpc, state, close_buffer);
254    on_message(router, rpc, state, update_buffer);
255    on_message(router, rpc, state, buffer_saved);
256    on_message(router, rpc, state, save_buffer);
257}
258
259pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
260    let mut router = Router::new();
261    add_rpc_routes(&mut router, app.state(), rpc);
262    let router = Arc::new(router);
263
264    let rpc = rpc.clone();
265    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
266        let user_id = request.ext::<UserId>().copied();
267        let rpc = rpc.clone();
268        let router = router.clone();
269        async move {
270            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
271
272            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
273            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
274            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
275
276            if !upgrade_requested {
277                return Ok(Response::new(StatusCode::UpgradeRequired));
278            }
279
280            let header = match request.header("Sec-Websocket-Key") {
281                Some(h) => h.as_str(),
282                None => return Err(anyhow!("expected sec-websocket-key"))?,
283            };
284
285            let mut response = Response::new(StatusCode::SwitchingProtocols);
286            response.insert_header(UPGRADE, "websocket");
287            response.insert_header(CONNECTION, "Upgrade");
288            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
289            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
290            response.insert_header("Sec-Websocket-Version", "13");
291
292            let http_res: &mut tide::http::Response = response.as_mut();
293            let upgrade_receiver = http_res.recv_upgrade().await;
294            let addr = request.remote().unwrap_or("unknown").to_string();
295            let state = request.state().clone();
296            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
297            task::spawn(async move {
298                if let Some(stream) = upgrade_receiver.await {
299                    let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
300                    handle_connection(rpc, router, state, addr, stream, user_id).await;
301                }
302            });
303
304            Ok(response)
305        }
306    });
307}
308
309pub async fn handle_connection<Conn>(
310    rpc: Arc<Peer>,
311    router: Arc<Router>,
312    state: Arc<AppState>,
313    addr: String,
314    stream: Conn,
315    user_id: UserId,
316) where
317    Conn: 'static
318        + futures::Sink<WebSocketMessage, Error = WebSocketError>
319        + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
320        + Send
321        + Unpin,
322{
323    log::info!("accepted rpc connection: {:?}", addr);
324    let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
325    state
326        .rpc
327        .write()
328        .await
329        .add_connection(connection_id, user_id);
330
331    let handle_messages = async move {
332        handle_messages.await;
333        Ok(())
334    };
335
336    if let Err(e) = futures::try_join!(handle_messages, handle_io) {
337        log::error!("error handling rpc connection {:?} - {:?}", addr, e);
338    }
339
340    log::info!("closing connection to {:?}", addr);
341    if let Err(e) = rpc.sign_out(connection_id, &state).await {
342        log::error!("error signing out connection {:?} - {:?}", addr, e);
343    }
344}
345
346async fn share_worktree(
347    mut request: TypedEnvelope<proto::ShareWorktree>,
348    rpc: &Arc<Peer>,
349    state: &Arc<AppState>,
350) -> tide::Result<()> {
351    let mut state = state.rpc.write().await;
352    let worktree_id = state.next_worktree_id;
353    state.next_worktree_id += 1;
354    let access_token = random_token();
355    let worktree = request
356        .payload
357        .worktree
358        .as_mut()
359        .ok_or_else(|| anyhow!("missing worktree"))?;
360    let entries = mem::take(&mut worktree.entries)
361        .into_iter()
362        .map(|entry| (entry.id, entry))
363        .collect();
364    state.worktrees.insert(
365        worktree_id,
366        WorktreeState {
367            host_connection_id: Some(request.sender_id),
368            guest_connection_ids: Default::default(),
369            active_replica_ids: Default::default(),
370            access_token: access_token.clone(),
371            root_name: mem::take(&mut worktree.root_name),
372            entries,
373        },
374    );
375
376    rpc.respond(
377        request.receipt(),
378        proto::ShareWorktreeResponse {
379            worktree_id,
380            access_token,
381        },
382    )
383    .await?;
384    Ok(())
385}
386
387async fn join_worktree(
388    request: TypedEnvelope<proto::OpenWorktree>,
389    rpc: &Arc<Peer>,
390    state: &Arc<AppState>,
391) -> tide::Result<()> {
392    let worktree_id = request.payload.worktree_id;
393    let access_token = &request.payload.access_token;
394
395    let mut state = state.rpc.write().await;
396    if let Some((peer_replica_id, worktree)) =
397        state.join_worktree(request.sender_id, worktree_id, access_token)
398    {
399        let mut peers = Vec::new();
400        if let Some(host_connection_id) = worktree.host_connection_id {
401            peers.push(proto::Peer {
402                peer_id: host_connection_id.0,
403                replica_id: 0,
404            });
405        }
406        for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
407            if *peer_conn_id != request.sender_id {
408                peers.push(proto::Peer {
409                    peer_id: peer_conn_id.0,
410                    replica_id: *peer_replica_id as u32,
411                });
412            }
413        }
414
415        broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
416            rpc.send(
417                conn_id,
418                proto::AddPeer {
419                    worktree_id,
420                    peer: Some(proto::Peer {
421                        peer_id: request.sender_id.0,
422                        replica_id: peer_replica_id as u32,
423                    }),
424                },
425            )
426        })
427        .await?;
428        rpc.respond(
429            request.receipt(),
430            proto::OpenWorktreeResponse {
431                worktree_id,
432                worktree: Some(proto::Worktree {
433                    root_name: worktree.root_name.clone(),
434                    entries: worktree.entries.values().cloned().collect(),
435                }),
436                replica_id: peer_replica_id as u32,
437                peers,
438            },
439        )
440        .await?;
441    } else {
442        rpc.respond(
443            request.receipt(),
444            proto::OpenWorktreeResponse {
445                worktree_id,
446                worktree: None,
447                replica_id: 0,
448                peers: Vec::new(),
449            },
450        )
451        .await?;
452    }
453
454    Ok(())
455}
456
457async fn update_worktree(
458    request: TypedEnvelope<proto::UpdateWorktree>,
459    rpc: &Arc<Peer>,
460    state: &Arc<AppState>,
461) -> tide::Result<()> {
462    {
463        let mut state = state.rpc.write().await;
464        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
465        for entry_id in &request.payload.removed_entries {
466            worktree.entries.remove(&entry_id);
467        }
468
469        for entry in &request.payload.updated_entries {
470            worktree.entries.insert(entry.id, entry.clone());
471        }
472    }
473
474    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
475    Ok(())
476}
477
478async fn close_worktree(
479    request: TypedEnvelope<proto::CloseWorktree>,
480    rpc: &Arc<Peer>,
481    state: &Arc<AppState>,
482) -> tide::Result<()> {
483    let connection_ids;
484    {
485        let mut state = state.rpc.write().await;
486        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
487        connection_ids = worktree.connection_ids();
488        if worktree.host_connection_id == Some(request.sender_id) {
489            worktree.host_connection_id = None;
490        } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
491            worktree.active_replica_ids.remove(&replica_id);
492        }
493    }
494
495    broadcast(request.sender_id, connection_ids, |conn_id| {
496        rpc.send(
497            conn_id,
498            proto::RemovePeer {
499                worktree_id: request.payload.worktree_id,
500                peer_id: request.sender_id.0,
501            },
502        )
503    })
504    .await?;
505
506    Ok(())
507}
508
509async fn open_buffer(
510    request: TypedEnvelope<proto::OpenBuffer>,
511    rpc: &Arc<Peer>,
512    state: &Arc<AppState>,
513) -> tide::Result<()> {
514    let receipt = request.receipt();
515    let worktree_id = request.payload.worktree_id;
516    let host_connection_id = state
517        .rpc
518        .read()
519        .await
520        .read_worktree(worktree_id, request.sender_id)?
521        .host_connection_id()?;
522
523    let response = rpc
524        .forward_request(request.sender_id, host_connection_id, request.payload)
525        .await?;
526    rpc.respond(receipt, response).await?;
527    Ok(())
528}
529
530async fn close_buffer(
531    request: TypedEnvelope<proto::CloseBuffer>,
532    rpc: &Arc<Peer>,
533    state: &Arc<AppState>,
534) -> tide::Result<()> {
535    let host_connection_id = state
536        .rpc
537        .read()
538        .await
539        .read_worktree(request.payload.worktree_id, request.sender_id)?
540        .host_connection_id()?;
541
542    rpc.forward_send(request.sender_id, host_connection_id, request.payload)
543        .await?;
544
545    Ok(())
546}
547
548async fn save_buffer(
549    request: TypedEnvelope<proto::SaveBuffer>,
550    rpc: &Arc<Peer>,
551    state: &Arc<AppState>,
552) -> tide::Result<()> {
553    let host;
554    let guests;
555    {
556        let state = state.rpc.read().await;
557        let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
558        host = worktree.host_connection_id()?;
559        guests = worktree
560            .guest_connection_ids
561            .keys()
562            .copied()
563            .collect::<Vec<_>>();
564    }
565
566    let sender = request.sender_id;
567    let receipt = request.receipt();
568    let response = rpc
569        .forward_request(sender, host, request.payload.clone())
570        .await?;
571
572    broadcast(host, guests, |conn_id| {
573        let response = response.clone();
574        async move {
575            if conn_id == sender {
576                rpc.respond(receipt, response).await
577            } else {
578                rpc.forward_send(host, conn_id, response).await
579            }
580        }
581    })
582    .await?;
583
584    Ok(())
585}
586
587async fn update_buffer(
588    request: TypedEnvelope<proto::UpdateBuffer>,
589    rpc: &Arc<Peer>,
590    state: &Arc<AppState>,
591) -> tide::Result<()> {
592    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
593}
594
595async fn buffer_saved(
596    request: TypedEnvelope<proto::BufferSaved>,
597    rpc: &Arc<Peer>,
598    state: &Arc<AppState>,
599) -> tide::Result<()> {
600    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
601}
602
603async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
604    worktree_id: u64,
605    request: TypedEnvelope<T>,
606    rpc: &Arc<Peer>,
607    state: &Arc<AppState>,
608) -> tide::Result<()> {
609    let connection_ids = state
610        .rpc
611        .read()
612        .await
613        .read_worktree(worktree_id, request.sender_id)?
614        .connection_ids();
615
616    broadcast(request.sender_id, connection_ids, |conn_id| {
617        rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
618    })
619    .await?;
620
621    Ok(())
622}
623
624pub async fn broadcast<F, T>(
625    sender_id: ConnectionId,
626    receiver_ids: Vec<ConnectionId>,
627    mut f: F,
628) -> anyhow::Result<()>
629where
630    F: FnMut(ConnectionId) -> T,
631    T: Future<Output = anyhow::Result<()>>,
632{
633    let futures = receiver_ids
634        .into_iter()
635        .filter(|id| *id != sender_id)
636        .map(|id| f(id));
637    futures::future::try_join_all(futures).await?;
638    Ok(())
639}
640
641fn header_contains_ignore_case<T>(
642    request: &tide::Request<T>,
643    header_name: HeaderName,
644    value: &str,
645) -> bool {
646    request
647        .header(header_name)
648        .map(|h| {
649            h.as_str()
650                .split(',')
651                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
652        })
653        .unwrap_or(false)
654}