rpc.rs

  1use super::{
  2    auth::{self, PeerExt as _},
  3    db::{ChannelId, UserId},
  4    AppState,
  5};
  6use anyhow::anyhow;
  7use async_std::task;
  8use async_tungstenite::{
  9    tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
 10    WebSocketStream,
 11};
 12use sha1::{Digest as _, Sha1};
 13use std::{
 14    collections::{HashMap, HashSet},
 15    future::Future,
 16    mem,
 17    sync::Arc,
 18    time::Instant,
 19};
 20use surf::StatusCode;
 21use tide::log;
 22use tide::{
 23    http::headers::{HeaderName, CONNECTION, UPGRADE},
 24    Request, Response,
 25};
 26use zrpc::{
 27    auth::random_token,
 28    proto::{self, EnvelopedMessage},
 29    ConnectionId, Peer, Router, TypedEnvelope,
 30};
 31
 32type ReplicaId = u16;
 33
 34#[derive(Default)]
 35pub struct State {
 36    connections: HashMap<ConnectionId, ConnectionState>,
 37    pub worktrees: HashMap<u64, WorktreeState>,
 38    next_worktree_id: u64,
 39}
 40
 41struct ConnectionState {
 42    user_id: UserId,
 43    worktrees: HashSet<u64>,
 44}
 45
 46pub struct WorktreeState {
 47    host_connection_id: Option<ConnectionId>,
 48    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 49    active_replica_ids: HashSet<ReplicaId>,
 50    access_token: String,
 51    root_name: String,
 52    entries: HashMap<u64, proto::Entry>,
 53}
 54
 55impl WorktreeState {
 56    pub fn connection_ids(&self) -> Vec<ConnectionId> {
 57        self.guest_connection_ids
 58            .keys()
 59            .copied()
 60            .chain(self.host_connection_id)
 61            .collect()
 62    }
 63
 64    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
 65        Ok(self
 66            .host_connection_id
 67            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
 68    }
 69}
 70
 71impl State {
 72    // Add a new connection associated with a given user.
 73    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 74        self.connections.insert(
 75            connection_id,
 76            ConnectionState {
 77                user_id,
 78                worktrees: Default::default(),
 79            },
 80        );
 81    }
 82
 83    // Remove the given connection and its association with any worktrees.
 84    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
 85        let mut worktree_ids = Vec::new();
 86        if let Some(connection_state) = self.connections.remove(&connection_id) {
 87            for worktree_id in connection_state.worktrees {
 88                if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
 89                    if worktree.host_connection_id == Some(connection_id) {
 90                        worktree_ids.push(worktree_id);
 91                    } else if let Some(replica_id) =
 92                        worktree.guest_connection_ids.remove(&connection_id)
 93                    {
 94                        worktree.active_replica_ids.remove(&replica_id);
 95                        worktree_ids.push(worktree_id);
 96                    }
 97                }
 98            }
 99        }
