1use super::{
2 auth::{self, PeerExt as _},
3 db::{ChannelId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::task;
8use async_tungstenite::{
9 tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
10 WebSocketStream,
11};
12use futures::{future::BoxFuture, FutureExt};
13use postage::prelude::Stream as _;
14use sha1::{Digest as _, Sha1};
15use std::{
16 any::{Any, TypeId},
17 collections::{HashMap, HashSet},
18 future::Future,
19 mem,
20 sync::Arc,
21 time::Instant,
22};
23use surf::StatusCode;
24use tide::log;
25use tide::{
26 http::headers::{HeaderName, CONNECTION, UPGRADE},
27 Request, Response,
28};
29use time::OffsetDateTime;
30use zrpc::{
31 auth::random_token,
32 proto::{self, EnvelopedMessage},
33 ConnectionId, Peer, Router, TypedEnvelope,
34};
35
36type ReplicaId = u16;
37
38type Handler = Box<
39 dyn Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
40>;
41
42#[derive(Default)]
43struct ServerBuilder {
44 handlers: Vec<Handler>,
45 handler_types: HashSet<TypeId>,
46}
47
48impl ServerBuilder {
49 pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
50 where
51 F: 'static + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
52 Fut: 'static + Send + Future<Output = ()>,
53 M: EnvelopedMessage,
54 {
55 if self.handler_types.insert(TypeId::of::<M>()) {
56 panic!("registered a handler for the same message twice");
57 }
58
59 self.handlers
60 .push(Box::new(move |untyped_envelope, server| {
61 if let Some(typed_envelope) = untyped_envelope.take() {
62 match typed_envelope.downcast::<TypedEnvelope<M>>() {
63 Ok(typed_envelope) => Some((handler)(typed_envelope, server).boxed()),
64 Err(envelope) => {
65 *untyped_envelope = Some(envelope);
66 None
67 }
68 }
69 } else {
70 None
71 }
72 }));
73 self
74 }
75
76 pub fn build(self, rpc: Arc<zrpc::peer2::Peer>, state: Arc<AppState>) -> Arc<Server> {
77 Arc::new(Server {
78 rpc,
79 state,
80 handlers: self.handlers,
81 })
82 }
83}
84
85struct Server {
86 rpc: Arc<zrpc::peer2::Peer>,
87 state: Arc<AppState>,
88 handlers: Vec<Handler>,
89}
90
91impl Server {
92 pub async fn add_connection<Conn>(
93 self: &Arc<Self>,
94 connection: Conn,
95 addr: String,
96 user_id: UserId,
97 ) where
98 Conn: 'static
99 + futures::Sink<WebSocketMessage, Error = WebSocketError>
100 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
101 + Send
102 + Unpin,
103 {
104 let this = self.clone();
105 let (connection_id, handle_io, mut incoming_rx) = this.rpc.add_connection(connection).await;
106 this.state
107 .rpc
108 .write()
109 .await
110 .add_connection(connection_id, user_id);
111
112 let handle_io = handle_io.fuse();
113 futures::pin_mut!(handle_io);
114 loop {
115 let next_message = incoming_rx.recv().fuse();
116 futures::pin_mut!(next_message);
117 futures::select_biased! {
118 message = next_message => {
119 if let Some(message) = message {
120 let mut message = Some(message);
121 for handler in &this.handlers {
122 if let Some(future) = (handler)(&mut message, this.clone()) {
123 future.await;
124 break;
125 }
126 }
127
128 if let Some(message) = message {
129 log::warn!("unhandled message: {:?}", message);
130 }
131 } else {
132 log::info!("rpc connection closed {:?}", addr);
133 break;
134 }
135 }
136 handle_io = handle_io => {
137 if let Err(err) = handle_io {
138 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
139 }
140 break;
141 }
142 }
143 }
144
145 if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await {
146 log::error!("error signing out connection {:?} - {:?}", addr, err);
147 }
148 }
149}
150
151#[derive(Default)]
152pub struct State {
153 connections: HashMap<ConnectionId, Connection>,
154 pub worktrees: HashMap<u64, Worktree>,
155 channels: HashMap<ChannelId, Channel>,
156 next_worktree_id: u64,
157}
158
159struct Connection {
160 user_id: UserId,
161 worktrees: HashSet<u64>,
162 channels: HashSet<ChannelId>,
163}
164
165pub struct Worktree {
166 host_connection_id: Option<ConnectionId>,
167 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
168 active_replica_ids: HashSet<ReplicaId>,
169 access_token: String,
170 root_name: String,
171 entries: HashMap<u64, proto::Entry>,
172}
173
174#[derive(Default)]
175struct Channel {
176 connection_ids: HashSet<ConnectionId>,
177}
178
179impl Worktree {
180 pub fn connection_ids(&self) -> Vec<ConnectionId> {
181 self.guest_connection_ids
182 .keys()
183 .copied()
184 .chain(self.host_connection_id)
185 .collect()
186 }
187
188 fn host_connection_id(&self) -> tide::Result<ConnectionId> {
189 Ok(self
190 .host_connection_id
191 .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
192 }
193}
194
195impl Channel {
196 fn connection_ids(&self) -> Vec<ConnectionId> {
197 self.connection_ids.iter().copied().collect()
198 }
199}
200
201impl State {
202 // Add a new connection associated with a given user.
203 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
204 self.connections.insert(
205 connection_id,
206 Connection {
207 user_id,
208 worktrees: Default::default(),
209 channels: Default::default(),
210 },
211 );
212 }
213
214 // Remove the given connection and its association with any worktrees.
215 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
216 let mut worktree_ids = Vec::new();
217 if let Some(connection) = self.connections.remove(&connection_id) {
218 for channel_id in connection.channels {
219 if let Some(channel) = self.channels.get_mut(&channel_id) {
220 channel.connection_ids.remove(&connection_id);
221 }
222 }
223 for worktree_id in connection.worktrees {
224 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
225 if worktree.host_connection_id == Some(connection_id) {
226 worktree_ids.push(worktree_id);
227 } else if let Some(replica_id) =
228 worktree.guest_connection_ids.remove(&connection_id)
229 {
230 worktree.active_replica_ids.remove(&replica_id);
231 worktree_ids.push(worktree_id);
232 }
233 }
234 }
235 }
236 worktree_ids
237 }
238
239 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
240 if let Some(connection) = self.connections.get_mut(&connection_id) {
241 connection.channels.insert(channel_id);
242 self.channels
243 .entry(channel_id)
244 .or_default()
245 .connection_ids
246 .insert(connection_id);
247 }
248 }
249
250 // Add the given connection as a guest of the given worktree
251 pub fn join_worktree(
252 &mut self,
253 connection_id: ConnectionId,
254 worktree_id: u64,
255 access_token: &str,
256 ) -> Option<(ReplicaId, &Worktree)> {
257 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
258 if access_token == worktree.access_token {
259 if let Some(connection) = self.connections.get_mut(&connection_id) {
260 connection.worktrees.insert(worktree_id);
261 }
262
263 let mut replica_id = 1;
264 while worktree.active_replica_ids.contains(&replica_id) {
265 replica_id += 1;
266 }
267 worktree.active_replica_ids.insert(replica_id);
268 worktree
269 .guest_connection_ids
270 .insert(connection_id, replica_id);
271 Some((replica_id, worktree))
272 } else {
273 None
274 }
275 } else {
276 None
277 }
278 }
279
280 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
281 Ok(self
282 .connections
283 .get(&connection_id)
284 .ok_or_else(|| anyhow!("unknown connection"))?
285 .user_id)
286 }
287
288 fn read_worktree(
289 &self,
290 worktree_id: u64,
291 connection_id: ConnectionId,
292 ) -> tide::Result<&Worktree> {
293 let worktree = self
294 .worktrees
295 .get(&worktree_id)
296 .ok_or_else(|| anyhow!("worktree not found"))?;
297
298 if worktree.host_connection_id == Some(connection_id)
299 || worktree.guest_connection_ids.contains_key(&connection_id)
300 {
301 Ok(worktree)
302 } else {
303 Err(anyhow!(
304 "{} is not a member of worktree {}",
305 connection_id,
306 worktree_id
307 ))?
308 }
309 }
310
311 fn write_worktree(
312 &mut self,
313 worktree_id: u64,
314 connection_id: ConnectionId,
315 ) -> tide::Result<&mut Worktree> {
316 let worktree = self
317 .worktrees
318 .get_mut(&worktree_id)
319 .ok_or_else(|| anyhow!("worktree not found"))?;
320
321 if worktree.host_connection_id == Some(connection_id)
322 || worktree.guest_connection_ids.contains_key(&connection_id)
323 {
324 Ok(worktree)
325 } else {
326 Err(anyhow!(
327 "{} is not a member of worktree {}",
328 connection_id,
329 worktree_id
330 ))?
331 }
332 }
333}
334
335trait MessageHandler<'a, M: proto::EnvelopedMessage> {
336 type Output: 'a + Send + Future<Output = tide::Result<()>>;
337
338 fn handle(
339 &self,
340 message: TypedEnvelope<M>,
341 rpc: &'a Arc<Peer>,
342 app_state: &'a Arc<AppState>,
343 ) -> Self::Output;
344}
345
346impl<'a, M, F, Fut> MessageHandler<'a, M> for F
347where
348 M: proto::EnvelopedMessage,
349 F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
350 Fut: 'a + Send + Future<Output = tide::Result<()>>,
351{
352 type Output = Fut;
353
354 fn handle(
355 &self,
356 message: TypedEnvelope<M>,
357 rpc: &'a Arc<Peer>,
358 app_state: &'a Arc<AppState>,
359 ) -> Self::Output {
360 (self)(message, rpc, app_state)
361 }
362}
363
364fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
365where
366 M: EnvelopedMessage,
367 H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
368{
369 let rpc = rpc.clone();
370 let handler = handler.clone();
371 let app_state = app_state.clone();
372 router.add_message_handler(move |message| {
373 let rpc = rpc.clone();
374 let handler = handler.clone();
375 let app_state = app_state.clone();
376 async move {
377 let sender_id = message.sender_id;
378 let message_id = message.message_id;
379 let start_time = Instant::now();
380 log::info!(
381 "RPC message received. id: {}.{}, type:{}",
382 sender_id,
383 message_id,
384 M::NAME
385 );
386 if let Err(err) = handler.handle(message, &rpc, &app_state).await {
387 log::error!("error handling message: {:?}", err);
388 } else {
389 log::info!(
390 "RPC message handled. id:{}.{}, duration:{:?}",
391 sender_id,
392 message_id,
393 start_time.elapsed()
394 );
395 }
396
397 Ok(())
398 }
399 });
400}
401
402pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
403 on_message(router, rpc, state, share_worktree);
404 on_message(router, rpc, state, join_worktree);
405 on_message(router, rpc, state, update_worktree);
406 on_message(router, rpc, state, close_worktree);
407 on_message(router, rpc, state, open_buffer);
408 on_message(router, rpc, state, close_buffer);
409 on_message(router, rpc, state, update_buffer);
410 on_message(router, rpc, state, buffer_saved);
411 on_message(router, rpc, state, save_buffer);
412 on_message(router, rpc, state, get_channels);
413 on_message(router, rpc, state, get_users);
414 on_message(router, rpc, state, join_channel);
415 on_message(router, rpc, state, send_channel_message);
416}
417
418pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
419 let mut router = Router::new();
420 add_rpc_routes(&mut router, app.state(), rpc);
421 let router = Arc::new(router);
422
423 let rpc = rpc.clone();
424 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
425 let user_id = request.ext::<UserId>().copied();
426 let rpc = rpc.clone();
427 let router = router.clone();
428 async move {
429 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
430
431 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
432 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
433 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
434
435 if !upgrade_requested {
436 return Ok(Response::new(StatusCode::UpgradeRequired));
437 }
438
439 let header = match request.header("Sec-Websocket-Key") {
440 Some(h) => h.as_str(),
441 None => return Err(anyhow!("expected sec-websocket-key"))?,
442 };
443
444 let mut response = Response::new(StatusCode::SwitchingProtocols);
445 response.insert_header(UPGRADE, "websocket");
446 response.insert_header(CONNECTION, "Upgrade");
447 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
448 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
449 response.insert_header("Sec-Websocket-Version", "13");
450
451 let http_res: &mut tide::http::Response = response.as_mut();
452 let upgrade_receiver = http_res.recv_upgrade().await;
453 let addr = request.remote().unwrap_or("unknown").to_string();
454 let state = request.state().clone();
455 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
456 task::spawn(async move {
457 if let Some(stream) = upgrade_receiver.await {
458 let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
459 handle_connection(rpc, router, state, addr, stream, user_id).await;
460 }
461 });
462
463 Ok(response)
464 }
465 });
466}
467
468pub async fn handle_connection<Conn>(
469 rpc: Arc<Peer>,
470 router: Arc<Router>,
471 state: Arc<AppState>,
472 addr: String,
473 stream: Conn,
474 user_id: UserId,
475) where
476 Conn: 'static
477 + futures::Sink<WebSocketMessage, Error = WebSocketError>
478 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
479 + Send
480 + Unpin,
481{
482 log::info!("accepted rpc connection: {:?}", addr);
483 let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
484 state
485 .rpc
486 .write()
487 .await
488 .add_connection(connection_id, user_id);
489
490 let handle_messages = async move {
491 handle_messages.await;
492 Ok(())
493 };
494
495 if let Err(e) = futures::try_join!(handle_messages, handle_io) {
496 log::error!("error handling rpc connection {:?} - {:?}", addr, e);
497 }
498
499 log::info!("closing connection to {:?}", addr);
500 if let Err(e) = rpc.sign_out(connection_id, &state).await {
501 log::error!("error signing out connection {:?} - {:?}", addr, e);
502 }
503}
504
505async fn share_worktree(
506 mut request: TypedEnvelope<proto::ShareWorktree>,
507 rpc: &Arc<Peer>,
508 state: &Arc<AppState>,
509) -> tide::Result<()> {
510 let mut state = state.rpc.write().await;
511 let worktree_id = state.next_worktree_id;
512 state.next_worktree_id += 1;
513 let access_token = random_token();
514 let worktree = request
515 .payload
516 .worktree
517 .as_mut()
518 .ok_or_else(|| anyhow!("missing worktree"))?;
519 let entries = mem::take(&mut worktree.entries)
520 .into_iter()
521 .map(|entry| (entry.id, entry))
522 .collect();
523 state.worktrees.insert(
524 worktree_id,
525 Worktree {
526 host_connection_id: Some(request.sender_id),
527 guest_connection_ids: Default::default(),
528 active_replica_ids: Default::default(),
529 access_token: access_token.clone(),
530 root_name: mem::take(&mut worktree.root_name),
531 entries,
532 },
533 );
534
535 rpc.respond(
536 request.receipt(),
537 proto::ShareWorktreeResponse {
538 worktree_id,
539 access_token,
540 },
541 )
542 .await?;
543 Ok(())
544}
545
546async fn join_worktree(
547 request: TypedEnvelope<proto::OpenWorktree>,
548 rpc: &Arc<Peer>,
549 state: &Arc<AppState>,
550) -> tide::Result<()> {
551 let worktree_id = request.payload.worktree_id;
552 let access_token = &request.payload.access_token;
553
554 let mut state = state.rpc.write().await;
555 if let Some((peer_replica_id, worktree)) =
556 state.join_worktree(request.sender_id, worktree_id, access_token)
557 {
558 let mut peers = Vec::new();
559 if let Some(host_connection_id) = worktree.host_connection_id {
560 peers.push(proto::Peer {
561 peer_id: host_connection_id.0,
562 replica_id: 0,
563 });
564 }
565 for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
566 if *peer_conn_id != request.sender_id {
567 peers.push(proto::Peer {
568 peer_id: peer_conn_id.0,
569 replica_id: *peer_replica_id as u32,
570 });
571 }
572 }
573
574 broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
575 rpc.send(
576 conn_id,
577 proto::AddPeer {
578 worktree_id,
579 peer: Some(proto::Peer {
580 peer_id: request.sender_id.0,
581 replica_id: peer_replica_id as u32,
582 }),
583 },
584 )
585 })
586 .await?;
587 rpc.respond(
588 request.receipt(),
589 proto::OpenWorktreeResponse {
590 worktree_id,
591 worktree: Some(proto::Worktree {
592 root_name: worktree.root_name.clone(),
593 entries: worktree.entries.values().cloned().collect(),
594 }),
595 replica_id: peer_replica_id as u32,
596 peers,
597 },
598 )
599 .await?;
600 } else {
601 rpc.respond(
602 request.receipt(),
603 proto::OpenWorktreeResponse {
604 worktree_id,
605 worktree: None,
606 replica_id: 0,
607 peers: Vec::new(),
608 },
609 )
610 .await?;
611 }
612
613 Ok(())
614}
615
616async fn update_worktree(
617 request: TypedEnvelope<proto::UpdateWorktree>,
618 rpc: &Arc<Peer>,
619 state: &Arc<AppState>,
620) -> tide::Result<()> {
621 {
622 let mut state = state.rpc.write().await;
623 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
624 for entry_id in &request.payload.removed_entries {
625 worktree.entries.remove(&entry_id);
626 }
627
628 for entry in &request.payload.updated_entries {
629 worktree.entries.insert(entry.id, entry.clone());
630 }
631 }
632
633 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
634 Ok(())
635}
636
637async fn close_worktree(
638 request: TypedEnvelope<proto::CloseWorktree>,
639 rpc: &Arc<Peer>,
640 state: &Arc<AppState>,
641) -> tide::Result<()> {
642 let connection_ids;
643 {
644 let mut state = state.rpc.write().await;
645 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
646 connection_ids = worktree.connection_ids();
647 if worktree.host_connection_id == Some(request.sender_id) {
648 worktree.host_connection_id = None;
649 } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
650 worktree.active_replica_ids.remove(&replica_id);
651 }
652 }
653
654 broadcast(request.sender_id, connection_ids, |conn_id| {
655 rpc.send(
656 conn_id,
657 proto::RemovePeer {
658 worktree_id: request.payload.worktree_id,
659 peer_id: request.sender_id.0,
660 },
661 )
662 })
663 .await?;
664
665 Ok(())
666}
667
668async fn open_buffer(
669 request: TypedEnvelope<proto::OpenBuffer>,
670 rpc: &Arc<Peer>,
671 state: &Arc<AppState>,
672) -> tide::Result<()> {
673 let receipt = request.receipt();
674 let worktree_id = request.payload.worktree_id;
675 let host_connection_id = state
676 .rpc
677 .read()
678 .await
679 .read_worktree(worktree_id, request.sender_id)?
680 .host_connection_id()?;
681
682 let response = rpc
683 .forward_request(request.sender_id, host_connection_id, request.payload)
684 .await?;
685 rpc.respond(receipt, response).await?;
686 Ok(())
687}
688
689async fn close_buffer(
690 request: TypedEnvelope<proto::CloseBuffer>,
691 rpc: &Arc<Peer>,
692 state: &Arc<AppState>,
693) -> tide::Result<()> {
694 let host_connection_id = state
695 .rpc
696 .read()
697 .await
698 .read_worktree(request.payload.worktree_id, request.sender_id)?
699 .host_connection_id()?;
700
701 rpc.forward_send(request.sender_id, host_connection_id, request.payload)
702 .await?;
703
704 Ok(())
705}
706
707async fn save_buffer(
708 request: TypedEnvelope<proto::SaveBuffer>,
709 rpc: &Arc<Peer>,
710 state: &Arc<AppState>,
711) -> tide::Result<()> {
712 let host;
713 let guests;
714 {
715 let state = state.rpc.read().await;
716 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
717 host = worktree.host_connection_id()?;
718 guests = worktree
719 .guest_connection_ids
720 .keys()
721 .copied()
722 .collect::<Vec<_>>();
723 }
724
725 let sender = request.sender_id;
726 let receipt = request.receipt();
727 let response = rpc
728 .forward_request(sender, host, request.payload.clone())
729 .await?;
730
731 broadcast(host, guests, |conn_id| {
732 let response = response.clone();
733 async move {
734 if conn_id == sender {
735 rpc.respond(receipt, response).await
736 } else {
737 rpc.forward_send(host, conn_id, response).await
738 }
739 }
740 })
741 .await?;
742
743 Ok(())
744}
745
746async fn update_buffer(
747 request: TypedEnvelope<proto::UpdateBuffer>,
748 rpc: &Arc<Peer>,
749 state: &Arc<AppState>,
750) -> tide::Result<()> {
751 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
752}
753
754async fn buffer_saved(
755 request: TypedEnvelope<proto::BufferSaved>,
756 rpc: &Arc<Peer>,
757 state: &Arc<AppState>,
758) -> tide::Result<()> {
759 broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
760}
761
762async fn get_channels(
763 request: TypedEnvelope<proto::GetChannels>,
764 rpc: &Arc<Peer>,
765 state: &Arc<AppState>,
766) -> tide::Result<()> {
767 let user_id = state
768 .rpc
769 .read()
770 .await
771 .user_id_for_connection(request.sender_id)?;
772 let channels = state.db.get_channels_for_user(user_id).await?;
773 rpc.respond(
774 request.receipt(),
775 proto::GetChannelsResponse {
776 channels: channels
777 .into_iter()
778 .map(|chan| proto::Channel {
779 id: chan.id.to_proto(),
780 name: chan.name,
781 })
782 .collect(),
783 },
784 )
785 .await?;
786 Ok(())
787}
788
789async fn get_users(
790 request: TypedEnvelope<proto::GetUsers>,
791 rpc: &Arc<Peer>,
792 state: &Arc<AppState>,
793) -> tide::Result<()> {
794 let user_id = state
795 .rpc
796 .read()
797 .await
798 .user_id_for_connection(request.sender_id)?;
799 let receipt = request.receipt();
800 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
801 let users = state
802 .db
803 .get_users_by_ids(user_id, user_ids)
804 .await?
805 .into_iter()
806 .map(|user| proto::User {
807 id: user.id.to_proto(),
808 github_login: user.github_login,
809 avatar_url: String::new(),
810 })
811 .collect();
812 rpc.respond(receipt, proto::GetUsersResponse { users })
813 .await?;
814 Ok(())
815}
816
817async fn join_channel(
818 request: TypedEnvelope<proto::JoinChannel>,
819 rpc: &Arc<Peer>,
820 state: &Arc<AppState>,
821) -> tide::Result<()> {
822 let user_id = state
823 .rpc
824 .read()
825 .await
826 .user_id_for_connection(request.sender_id)?;
827 let channel_id = ChannelId::from_proto(request.payload.channel_id);
828 if !state
829 .db
830 .can_user_access_channel(user_id, channel_id)
831 .await?
832 {
833 Err(anyhow!("access denied"))?;
834 }
835
836 state
837 .rpc
838 .write()
839 .await
840 .join_channel(request.sender_id, channel_id);
841 let messages = state
842 .db
843 .get_recent_channel_messages(channel_id, 50)
844 .await?
845 .into_iter()
846 .map(|msg| proto::ChannelMessage {
847 id: msg.id.to_proto(),
848 body: msg.body,
849 timestamp: msg.sent_at.unix_timestamp() as u64,
850 sender_id: msg.sender_id.to_proto(),
851 })
852 .collect();
853 rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
854 .await?;
855 Ok(())
856}
857
858async fn send_channel_message(
859 request: TypedEnvelope<proto::SendChannelMessage>,
860 peer: &Arc<Peer>,
861 app: &Arc<AppState>,
862) -> tide::Result<()> {
863 let channel_id = ChannelId::from_proto(request.payload.channel_id);
864 let user_id;
865 let connection_ids;
866 {
867 let state = app.rpc.read().await;
868 user_id = state.user_id_for_connection(request.sender_id)?;
869 if let Some(channel) = state.channels.get(&channel_id) {
870 connection_ids = channel.connection_ids();
871 } else {
872 return Ok(());
873 }
874 }
875
876 let timestamp = OffsetDateTime::now_utc();
877 let message_id = app
878 .db
879 .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
880 .await?;
881 let message = proto::ChannelMessageSent {
882 channel_id: channel_id.to_proto(),
883 message: Some(proto::ChannelMessage {
884 sender_id: user_id.to_proto(),
885 id: message_id.to_proto(),
886 body: request.payload.body,
887 timestamp: timestamp.unix_timestamp() as u64,
888 }),
889 };
890 broadcast(request.sender_id, connection_ids, |conn_id| {
891 peer.send(conn_id, message.clone())
892 })
893 .await?;
894
895 Ok(())
896}
897
898async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
899 worktree_id: u64,
900 request: TypedEnvelope<T>,
901 rpc: &Arc<Peer>,
902 state: &Arc<AppState>,
903) -> tide::Result<()> {
904 let connection_ids = state
905 .rpc
906 .read()
907 .await
908 .read_worktree(worktree_id, request.sender_id)?
909 .connection_ids();
910
911 broadcast(request.sender_id, connection_ids, |conn_id| {
912 rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
913 })
914 .await?;
915
916 Ok(())
917}
918
919pub async fn broadcast<F, T>(
920 sender_id: ConnectionId,
921 receiver_ids: Vec<ConnectionId>,
922 mut f: F,
923) -> anyhow::Result<()>
924where
925 F: FnMut(ConnectionId) -> T,
926 T: Future<Output = anyhow::Result<()>>,
927{
928 let futures = receiver_ids
929 .into_iter()
930 .filter(|id| *id != sender_id)
931 .map(|id| f(id));
932 futures::future::try_join_all(futures).await?;
933 Ok(())
934}
935
936fn header_contains_ignore_case<T>(
937 request: &tide::Request<T>,
938 header_name: HeaderName,
939 value: &str,
940) -> bool {
941 request
942 .header(header_name)
943 .map(|h| {
944 h.as_str()
945 .split(',')
946 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
947 })
948 .unwrap_or(false)
949}