rpc.rs

  1use crate::auth::{self, UserId};
  2
  3use super::{auth::PeerExt as _, AppState};
  4use anyhow::anyhow;
  5use async_std::task;
  6use async_tungstenite::{
  7    tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
  8    WebSocketStream,
  9};
 10use sha1::{Digest as _, Sha1};
 11use std::{
 12    collections::{HashMap, HashSet},
 13    future::Future,
 14    mem,
 15    sync::Arc,
 16    time::Instant,
 17};
 18use surf::StatusCode;
 19use tide::log;
 20use tide::{
 21    http::headers::{HeaderName, CONNECTION, UPGRADE},
 22    Request, Response,
 23};
 24use zrpc::{
 25    auth::random_token,
 26    proto::{self, EnvelopedMessage},
 27    ConnectionId, Peer, Router, TypedEnvelope,
 28};
 29
 30type ReplicaId = u16;
 31
 32#[derive(Default)]
 33pub struct State {
 34    connections: HashMap<ConnectionId, ConnectionState>,
 35    pub worktrees: HashMap<u64, WorktreeState>,
 36    next_worktree_id: u64,
 37}
 38
 39struct ConnectionState {
 40    _user_id: i32,
 41    worktrees: HashSet<u64>,
 42}
 43
 44pub struct WorktreeState {
 45    host_connection_id: Option<ConnectionId>,
 46    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 47    active_replica_ids: HashSet<ReplicaId>,
 48    access_token: String,
 49    root_name: String,
 50    entries: HashMap<u64, proto::Entry>,
 51}
 52
 53impl WorktreeState {
 54    pub fn connection_ids(&self) -> Vec<ConnectionId> {
 55        self.guest_connection_ids
 56            .keys()
 57            .copied()
 58            .chain(self.host_connection_id)
 59            .collect()
 60    }
 61
 62    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
 63        Ok(self
 64            .host_connection_id
 65            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
 66    }
 67}
 68
 69impl State {
 70    // Add a new connection associated with a given user.
 71    pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) {
 72        self.connections.insert(
 73            connection_id,
 74            ConnectionState {
 75                _user_id,
 76                worktrees: Default::default(),
 77            },
 78        );
 79    }
 80
 81    // Remove the given connection and its association with any worktrees.
 82    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
 83        let mut worktree_ids = Vec::new();
 84        if let Some(connection_state) = self.connections.remove(&connection_id) {
 85            for worktree_id in connection_state.worktrees {
 86                if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
 87                    if worktree.host_connection_id == Some(connection_id) {
 88                        worktree_ids.push(worktree_id);
 89                    } else if let Some(replica_id) =
 90                        worktree.guest_connection_ids.remove(&connection_id)
 91                    {
 92                        worktree.active_replica_ids.remove(&replica_id);
 93                        worktree_ids.push(worktree_id);
 94                    }
 95                }
 96            }
 97        }
 98        worktree_ids
 99    }
