rpc.rs

  1use super::{
  2    auth::{self, PeerExt as _},
  3    db::{ChannelId, UserId},
  4    AppState,
  5};
  6use anyhow::anyhow;
  7use async_std::task;
  8use async_tungstenite::{
  9    tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
 10    WebSocketStream,
 11};
 12use futures::{future::BoxFuture, FutureExt};
 13use postage::prelude::Stream as _;
 14use sha1::{Digest as _, Sha1};
 15use std::{
 16    any::{Any, TypeId},
 17    collections::{HashMap, HashSet},
 18    future::Future,
 19    mem,
 20    sync::Arc,
 21    time::Instant,
 22};
 23use surf::StatusCode;
 24use tide::log;
 25use tide::{
 26    http::headers::{HeaderName, CONNECTION, UPGRADE},
 27    Request, Response,
 28};
 29use time::OffsetDateTime;
 30use zrpc::{
 31    auth::random_token,
 32    proto::{self, EnvelopedMessage},
 33    ConnectionId, Peer, Router, TypedEnvelope,
 34};
 35
 36type ReplicaId = u16;
 37
 38type Handler = Box<
 39    dyn Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
 40>;
 41
 42#[derive(Default)]
 43struct ServerBuilder {
 44    handlers: Vec<Handler>,
 45    handler_types: HashSet<TypeId>,
 46}
 47
 48impl ServerBuilder {
 49    pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
 50    where
 51        F: 'static + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
 52        Fut: 'static + Send + Future<Output = ()>,
 53        M: EnvelopedMessage,
 54    {
 55        if self.handler_types.insert(TypeId::of::<M>()) {
 56            panic!("registered a handler for the same message twice");
 57        }
 58
 59        self.handlers
 60            .push(Box::new(move |untyped_envelope, server| {
 61                if let Some(typed_envelope) = untyped_envelope.take() {
 62                    match typed_envelope.downcast::<TypedEnvelope<M>>() {
 63                        Ok(typed_envelope) => Some((handler)(typed_envelope, server).boxed()),
 64                        Err(envelope) => {
 65                            *untyped_envelope = Some(envelope);
 66                            None
 67                        }
 68                    }
 69                } else {
 70                    None
 71                }
 72            }));
 73        self
 74    }
 75
 76    pub fn build(self, rpc: Arc<zrpc::peer2::Peer>, state: Arc<AppState>) -> Arc<Server> {
 77        Arc::new(Server {
 78            rpc,
 79            state,
 80            handlers: self.handlers,
 81        })
 82    }
 83}
 84
 85struct Server {
 86    rpc: Arc<zrpc::peer2::Peer>,
 87    state: Arc<AppState>,
 88    handlers: Vec<Handler>,
 89}
 90
 91impl Server {
 92    pub async fn add_connection<Conn>(
 93        self: &Arc<Self>,
 94        connection: Conn,
 95        addr: String,
 96        user_id: UserId,
 97    ) where
 98        Conn: 'static
 99            + futures::Sink<WebSocketMessage, Error = WebSocketError>
100            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
101            + Send
102            + Unpin,
103    {
104        let this = self.clone();
105        let (connection_id, handle_io, mut incoming_rx) = this.rpc.add_connection(connection).await;
106        this.state
107            .rpc
108            .write()
109            .await
110            .add_connection(connection_id, user_id);
111
112        let handle_io = handle_io.fuse();
113        futures::pin_mut!(handle_io);
114        loop {
115            let next_message = incoming_rx.recv().fuse();
116            futures::pin_mut!(next_message);
117            futures::select_biased! {
118                message = next_message => {
119                    if let Some(message) = message {
120                        let mut message = Some(message);
121                        for handler in &this.handlers {
122                            if let Some(future) = (handler)(&mut message, this.clone()) {
123                                future.await;
124                                break;
125                            }
126                        }
127
128                        if let Some(message) = message {
129                            log::warn!("unhandled message: {:?}", message);
130                        }
131                    } else {
132                        log::info!("rpc connection closed {:?}", addr);
133                        break;
134                    }
135                }
136                handle_io = handle_io => {
137                    if let Err(err) = handle_io {
138                        log::error!("error handling rpc connection {:?} - {:?}", addr, err);
139                    }
140                    break;
141                }
142            }
143        }
144
145        if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await {
146            log::error!("error signing out connection {:?} - {:?}", addr, err);
147        }
148    }
149}
150
151#[derive(Default)]
152pub struct State {
153    connections: HashMap<ConnectionId, Connection>,
154    pub worktrees: HashMap<u64, Worktree>,
155    channels: HashMap<ChannelId, Channel>,
156    next_worktree_id: u64,
157}
158
159struct Connection {
160    user_id: UserId,
161    worktrees: HashSet<u64>,
162    channels: HashSet<ChannelId>,
163}
164
165pub struct Worktree {
166    host_connection_id: Option<ConnectionId>,
167    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
168    active_replica_ids: HashSet<ReplicaId>,
169    access_token: String,
170    root_name: String,
171    entries: HashMap<u64, proto::Entry>,
172}
173
174#[derive(Default)]
175struct Channel {
176    connection_ids: HashSet<ConnectionId>,
177}
178
179impl Worktree {
180    pub fn connection_ids(&self) -> Vec<ConnectionId> {
181        self.guest_connection_ids
182            .keys()
183            .copied()
184            .chain(self.host_connection_id)
185            .collect()
186    }
187
188    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
189        Ok(self
190            .host_connection_id
191            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
192    }
193}
194
195impl Channel {
196    fn connection_ids(&self) -> Vec<ConnectionId> {
197        self.connection_ids.iter().copied().collect()
198    }
199}
200
201impl State {
202    // Add a new connection associated with a given user.
203    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
204        self.connections.insert(
205            connection_id,
206            Connection {
207                user_id,
208                worktrees: Default::default(),
209                channels: Default::default(),
210            },
211        );
212    }
213
214    // Remove the given connection and its association with any worktrees.
215    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
216        let mut worktree_ids = Vec::new();
217        if let Some(connection) = self.connections.remove(&connection_id) {
218            for channel_id in connection.channels {
219                if let Some(channel) = self.channels.get_mut(&channel_id) {
220                    channel.connection_ids.remove(&connection_id);
221                }
222            }
223            for worktree_id in connection.worktrees {
224                if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
225                    if worktree.host_connection_id == Some(connection_id) {
226                        worktree_ids.push(worktree_id);
227                    } else if let Some(replica_id) =
228                        worktree.guest_connection_ids.remove(&connection_id)
229                    {
230                        worktree.active_replica_ids.remove(&replica_id);
231                        worktree_ids.push(worktree_id);
232                    }
233                }
234            }
235        }
236        worktree_ids
237    }
238
239    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
240        if let Some(connection) = self.connections.get_mut(&connection_id) {
241            connection.channels.insert(channel_id);
242            self.channels
243                .entry(channel_id)
244                .or_default()
245                .connection_ids
246                .insert(connection_id);
247        }
248    }
249
250    // Add the given connection as a guest of the given worktree
251    pub fn join_worktree(
252        &mut self,
253        connection_id: ConnectionId,
254        worktree_id: u64,
255        access_token: &str,
256    ) -> Option<(ReplicaId, &Worktree)> {
257        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
258            if access_token == worktree.access_token {
259                if let Some(connection) = self.connections.get_mut(&connection_id) {
260                    connection.worktrees.insert(worktree_id);
261                }
262
263                let mut replica_id = 1;
264                while worktree.active_replica_ids.contains(&replica_id) {
265                    replica_id += 1;
266                }
267                worktree.active_replica_ids.insert(replica_id);
268                worktree
269                    .guest_connection_ids
270                    .insert(connection_id, replica_id);
271                Some((replica_id, worktree))
272            } else {
273                None
274            }
275        } else {
276            None
277        }
278    }
279
280    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
281        Ok(self
282            .connections
283            .get(&connection_id)
284            .ok_or_else(|| anyhow!("unknown connection"))?
285            .user_id)
286    }
287
288    fn read_worktree(
289        &self,
290        worktree_id: u64,
291        connection_id: ConnectionId,
292    ) -> tide::Result<&Worktree> {
293        let worktree = self
294            .worktrees
295            .get(&worktree_id)
296            .ok_or_else(|| anyhow!("worktree not found"))?;
297
298        if worktree.host_connection_id == Some(connection_id)
299            || worktree.guest_connection_ids.contains_key(&connection_id)
300        {
301            Ok(worktree)
302        } else {
303            Err(anyhow!(
304                "{} is not a member of worktree {}",
305                connection_id,
306                worktree_id
307            ))?
308        }
309    }
310
311    fn write_worktree(
312        &mut self,
313        worktree_id: u64,
314        connection_id: ConnectionId,
315    ) -> tide::Result<&mut Worktree> {
316        let worktree = self
317            .worktrees
318            .get_mut(&worktree_id)
319            .ok_or_else(|| anyhow!("worktree not found"))?;
320
321        if worktree.host_connection_id == Some(connection_id)
322            || worktree.guest_connection_ids.contains_key(&connection_id)
323        {
324            Ok(worktree)
325        } else {
326            Err(anyhow!(
327                "{} is not a member of worktree {}",
328                connection_id,
329                worktree_id
330            ))?
331        }
332    }
333}
334
335trait MessageHandler<'a, M: proto::EnvelopedMessage> {
336    type Output: 'a + Send + Future<Output = tide::Result<()>>;
337
338    fn handle(
339        &self,
340        message: TypedEnvelope<M>,
341        rpc: &'a Arc<Peer>,
342        app_state: &'a Arc<AppState>,
343    ) -> Self::Output;
344}
345
346impl<'a, M, F, Fut> MessageHandler<'a, M> for F
347where
348    M: proto::EnvelopedMessage,
349    F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
350    Fut: 'a + Send + Future<Output = tide::Result<()>>,
351{
352    type Output = Fut;
353
354    fn handle(
355        &self,
356        message: TypedEnvelope<M>,
357        rpc: &'a Arc<Peer>,
358        app_state: &'a Arc<AppState>,
359    ) -> Self::Output {
360        (self)(message, rpc, app_state)
361    }
362}
363
364fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
365where
366    M: EnvelopedMessage,
367    H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
368{
369    let rpc = rpc.clone();
370    let handler = handler.clone();
371    let app_state = app_state.clone();
372    router.add_message_handler(move |message| {
373        let rpc = rpc.clone();
374        let handler = handler.clone();
375        let app_state = app_state.clone();
376        async move {
377            let sender_id = message.sender_id;
378            let message_id = message.message_id;
379            let start_time = Instant::now();
380            log::info!(
381                "RPC message received. id: {}.{}, type:{}",
382                sender_id,
383                message_id,
384                M::NAME
385            );
386            if let Err(err) = handler.handle(message, &rpc, &app_state).await {
387                log::error!("error handling message: {:?}", err);
388            } else {
389                log::info!(
390                    "RPC message handled. id:{}.{}, duration:{:?}",
391                    sender_id,
392                    message_id,
393                    start_time.elapsed()
394                );
395            }
396
397            Ok(())
398        }
399    });
400}
401
402pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
403    on_message(router, rpc, state, share_worktree);
404    on_message(router, rpc, state, join_worktree);
405    on_message(router, rpc, state, update_worktree);
406    on_message(router, rpc, state, close_worktree);
407    on_message(router, rpc, state, open_buffer);
408    on_message(router, rpc, state, close_buffer);
409    on_message(router, rpc, state, update_buffer);
410    on_message(router, rpc, state, buffer_saved);
411    on_message(router, rpc, state, save_buffer);
412    on_message(router, rpc, state, get_channels);
413    on_message(router, rpc, state, get_users);
414    on_message(router, rpc, state, join_channel);
415    on_message(router, rpc, state, send_channel_message);
416}
417
418pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
419    let mut router = Router::new();
420    add_rpc_routes(&mut router, app.state(), rpc);
421    let router = Arc::new(router);
422
423    let rpc = rpc.clone();
424    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
425        let user_id = request.ext::<UserId>().copied();
426        let rpc = rpc.clone();
427        let router = router.clone();
428        async move {
429            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
430
431            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
432            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
433            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
434
435            if !upgrade_requested {
436                return Ok(Response::new(StatusCode::UpgradeRequired));
437            }
438
439            let header = match request.header("Sec-Websocket-Key") {
440                Some(h) => h.as_str(),
441                None => return Err(anyhow!("expected sec-websocket-key"))?,
442            };
443
444            let mut response = Response::new(StatusCode::SwitchingProtocols);
445            response.insert_header(UPGRADE, "websocket");
446            response.insert_header(CONNECTION, "Upgrade");
447            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
448            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
449            response.insert_header("Sec-Websocket-Version", "13");
450
451            let http_res: &mut tide::http::Response = response.as_mut();
452            let upgrade_receiver = http_res.recv_upgrade().await;
453            let addr = request.remote().unwrap_or("unknown").to_string();
454            let state = request.state().clone();
455            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
456            task::spawn(async move {
457                if let Some(stream) = upgrade_receiver.await {
458                    let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
459                    handle_connection(rpc, router, state, addr, stream, user_id).await;
460                }
461            });
462
463            Ok(response)
464        }
465    });
466}
467
468pub async fn handle_connection<Conn>(
469    rpc: Arc<Peer>,
470    router: Arc<Router>,
471    state: Arc<AppState>,
472    addr: String,
473    stream: Conn,
474    user_id: UserId,
475) where
476    Conn: 'static
477        + futures::Sink<WebSocketMessage, Error = WebSocketError>
478        + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
479        + Send
480        + Unpin,
481{
482    log::info!("accepted rpc connection: {:?}", addr);
483    let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
484    state
485        .rpc
486        .write()
487        .await
488        .add_connection(connection_id, user_id);
489
490    let handle_messages = async move {
491        handle_messages.await;
492        Ok(())
493    };
494
495    if let Err(e) = futures::try_join!(handle_messages, handle_io) {
496        log::error!("error handling rpc connection {:?} - {:?}", addr, e);
497    }
498
499    log::info!("closing connection to {:?}", addr);
500    if let Err(e) = rpc.sign_out(connection_id, &state).await {
501        log::error!("error signing out connection {:?} - {:?}", addr, e);
502    }
503}
504
505async fn share_worktree(
506    mut request: TypedEnvelope<proto::ShareWorktree>,
507    rpc: &Arc<Peer>,
508    state: &Arc<AppState>,
509) -> tide::Result<()> {
510    let mut state = state.rpc.write().await;
511    let worktree_id = state.next_worktree_id;
512    state.next_worktree_id += 1;
513    let access_token = random_token();
514    let worktree = request
515        .payload
516        .worktree
517        .as_mut()
518        .ok_or_else(|| anyhow!("missing worktree"))?;
519    let entries = mem::take(&mut worktree.entries)
520        .into_iter()
521        .map(|entry| (entry.id, entry))
522        .collect();
523    state.worktrees.insert(
524        worktree_id,
525        Worktree {
526            host_connection_id: Some(request.sender_id),
527            guest_connection_ids: Default::default(),
528            active_replica_ids: Default::default(),
529            access_token: access_token.clone(),
530            root_name: mem::take(&mut worktree.root_name),
531            entries,
532        },
533    );
534
535    rpc.respond(
536        request.receipt(),
537        proto::ShareWorktreeResponse {
538            worktree_id,
539            access_token,
540        },
541    )
542    .await?;
543    Ok(())
544}
545
546async fn join_worktree(
547    request: TypedEnvelope<proto::OpenWorktree>,
548    rpc: &Arc<Peer>,
549    state: &Arc<AppState>,
550) -> tide::Result<()> {
551    let worktree_id = request.payload.worktree_id;
552    let access_token = &request.payload.access_token;
553
554    let mut state = state.rpc.write().await;
555    if let Some((peer_replica_id, worktree)) =
556        state.join_worktree(request.sender_id, worktree_id, access_token)
557    {
558        let mut peers = Vec::new();
559        if let Some(host_connection_id) = worktree.host_connection_id {
560            peers.push(proto::Peer {
561                peer_id: host_connection_id.0,
562                replica_id: 0,
563            });
564        }
565        for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
566            if *peer_conn_id != request.sender_id {
567                peers.push(proto::Peer {
568                    peer_id: peer_conn_id.0,
569                    replica_id: *peer_replica_id as u32,
570                });
571            }
572        }
573
574        broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
575            rpc.send(
576                conn_id,
577                proto::AddPeer {
578                    worktree_id,
579                    peer: Some(proto::Peer {
580                        peer_id: request.sender_id.0,
581                        replica_id: peer_replica_id as u32,
582                    }),
583                },
584            )
585        })
586        .await?;
587        rpc.respond(
588            request.receipt(),
589            proto::OpenWorktreeResponse {
590                worktree_id,
591                worktree: Some(proto::Worktree {
592                    root_name: worktree.root_name.clone(),
593                    entries: worktree.entries.values().cloned().collect(),
594                }),
595                replica_id: peer_replica_id as u32,
596                peers,
597            },
598        )
599        .await?;
600    } else {
601        rpc.respond(
602            request.receipt(),
603            proto::OpenWorktreeResponse {
604                worktree_id,
605                worktree: None,
606                replica_id: 0,
607                peers: Vec::new(),
608            },
609        )
610        .await?;
611    }
612
613    Ok(())
614}
615
616async fn update_worktree(
617    request: TypedEnvelope<proto::UpdateWorktree>,
618    rpc: &Arc<Peer>,
619    state: &Arc<AppState>,
620) -> tide::Result<()> {
621    {
622        let mut state = state.rpc.write().await;
623        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
624        for entry_id in &request.payload.removed_entries {
625            worktree.entries.remove(&entry_id);
626        }
627
628        for entry in &request.payload.updated_entries {
629            worktree.entries.insert(entry.id, entry.clone());
630        }
631    }
632
633    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
634    Ok(())
635}
636
637async fn close_worktree(
638    request: TypedEnvelope<proto::CloseWorktree>,
639    rpc: &Arc<Peer>,
640    state: &Arc<AppState>,
641) -> tide::Result<()> {
642    let connection_ids;
643    {
644        let mut state = state.rpc.write().await;
645        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
646        connection_ids = worktree.connection_ids();
647        if worktree.host_connection_id == Some(request.sender_id) {
648            worktree.host_connection_id = None;
649        } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
650            worktree.active_replica_ids.remove(&replica_id);
651        }
652    }
653
654    broadcast(request.sender_id, connection_ids, |conn_id| {
655        rpc.send(
656            conn_id,
657            proto::RemovePeer {
658                worktree_id: request.payload.worktree_id,
659                peer_id: request.sender_id.0,
660            },
661        )
662    })
663    .await?;
664
665    Ok(())
666}
667
668async fn open_buffer(
669    request: TypedEnvelope<proto::OpenBuffer>,
670    rpc: &Arc<Peer>,
671    state: &Arc<AppState>,
672) -> tide::Result<()> {
673    let receipt = request.receipt();
674    let worktree_id = request.payload.worktree_id;
675    let host_connection_id = state
676        .rpc
677        .read()
678        .await
679        .read_worktree(worktree_id, request.sender_id)?
680        .host_connection_id()?;
681
682    let response = rpc
683        .forward_request(request.sender_id, host_connection_id, request.payload)
684        .await?;
685    rpc.respond(receipt, response).await?;
686    Ok(())
687}
688
689async fn close_buffer(
690    request: TypedEnvelope<proto::CloseBuffer>,
691    rpc: &Arc<Peer>,
692    state: &Arc<AppState>,
693) -> tide::Result<()> {
694    let host_connection_id = state
695        .rpc
696        .read()
697        .await
698        .read_worktree(request.payload.worktree_id, request.sender_id)?
699        .host_connection_id()?;
700
701    rpc.forward_send(request.sender_id, host_connection_id, request.payload)
702        .await?;
703
704    Ok(())
705}
706
707async fn save_buffer(
708    request: TypedEnvelope<proto::SaveBuffer>,
709    rpc: &Arc<Peer>,
710    state: &Arc<AppState>,
711) -> tide::Result<()> {
712    let host;
713    let guests;
714    {
715        let state = state.rpc.read().await;
716        let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
717        host = worktree.host_connection_id()?;
718        guests = worktree
719            .guest_connection_ids
720            .keys()
721            .copied()
722            .collect::<Vec<_>>();
723    }
724
725    let sender = request.sender_id;
726    let receipt = request.receipt();
727    let response = rpc
728        .forward_request(sender, host, request.payload.clone())
729        .await?;
730
731    broadcast(host, guests, |conn_id| {
732        let response = response.clone();
733        async move {
734            if conn_id == sender {
735                rpc.respond(receipt, response).await
736            } else {
737                rpc.forward_send(host, conn_id, response).await
738            }
739        }
740    })
741    .await?;
742
743    Ok(())
744}
745
746async fn update_buffer(
747    request: TypedEnvelope<proto::UpdateBuffer>,
748    rpc: &Arc<Peer>,
749    state: &Arc<AppState>,
750) -> tide::Result<()> {
751    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
752}
753
754async fn buffer_saved(
755    request: TypedEnvelope<proto::BufferSaved>,
756    rpc: &Arc<Peer>,
757    state: &Arc<AppState>,
758) -> tide::Result<()> {
759    broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
760}
761
762async fn get_channels(
763    request: TypedEnvelope<proto::GetChannels>,
764    rpc: &Arc<Peer>,
765    state: &Arc<AppState>,
766) -> tide::Result<()> {
767    let user_id = state
768        .rpc
769        .read()
770        .await
771        .user_id_for_connection(request.sender_id)?;
772    let channels = state.db.get_channels_for_user(user_id).await?;
773    rpc.respond(
774        request.receipt(),
775        proto::GetChannelsResponse {
776            channels: channels
777                .into_iter()
778                .map(|chan| proto::Channel {
779                    id: chan.id.to_proto(),
780                    name: chan.name,
781                })
782                .collect(),
783        },
784    )
785    .await?;
786    Ok(())
787}
788
789async fn get_users(
790    request: TypedEnvelope<proto::GetUsers>,
791    rpc: &Arc<Peer>,
792    state: &Arc<AppState>,
793) -> tide::Result<()> {
794    let user_id = state
795        .rpc
796        .read()
797        .await
798        .user_id_for_connection(request.sender_id)?;
799    let receipt = request.receipt();
800    let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
801    let users = state
802        .db
803        .get_users_by_ids(user_id, user_ids)
804        .await?
805        .into_iter()
806        .map(|user| proto::User {
807            id: user.id.to_proto(),
808            github_login: user.github_login,
809            avatar_url: String::new(),
810        })
811        .collect();
812    rpc.respond(receipt, proto::GetUsersResponse { users })
813        .await?;
814    Ok(())
815}
816
817async fn join_channel(
818    request: TypedEnvelope<proto::JoinChannel>,
819    rpc: &Arc<Peer>,
820    state: &Arc<AppState>,
821) -> tide::Result<()> {
822    let user_id = state
823        .rpc
824        .read()
825        .await
826        .user_id_for_connection(request.sender_id)?;
827    let channel_id = ChannelId::from_proto(request.payload.channel_id);
828    if !state
829        .db
830        .can_user_access_channel(user_id, channel_id)
831        .await?
832    {
833        Err(anyhow!("access denied"))?;
834    }
835
836    state
837        .rpc
838        .write()
839        .await
840        .join_channel(request.sender_id, channel_id);
841    let messages = state
842        .db
843        .get_recent_channel_messages(channel_id, 50)
844        .await?
845        .into_iter()
846        .map(|msg| proto::ChannelMessage {
847            id: msg.id.to_proto(),
848            body: msg.body,
849            timestamp: msg.sent_at.unix_timestamp() as u64,
850            sender_id: msg.sender_id.to_proto(),
851        })
852        .collect();
853    rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
854        .await?;
855    Ok(())
856}
857
858async fn send_channel_message(
859    request: TypedEnvelope<proto::SendChannelMessage>,
860    peer: &Arc<Peer>,
861    app: &Arc<AppState>,
862) -> tide::Result<()> {
863    let channel_id = ChannelId::from_proto(request.payload.channel_id);
864    let user_id;
865    let connection_ids;
866    {
867        let state = app.rpc.read().await;
868        user_id = state.user_id_for_connection(request.sender_id)?;
869        if let Some(channel) = state.channels.get(&channel_id) {
870            connection_ids = channel.connection_ids();
871        } else {
872            return Ok(());
873        }
874    }
875
876    let timestamp = OffsetDateTime::now_utc();
877    let message_id = app
878        .db
879        .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
880        .await?;
881    let message = proto::ChannelMessageSent {
882        channel_id: channel_id.to_proto(),
883        message: Some(proto::ChannelMessage {
884            sender_id: user_id.to_proto(),
885            id: message_id.to_proto(),
886            body: request.payload.body,
887            timestamp: timestamp.unix_timestamp() as u64,
888        }),
889    };
890    broadcast(request.sender_id, connection_ids, |conn_id| {
891        peer.send(conn_id, message.clone())
892    })
893    .await?;
894
895    Ok(())
896}
897
898async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
899    worktree_id: u64,
900    request: TypedEnvelope<T>,
901    rpc: &Arc<Peer>,
902    state: &Arc<AppState>,
903) -> tide::Result<()> {
904    let connection_ids = state
905        .rpc
906        .read()
907        .await
908        .read_worktree(worktree_id, request.sender_id)?
909        .connection_ids();
910
911    broadcast(request.sender_id, connection_ids, |conn_id| {
912        rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
913    })
914    .await?;
915
916    Ok(())
917}
918
919pub async fn broadcast<F, T>(
920    sender_id: ConnectionId,
921    receiver_ids: Vec<ConnectionId>,
922    mut f: F,
923) -> anyhow::Result<()>
924where
925    F: FnMut(ConnectionId) -> T,
926    T: Future<Output = anyhow::Result<()>>,
927{
928    let futures = receiver_ids
929        .into_iter()
930        .filter(|id| *id != sender_id)
931        .map(|id| f(id));
932    futures::future::try_join_all(futures).await?;
933    Ok(())
934}
935
936fn header_contains_ignore_case<T>(
937    request: &tide::Request<T>,
938    header_name: HeaderName,
939    value: &str,
940) -> bool {
941    request
942        .header(header_name)
943        .map(|h| {
944            h.as_str()
945                .split(',')
946                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
947        })
948        .unwrap_or(false)
949}