100        worktree_ids
101    }
102
103    // Add the given connection as a guest of the given worktree
104    pub fn join_worktree(
105        &mut self,
106        connection_id: ConnectionId,
107        worktree_id: u64,
108        access_token: &str,
109    ) -> Option<(ReplicaId, &WorktreeState)> {
110        if let Some(worktree_state) = self.worktrees.get_mut(&worktree_id) {
111            if access_token == worktree_state.access_token {
112                if let Some(connection_state) = self.connections.get_mut(&connection_id) {
113                    connection_state.worktrees.insert(worktree_id);
114                }
115
116                let mut replica_id = 1;
117                while worktree_state.active_replica_ids.contains(&replica_id) {
118                    replica_id += 1;
119                }
120                worktree_state.active_replica_ids.insert(replica_id);
121                worktree_state
122                    .guest_connection_ids
123                    .insert(connection_id, replica_id);
124                Some((replica_id, worktree_state))
125            } else {
126                None
127            }
128        } else {
129            None
130        }
131    }
132
133    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
134        Ok(self
135            .connections
136            .get(&connection_id)
137            .ok_or_else(|| anyhow!("unknown connection"))?
138            .user_id)
139    }
140
141    fn read_worktree(
142        &self,
143        worktree_id: u64,
144        connection_id: ConnectionId,
145    ) -> tide::Result<&WorktreeState> {
146        let worktree = self
147            .worktrees
148            .get(&worktree_id)
149            .ok_or_else(|| anyhow!("worktree not found"))?;
150
151        if worktree.host_connection_id == Some(connection_id)
152            || worktree.guest_connection_ids.contains_key(&connection_id)
153        {
154            Ok(worktree)
155        } else {
156            Err(anyhow!(
157                "{} is not a member of worktree {}",
158                connection_id,
159                worktree_id
160            ))?
161        }
162    }
163
164    fn write_worktree(
165        &mut self,
166        worktree_id: u64,
167        connection_id: ConnectionId,
168    ) -> tide::Result<&mut WorktreeState> {
169        let worktree = self
170            .worktrees
171            .get_mut(&worktree_id)
172            .ok_or_else(|| anyhow!("worktree not found"))?;
173
174        if worktree.host_connection_id == Some(connection_id)
175            || worktree.guest_connection_ids.contains_key(&connection_id)
176        {
177            Ok(worktree)
178        } else {
179            Err(anyhow!(
180                "{} is not a member of worktree {}",
181                connection_id,
182                worktree_id
183            ))?
184        }
185    }
186}
187
188trait MessageHandler<'a, M: proto::EnvelopedMessage> {
189    type Output: 'a + Send + Future<Output = tide::Result<()>>;
190
191    fn handle(
192        &self,
193        message: TypedEnvelope<M>,
194        rpc: &'a Arc<Peer>,
195        app_state: &'a Arc<AppState>,
196    ) -> Self::Output;
197}
198
199impl<'a, M, F, Fut> MessageHandler<'a, M> for F
200where
201    M: proto::EnvelopedMessage,
202    F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
203    Fut: 'a + Send + Future<Output = tide::Result<()>>,
204{
205    type Output = Fut;
206
207    fn handle(
208        &self,
209        message: TypedEnvelope<M>,
210        rpc: &'a Arc<Peer>,
211        app_state: &'a Arc<AppState>,
212    ) -> Self::Output {
213        (self)(message, rpc, app_state)
214    }
215}
216
217fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
218where
219    M: EnvelopedMessage,
220    H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
221{
222    let rpc = rpc.clone();
223    let handler = handler.clone();
224    let app_state = app_state.clone();
225    router.add_message_handler(move |message| {
226        let rpc = rpc.clone();
227        let handler = handler.clone();
228        let app_state = app_state.clone();
229        async move {
230            let sender_id = message.sender_id;
231            let message_id = message.message_id;
232            let start_time = Instant::now();
233            log::info!(
234                "RPC message received. id: {}.{}, type:{}",
235                sender_id,
236                message_id,
237                M::NAME
238            );
239            if let Err(err) = handler.handle(message, &rpc, &app_state).await {
240                log::error!("error handling message: {:?}", err);
241            } else {
242                log::info!(
243                    "RPC message handled. id:{}.{}, duration:{:?}",
244                    sender_id,
245                    message_id,
246                    start_time.elapsed()
247                );
248            }
249
250            Ok(())
251        }
252    });
253}
254
255pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
256    on_message(router, rpc, state, share_worktree);
257    on_message(router, rpc, state, join_worktree);
258    on_message(router, rpc, state, update_worktree);
259    on_message(router, rpc, state, close_worktree);
260    on_message(router, rpc, state, open_buffer);
261    on_message(router, rpc, state, close_buffer);
262    on_message(router, rpc, state, update_buffer);
263    on_message(router, rpc, state, buffer_saved);
264    on_message(router, rpc, state, save_buffer);
265    on_message(router, rpc, state, get_channels);
266    on_message(router, rpc, state, join_channel);
267}
268
269pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
270    let mut router = Router::new();
271    add_rpc_routes(&mut router, app.state(), rpc);
272    let router = Arc::new(router);
273
274    let rpc = rpc.clone();
275    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
276        let user_id = request.ext::<UserId>().copied();
277        let rpc = rpc.clone();
278        let router = router.clone();
279        async move {
280            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
281
282            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
283            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
284            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
285
286            if !upgrade_requested {
287                return Ok(Response::new(StatusCode::UpgradeRequired));
288            }
289
290            let header = match request.header("Sec-Websocket-Key") {
291                Some(h) => h.as_str(),
292                None => return Err(anyhow!("expected sec-websocket-key"))?,
293            };
294
295            let mut response = Response::new(StatusCode::SwitchingProtocols);
296            response.insert_header(UPGRADE, "websocket");
297            response.insert_header(CONNECTION, "Upgrade");
298            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
299            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
300            response.insert_header("Sec-Websocket-Version", "13");
301
302            let http_res: &mut tide::http::Response = response.as_mut();
303            let upgrade_receiver = http_res.recv_upgrade().await;
304            let addr = request.remote().unwrap_or("unknown").to_string();
305            let state = request.state().clone();
306            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
307            task::spawn(async move {
308                if let Some(stream) = upgrade_receiver.await {
309                    let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
310                    handle_connection(rpc, router, state, addr, stream, user_id).await;
311                }
312            });
313
314            Ok(response)
315        }
316    });
317}
318
319pub async fn handle_connection<Conn>(
320    rpc: Arc<Peer>,
321    router: Arc<Router>,
322    state: Arc<AppState>,
323    addr: String,
324    stream: Conn,
325    user_id: UserId,
326) where
327    Conn: 'static
328        + futures::Sink<WebSocketMessage, Error = WebSocketError>
329        + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
330        + Send
331        + Unpin,
332{
333    log::info!("accepted rpc connection: {:?}", addr);
334    let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
335    state
336        .rpc
337        .write()
338        .await
339        .add_connection(connection_id, user_id);
340
341    let handle_messages = async move {
342        handle_messages.await;
343        Ok(())
344    };
345
346    if let Err(e) = futures::try_join!(handle_messages, handle_io) {
347        log::error!("error handling rpc connection {:?} - {:?}", addr, e);
348    }
349
350    log::info!("closing connection to {:?}", addr);
351    if let Err(e) = rpc.sign_out(connection_id, &state).await {
352        log::error!("error signing out connection {:?} - {:?}", addr, e);
353    }
354}
355
356async fn share_worktree(
357    mut request: TypedEnvelope<proto::ShareWorktree>,
358    rpc: &Arc<Peer>,
359    state: &Arc<AppState>,
360) -> tide::Result<()> {
361    let mut state = state.rpc.write().await;
362    let worktree_id = state.next_worktree_id;
363    state.next_worktree_id += 1;
364    let access_token = random_token();
365    let worktree = request
366        .payload
367        .worktree
368        .as_mut()
369        .ok_or_else(|| anyhow!("missing worktree"))?;
370    let entries = mem::take(&mut worktree.entries)
371        .into_iter()
372        .map(|entry| (entry.id, entry))
373        .collect();
374    state.worktrees.insert(
375        worktree_id,
376        WorktreeState {
377            host_connection_id: Some(request.sender_id),
378            guest_connection_ids: Default::default(),
379            active_replica_ids: Default::default(),
380            access_token: access_token.clone(),
381            root_name: mem::take(&mut worktree.root_name),
382            entries,
383        },
384    );
385
386    rpc.respond(
387        request.receipt(),
388        proto::ShareWorktreeResponse {
389            worktree_id,
390            access_token,
391        },
392    )
393    .await?;
394    Ok(())
395}
396
397async fn join_worktree(
398    request: TypedEnvelope<proto::OpenWorktree>,
399    rpc: &Arc<Peer>,
400    state: &Arc<AppState>,
401) -> tide::Result<()> {
402    let worktree_id = request.payload.worktree_id;
403    let access_token = &request.payload.access_token;
404
405    let mut state = state.rpc.write().await;
406    if let Some((peer_replica_id, worktree)) =
407        state.join_worktree(request.sender_id, worktree_id, access_token)
408    {
409        let mut peers = Vec::new();
410        if let Some(host_connection_id) = worktree.host_connection_id {
411            peers.push(proto::Peer {
412                peer_id: host_connection_id.0,
413                replica_id: 0,
414            });
415        }
416        for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
417            if *peer_conn_id != request.sender_id {
418                peers.push(proto::Peer {
419                    peer_id: peer_conn_id.0,
420                    replica_id: *peer_replica_id as u32,
421                });
422            }
423        }
424
425        broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
426            rpc.send(
427                conn_id,
428                proto::AddPeer {
429                    worktree_id,
430                    peer: Some(proto::Peer {
431                        peer_id: request.sender_id.0,
432                        replica_id: peer_replica_id as u32,
433                    }),
434                },
435            )
436        })
437        .await?;
438        rpc.respond(
439            request.receipt(),
440            proto::OpenWorktreeResponse {
441                worktree_id,
442                worktree: Some(proto::Worktree {
443                    root_name: worktree.root_name.clone(),
444                    entries: worktree.entries.values().cloned().collect(),
445                }),
446                replica_id: peer_replica_id as u32,
447                peers,
448            },
449        )
450        .await?;
451    } else {
452        rpc.respond(
453            request.receipt(),
454            proto::OpenWorktreeResponse {
455                worktree_id,
456                worktree: None,
457                replica_id: 0,
458                peers: Vec::new(),
459            },
460        )
461        .await?;
462    }
463
464    Ok(())
465}
466
467async fn update_worktree(
468    request: TypedEnvelope<proto::UpdateWorktree>,
469    rpc: &Arc<Peer>,
470    state: &Arc<AppState>,
471) -> tide::Result<()> {
472    {
473        let mut state = state.rpc.write().await;
474        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
475        for entry_id in &request.payload.removed_entries {
476            worktree.entries.remove(&entry_id);
477        }
478
479        for entry in &request.payload.updated_entries {
480            worktree.entries.insert(entry.id, entry.clone());
481        }
482    }
483
484    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
485    Ok(())
486}
487
488async fn close_worktree(
489    request: TypedEnvelope<proto::CloseWorktree>,
490    rpc: &Arc<Peer>,
491    state: &Arc<AppState>,
492) -> tide::Result<()> {
493    let connection_ids;
494    {
495        let mut state = state.rpc.write().await;
496        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
497        connection_ids = worktree.connection_ids();
498        if worktree.host_connection_id == Some(request.sender_id) {
499            worktree.host_connection_id = None;
500        } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
501            worktree.active_replica_ids.remove(&replica_id);
502        }
503    }
504
505    broadcast(request.sender_id, connection_ids, |conn_id| {
506        rpc.send(
507            conn_id,
508            proto::RemovePeer {
509                worktree_id: request.payload.worktree_id,
510                peer_id: request.sender_id.0,
511            },
512        )
513    })
514    .await?;
515
516    Ok(())
517}
518
519async fn open_buffer(
520    request: TypedEnvelope<proto::OpenBuffer>,
521    rpc: &Arc<Peer>,
522    state: &Arc<AppState>,
523) -> tide::Result<()> {
524    let receipt = request.receipt();
525    let worktree_id = request.payload.worktree_id;
526    let host_connection_id = state
527        .rpc
528        .read()
529        .await
530        .read_worktree(worktree_id, request.sender_id)?
531        .host_connection_id()?;
532
533    let response = rpc
534        .forward_request(request.sender_id, host_connection_id, request.payload)
535        .await?;
536    rpc.respond(receipt, response).await?;
537    Ok(())
538}
539
540async fn close_buffer(
541    request: TypedEnvelope<proto::CloseBuffer>,
542    rpc: &Arc<Peer>,
543    state: &Arc<AppState>,
544) -> tide::Result<()> {
545    let host_connection_id = state
546        .rpc
547        .read()
548        .await
549        .read_worktree(request.payload.worktree_id, request.sender_id)?
550        .host_connection_id()?;
551
552    rpc.forward_send(request.sender_id, host_connection_id, request.payload)
553        .await?;
554
555    Ok(())
556}
557
558async fn save_buffer(
559    request: TypedEnvelope<proto::SaveBuffer>,
560    rpc: &Arc<Peer>,
561    state: &Arc<AppState>,
562) -> tide::Result<()> {
563    let host;
564    let guests;
565    {
566        let state = state.rpc.read().await;
567        let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
568        host = worktree.host_connection_id()?;
569        guests = worktree
570            .guest_connection_ids
571            .keys()
572            .copied()
573            .collect::<Vec<_>>();
574    }
575
576    let sender = request.sender_id;
577    let receipt = request.receipt();
578    let response = rpc
579        .forward_request(sender, host, request.payload.clone())
580        .await?;
581
582    broadcast(host, guests, |conn_id| {
583        let response = response.clone();
584        async move {
585            if conn_id == sender {
586                rpc.respond(receipt, response).await
587            } else {
588                rpc.forward_send(host, conn_id, response).await
589            }
590        }
591    })
592    .await?;
593
594    Ok(())
595}
596
597async fn update_buffer(
598    request: TypedEnvelope<proto::UpdateBuffer>,
599    rpc: &Arc<Peer>,
600    state: &Arc<AppState>,
601) -> tide::Result<()> {
602    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
603}
604
605async fn buffer_saved(
606    request: TypedEnvelope<proto::BufferSaved>,
607    rpc: &Arc<Peer>,
608    state: &Arc<AppState>,
609) -> tide::Result<()> {
610    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
611}
612
613async fn get_channels(
614    request: TypedEnvelope<proto::GetChannels>,
615    rpc: &Arc<Peer>,
616    state: &Arc<AppState>,
617) -> tide::Result<()> {
618    let user_id = state
619        .rpc
620        .read()
621        .await
622        .user_id_for_connection(request.sender_id)?;
623    let channels = state.db.get_channels_for_user(user_id).await?;
624    rpc.respond(
625        request.receipt(),
626        proto::GetChannelsResponse {
627            channels: channels
628                .into_iter()
629                .map(|chan| proto::Channel {
630                    id: chan.id().0 as u64,
631                    name: chan.name,
632                })
633                .collect(),
634        },
635    )
636    .await?;
637    Ok(())
638}
639
640async fn join_channel(
641    request: TypedEnvelope<proto::JoinChannel>,
642    rpc: &Arc<Peer>,
643    state: &Arc<AppState>,
644) -> tide::Result<()> {
645    let user_id = state
646        .rpc
647        .read()
648        .await
649        .user_id_for_connection(request.sender_id)?;
650    if !state
651        .db
652        .can_user_access_channel(user_id, ChannelId(request.payload.channel_id as i32))
653        .await?
654    {
655        Err(anyhow!("access denied"))?;
656    }
657
658    Ok(())
659}
660
661async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
662    worktree_id: u64,
663    request: TypedEnvelope<T>,
664    rpc: &Arc<Peer>,
665    state: &Arc<AppState>,
666) -> tide::Result<()> {
667    let connection_ids = state
668        .rpc
669        .read()
670        .await
671        .read_worktree(worktree_id, request.sender_id)?
672        .connection_ids();
673
674    broadcast(request.sender_id, connection_ids, |conn_id| {
675        rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
676    })
677    .await?;
678
679    Ok(())
680}
681
682pub async fn broadcast<F, T>(
683    sender_id: ConnectionId,
684    receiver_ids: Vec<ConnectionId>,
685    mut f: F,
686) -> anyhow::Result<()>
687where
688    F: FnMut(ConnectionId) -> T,
689    T: Future<Output = anyhow::Result<()>>,
690{
691    let futures = receiver_ids
692        .into_iter()
693        .filter(|id| *id != sender_id)
694        .map(|id| f(id));
695    futures::future::try_join_all(futures).await?;
696    Ok(())
697}
698
699fn header_contains_ignore_case<T>(
700    request: &tide::Request<T>,
701    header_name: HeaderName,
702    value: &str,
703) -> bool {
704    request
705        .header(header_name)
706        .map(|h| {
707            h.as_str()
708                .split(',')
709                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
710        })
711        .unwrap_or(false)
712}