100
101    // Add the given connection as a guest of the given worktree
102    pub fn join_worktree(
103        &mut self,
104        connection_id: ConnectionId,
105        worktree_id: u64,
106        access_token: &str,
107    ) -> Option<(ReplicaId, &WorktreeState)> {
108        if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
109            if access_token == worktree_state.access_token {
110                if let Some(connection_state) = self.connections.get_mut(&connection_id) {
111                    connection_state.worktrees.insert(worktree_id);
112                }
113
114                let mut replica_id = 1;
115                while worktree_state.active_replica_ids.contains(&replica_id) {
116                    replica_id += 1;
117                }
118                worktree_state.active_replica_ids.insert(replica_id);
119                worktree_state
120                    .guest_connection_ids
121                    .insert(connection_id, replica_id);
122                Some((replica_id, worktree_state))
123            } else {
124                None
125            }
126        } else {
127            None
128        }
129    }
130
131    fn read_worktree(
132        &self,
133        worktree_id: u64,
134        connection_id: ConnectionId,
135    ) -> tide::Result<&WorktreeState> {
136        let worktree = self
137            .worktrees
138            .get(&worktree_id)
139            .ok_or_else(|| anyhow!("worktree not found"))?;
140
141        if worktree.host_connection_id == Some(connection_id)
142            || worktree.guest_connection_ids.contains_key(&connection_id)
143        {
144            Ok(worktree)
145        } else {
146            Err(anyhow!(
147                "{} is not a member of worktree {}",
148                connection_id,
149                worktree_id
150            ))?
151        }
152    }
153
154    fn write_worktree(
155        &mut self,
156        worktree_id: u64,
157        connection_id: ConnectionId,
158    ) -> tide::Result<&mut WorktreeState> {
159        let worktree = self
160            .worktrees
161            .get_mut(&worktree_id)
162            .ok_or_else(|| anyhow!("worktree not found"))?;
163
164        if worktree.host_connection_id == Some(connection_id)
165            || worktree.guest_connection_ids.contains_key(&connection_id)
166        {
167            Ok(worktree)
168        } else {
169            Err(anyhow!(
170                "{} is not a member of worktree {}",
171                connection_id,
172                worktree_id
173            ))?
174        }
175    }
176}
177
178trait MessageHandler<'a, M: proto::EnvelopedMessage> {
179    type Output: 'a + Send + Future<Output = tide::Result<()>>;
180
181    fn handle(
182        &self,
183        message: TypedEnvelope<M>,
184        rpc: &'a Arc<Peer>,
185        app_state: &'a Arc<AppState>,
186    ) -> Self::Output;
187}
188
189impl<'a, M, F, Fut> MessageHandler<'a, M> for F
190where
191    M: proto::EnvelopedMessage,
192    F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
193    Fut: 'a + Send + Future<Output = tide::Result<()>>,
194{
195    type Output = Fut;
196
197    fn handle(
198        &self,
199        message: TypedEnvelope<M>,
200        rpc: &'a Arc<Peer>,
201        app_state: &'a Arc<AppState>,
202    ) -> Self::Output {
203        (self)(message, rpc, app_state)
204    }
205}
206
207fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
208where
209    M: EnvelopedMessage,
210    H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
211{
212    let rpc = rpc.clone();
213    let handler = handler.clone();
214    let app_state = app_state.clone();
215    router.add_message_handler(move |message| {
216        let rpc = rpc.clone();
217        let handler = handler.clone();
218        let app_state = app_state.clone();
219        async move {
220            let sender_id = message.sender_id;
221            let message_id = message.message_id;
222            let start_time = Instant::now();
223            log::info!(
224                "RPC message received. id: {}.{}, type:{}",
225                sender_id,
226                message_id,
227                M::NAME
228            );
229            if let Err(err) = handler.handle(message, &rpc, &app_state).await {
230                log::error!("error handling message: {:?}", err);
231            } else {
232                log::info!(
233                    "RPC message handled. id:{}.{}, duration:{:?}",
234                    sender_id,
235                    message_id,
236                    start_time.elapsed()
237                );
238            }
239
240            Ok(())
241        }
242    });
243}
244
245pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
246    on_message(router, rpc, state, share_worktree);
247    on_message(router, rpc, state, join_worktree);
248    on_message(router, rpc, state, update_worktree);
249    on_message(router, rpc, state, close_worktree);
250    on_message(router, rpc, state, open_buffer);
251    on_message(router, rpc, state, close_buffer);
252    on_message(router, rpc, state, update_buffer);
253    on_message(router, rpc, state, buffer_saved);
254    on_message(router, rpc, state, save_buffer);
255}
256
257pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
258    let mut router = Router::new();
259    add_rpc_routes(&mut router, app.state(), rpc);
260    let router = Arc::new(router);
261
262    let rpc = rpc.clone();
263    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
264        let user_id = request.ext::<UserId>().copied();
265        let rpc = rpc.clone();
266        let router = router.clone();
267        async move {
268            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
269
270            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
271            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
272            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
273
274            if !upgrade_requested {
275                return Ok(Response::new(StatusCode::UpgradeRequired));
276            }
277
278            let header = match request.header("Sec-Websocket-Key") {
279                Some(h) => h.as_str(),
280                None => return Err(anyhow!("expected sec-websocket-key"))?,
281            };
282
283            let mut response = Response::new(StatusCode::SwitchingProtocols);
284            response.insert_header(UPGRADE, "websocket");
285            response.insert_header(CONNECTION, "Upgrade");
286            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
287            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
288            response.insert_header("Sec-Websocket-Version", "13");
289
290            let http_res: &mut tide::http::Response = response.as_mut();
291            let upgrade_receiver = http_res.recv_upgrade().await;
292            let addr = request.remote().unwrap_or("unknown").to_string();
293            let state = request.state().clone();
294            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0;
295            task::spawn(async move {
296                if let Some(stream) = upgrade_receiver.await {
297                    let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
298                    handle_connection(rpc, router, state, addr, stream, user_id).await;
299                }
300            });
301
302            Ok(response)
303        }
304    });
305}
306
307pub async fn handle_connection<Conn>(
308    rpc: Arc<Peer>,
309    router: Arc<Router>,
310    state: Arc<AppState>,
311    addr: String,
312    stream: Conn,
313    user_id: i32,
314) where
315    Conn: 'static
316        + futures::Sink<WebSocketMessage, Error = WebSocketError>
317        + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
318        + Send
319        + Unpin,
320{
321    log::info!("accepted rpc connection: {:?}", addr);
322    let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
323    state
324        .rpc
325        .write()
326        .await
327        .add_connection(connection_id, user_id);
328
329    let handle_messages = async move {
330        handle_messages.await;
331        Ok(())
332    };
333
334    if let Err(e) = futures::try_join!(handle_messages, handle_io) {
335        log::error!("error handling rpc connection {:?} - {:?}", addr, e);
336    }
337
338    log::info!("closing connection to {:?}", addr);
339    if let Err(e) = rpc.sign_out(connection_id, &state).await {
340        log::error!("error signing out connection {:?} - {:?}", addr, e);
341    }
342}
343
344async fn share_worktree(
345    mut request: TypedEnvelope<proto::ShareWorktree>,
346    rpc: &Arc<Peer>,
347    state: &Arc<AppState>,
348) -> tide::Result<()> {
349    let mut state = state.rpc.write().await;
350    let worktree_id = state.next_worktree_id;
351    state.next_worktree_id += 1;
352    let access_token = random_token();
353    let worktree = request
354        .payload
355        .worktree
356        .as_mut()
357        .ok_or_else(|| anyhow!("missing worktree"))?;
358    let entries = mem::take(&mut worktree.entries)
359        .into_iter()
360        .map(|entry| (entry.id, entry))
361        .collect();
362    state.worktrees.insert(
363        worktree_id,
364        WorktreeState {
365            host_connection_id: Some(request.sender_id),
366            guest_connection_ids: Default::default(),
367            active_replica_ids: Default::default(),
368            access_token: access_token.clone(),
369            root_name: mem::take(&mut worktree.root_name),
370            entries,
371        },
372    );
373
374    rpc.respond(
375        request.receipt(),
376        proto::ShareWorktreeResponse {
377            worktree_id,
378            access_token,
379        },
380    )
381    .await?;
382    Ok(())
383}
384
385async fn join_worktree(
386    request: TypedEnvelope<proto::OpenWorktree>,
387    rpc: &Arc<Peer>,
388    state: &Arc<AppState>,
389) -> tide::Result<()> {
390    let worktree_id = request.payload.worktree_id;
391    let access_token = &request.payload.access_token;
392
393    let mut state = state.rpc.write().await;
394    if let Some((peer_replica_id, worktree)) =
395        state.join_worktree(request.sender_id, worktree_id, access_token)
396    {
397        let mut peers = Vec::new();
398        if let Some(host_connection_id) = worktree.host_connection_id {
399            peers.push(proto::Peer {
400                peer_id: host_connection_id.0,
401                replica_id: 0,
402            });
403        }
404        for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
405            if *peer_conn_id != request.sender_id {
406                peers.push(proto::Peer {
407                    peer_id: peer_conn_id.0,
408                    replica_id: *peer_replica_id as u32,
409                });
410            }
411        }
412
413        broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
414            rpc.send(
415                conn_id,
416                proto::AddPeer {
417                    worktree_id,
418                    peer: Some(proto::Peer {
419                        peer_id: request.sender_id.0,
420                        replica_id: peer_replica_id as u32,
421                    }),
422                },
423            )
424        })
425        .await?;
426        rpc.respond(
427            request.receipt(),
428            proto::OpenWorktreeResponse {
429                worktree_id,
430                worktree: Some(proto::Worktree {
431                    root_name: worktree.root_name.clone(),
432                    entries: worktree.entries.values().cloned().collect(),
433                }),
434                replica_id: peer_replica_id as u32,
435                peers,
436            },
437        )
438        .await?;
439    } else {
440        rpc.respond(
441            request.receipt(),
442            proto::OpenWorktreeResponse {
443                worktree_id,
444                worktree: None,
445                replica_id: 0,
446                peers: Vec::new(),
447            },
448        )
449        .await?;
450    }
451
452    Ok(())
453}
454
455async fn update_worktree(
456    request: TypedEnvelope<proto::UpdateWorktree>,
457    rpc: &Arc<Peer>,
458    state: &Arc<AppState>,
459) -> tide::Result<()> {
460    {
461        let mut state = state.rpc.write().await;
462        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
463        for entry_id in &request.payload.removed_entries {
464            worktree.entries.remove(&entry_id);
465        }
466
467        for entry in &request.payload.updated_entries {
468            worktree.entries.insert(entry.id, entry.clone());
469        }
470    }
471
472    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
473    Ok(())
474}
475
476async fn close_worktree(
477    request: TypedEnvelope<proto::CloseWorktree>,
478    rpc: &Arc<Peer>,
479    state: &Arc<AppState>,
480) -> tide::Result<()> {
481    let connection_ids;
482    {
483        let mut state = state.rpc.write().await;
484        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
485        connection_ids = worktree.connection_ids();
486        if worktree.host_connection_id == Some(request.sender_id) {
487            worktree.host_connection_id = None;
488        } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
489            worktree.active_replica_ids.remove(&replica_id);
490        }
491    }
492
493    broadcast(request.sender_id, connection_ids, |conn_id| {
494        rpc.send(
495            conn_id,
496            proto::RemovePeer {
497                worktree_id: request.payload.worktree_id,
498                peer_id: request.sender_id.0,
499            },
500        )
501    })
502    .await?;
503
504    Ok(())
505}
506
507async fn open_buffer(
508    request: TypedEnvelope<proto::OpenBuffer>,
509    rpc: &Arc<Peer>,
510    state: &Arc<AppState>,
511) -> tide::Result<()> {
512    let receipt = request.receipt();
513    let worktree_id = request.payload.worktree_id;
514    let host_connection_id = state
515        .rpc
516        .read()
517        .await
518        .read_worktree(worktree_id, request.sender_id)?
519        .host_connection_id()?;
520
521    let response = rpc
522        .forward_request(request.sender_id, host_connection_id, request.payload)
523        .await?;
524    rpc.respond(receipt, response).await?;
525    Ok(())
526}
527
528async fn close_buffer(
529    request: TypedEnvelope<proto::CloseBuffer>,
530    rpc: &Arc<Peer>,
531    state: &Arc<AppState>,
532) -> tide::Result<()> {
533    let host_connection_id = state
534        .rpc
535        .read()
536        .await
537        .read_worktree(request.payload.worktree_id, request.sender_id)?
538        .host_connection_id()?;
539
540    rpc.forward_send(request.sender_id, host_connection_id, request.payload)
541        .await?;
542
543    Ok(())
544}
545
546async fn save_buffer(
547    request: TypedEnvelope<proto::SaveBuffer>,
548    rpc: &Arc<Peer>,
549    state: &Arc<AppState>,
550) -> tide::Result<()> {
551    let host;
552    let guests;
553    {
554        let state = state.rpc.read().await;
555        let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
556        host = worktree.host_connection_id()?;
557        guests = worktree
558            .guest_connection_ids
559            .keys()
560            .copied()
561            .collect::<Vec<_>>();
562    }
563
564    let sender = request.sender_id;
565    let receipt = request.receipt();
566    let response = rpc
567        .forward_request(sender, host, request.payload.clone())
568        .await?;
569
570    broadcast(host, guests, |conn_id| {
571        let response = response.clone();
572        async move {
573            if conn_id == sender {
574                rpc.respond(receipt, response).await
575            } else {
576                rpc.forward_send(host, conn_id, response).await
577            }
578        }
579    })
580    .await?;
581
582    Ok(())
583}
584
585async fn update_buffer(
586    request: TypedEnvelope<proto::UpdateBuffer>,
587    rpc: &Arc<Peer>,
588    state: &Arc<AppState>,
589) -> tide::Result<()> {
590    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
591}
592
593async fn buffer_saved(
594    request: TypedEnvelope<proto::BufferSaved>,
595    rpc: &Arc<Peer>,
596    state: &Arc<AppState>,
597) -> tide::Result<()> {
598    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
599}
600
601async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
602    worktree_id: u64,
603    request: TypedEnvelope<T>,
604    rpc: &Arc<Peer>,
605    state: &Arc<AppState>,
606) -> tide::Result<()> {
607    let connection_ids = state
608        .rpc
609        .read()
610        .await
611        .read_worktree(worktree_id, request.sender_id)?
612        .connection_ids();
613
614    broadcast(request.sender_id, connection_ids, |conn_id| {
615        rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
616    })
617    .await?;
618
619    Ok(())
620}
621
622pub async fn broadcast<F, T>(
623    sender_id: ConnectionId,
624    receiver_ids: Vec<ConnectionId>,
625    mut f: F,
626) -> anyhow::Result<()>
627where
628    F: FnMut(ConnectionId) -> T,
629    T: Future<Output = anyhow::Result<()>>,
630{
631    let futures = receiver_ids
632        .into_iter()
633        .filter(|id| *id != sender_id)
634        .map(|id| f(id));
635    futures::future::try_join_all(futures).await?;
636    Ok(())
637}
638
639fn header_contains_ignore_case<T>(
640    request: &tide::Request<T>,
641    header_name: HeaderName,
642    value: &str,
643) -> bool {
644    request
645        .header(header_name)
646        .map(|h| {
647            h.as_str()
648                .split(',')
649                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
650        })
651        .unwrap_or(false)
652}