1use super::{
2 auth,
3 db::{ChannelId, MessageId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
9use futures::{future::BoxFuture, FutureExt};
10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
11use sha1::{Digest as _, Sha1};
12use std::{
13 any::TypeId,
14 collections::{hash_map, HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use time::OffsetDateTime;
27use zrpc::{
28 auth::random_token,
29 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
30 Conn, ConnectionId, Peer, TypedEnvelope,
31};
32
33type ReplicaId = u16;
34
35type MessageHandler = Box<
36 dyn Send
37 + Sync
38 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
39>;
40
41pub struct Server {
42 peer: Arc<Peer>,
43 state: RwLock<ServerState>,
44 app_state: Arc<AppState>,
45 handlers: HashMap<TypeId, MessageHandler>,
46 notifications: Option<mpsc::Sender<()>>,
47}
48
49#[derive(Default)]
50struct ServerState {
51 connections: HashMap<ConnectionId, Connection>,
52 pub worktrees: HashMap<u64, Worktree>,
53 channels: HashMap<ChannelId, Channel>,
54 next_worktree_id: u64,
55}
56
57struct Connection {
58 user_id: UserId,
59 worktrees: HashSet<u64>,
60 channels: HashSet<ChannelId>,
61}
62
63struct Worktree {
64 host_connection_id: Option<ConnectionId>,
65 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
66 active_replica_ids: HashSet<ReplicaId>,
67 access_token: String,
68 root_name: String,
69 entries: HashMap<u64, proto::Entry>,
70}
71
72#[derive(Default)]
73struct Channel {
74 connection_ids: HashSet<ConnectionId>,
75}
76
77const MESSAGE_COUNT_PER_PAGE: usize = 100;
78const MAX_MESSAGE_LEN: usize = 1024;
79
80impl Server {
81 pub fn new(
82 app_state: Arc<AppState>,
83 peer: Arc<Peer>,
84 notifications: Option<mpsc::Sender<()>>,
85 ) -> Arc<Self> {
86 let mut server = Self {
87 peer,
88 app_state,
89 state: Default::default(),
90 handlers: Default::default(),
91 notifications,
92 };
93
94 server
95 .add_handler(Server::ping)
96 .add_handler(Server::share_worktree)
97 .add_handler(Server::join_worktree)
98 .add_handler(Server::update_worktree)
99 .add_handler(Server::close_worktree)
100 .add_handler(Server::open_buffer)
101 .add_handler(Server::close_buffer)
102 .add_handler(Server::update_buffer)
103 .add_handler(Server::buffer_saved)
104 .add_handler(Server::save_buffer)
105 .add_handler(Server::get_channels)
106 .add_handler(Server::get_users)
107 .add_handler(Server::join_channel)
108 .add_handler(Server::leave_channel)
109 .add_handler(Server::send_channel_message)
110 .add_handler(Server::get_channel_messages);
111
112 Arc::new(server)
113 }
114
115 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
116 where
117 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
118 Fut: 'static + Send + Future<Output = tide::Result<()>>,
119 M: EnvelopedMessage,
120 {
121 let prev_handler = self.handlers.insert(
122 TypeId::of::<M>(),
123 Box::new(move |server, envelope| {
124 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
125 (handler)(server, *envelope).boxed()
126 }),
127 );
128 if prev_handler.is_some() {
129 panic!("registered a handler for the same message twice");
130 }
131 self
132 }
133
134 pub fn handle_connection(
135 self: &Arc<Self>,
136 connection: Conn,
137 addr: String,
138 user_id: UserId,
139 ) -> impl Future<Output = ()> {
140 let this = self.clone();
141 async move {
142 let (connection_id, handle_io, mut incoming_rx) =
143 this.peer.add_connection(connection).await;
144 this.add_connection(connection_id, user_id).await;
145
146 let handle_io = handle_io.fuse();
147 futures::pin_mut!(handle_io);
148 loop {
149 let next_message = incoming_rx.recv().fuse();
150 futures::pin_mut!(next_message);
151 futures::select_biased! {
152 message = next_message => {
153 if let Some(message) = message {
154 let start_time = Instant::now();
155 log::info!("RPC message received: {}", message.payload_type_name());
156 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
157 if let Err(err) = (handler)(this.clone(), message).await {
158 log::error!("error handling message: {:?}", err);
159 } else {
160 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
161 }
162
163 if let Some(mut notifications) = this.notifications.clone() {
164 let _ = notifications.send(()).await;
165 }
166 } else {
167 log::warn!("unhandled message: {}", message.payload_type_name());
168 }
169 } else {
170 log::info!("rpc connection closed {:?}", addr);
171 break;
172 }
173 }
174 handle_io = handle_io => {
175 if let Err(err) = handle_io {
176 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
177 }
178 break;
179 }
180 }
181 }
182
183 if let Err(err) = this.sign_out(connection_id).await {
184 log::error!("error signing out connection {:?} - {:?}", addr, err);
185 }
186 }
187 }
188
189 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
190 self.peer.disconnect(connection_id).await;
191 let worktree_ids = self.remove_connection(connection_id).await;
192 for worktree_id in worktree_ids {
193 let state = self.state.read().await;
194 if let Some(worktree) = state.worktrees.get(&worktree_id) {
195 broadcast(connection_id, worktree.connection_ids(), |conn_id| {
196 self.peer.send(
197 conn_id,
198 proto::RemovePeer {
199 worktree_id,
200 peer_id: connection_id.0,
201 },
202 )
203 })
204 .await?;
205 }
206 }
207 Ok(())
208 }
209
210 // Add a new connection associated with a given user.
211 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
212 self.state.write().await.connections.insert(
213 connection_id,
214 Connection {
215 user_id,
216 worktrees: Default::default(),
217 channels: Default::default(),
218 },
219 );
220 }
221
222 // Remove the given connection and its association with any worktrees.
223 async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
224 let mut worktree_ids = Vec::new();
225 let mut state = self.state.write().await;
226 if let Some(connection) = state.connections.remove(&connection_id) {
227 for channel_id in connection.channels {
228 if let Some(channel) = state.channels.get_mut(&channel_id) {
229 channel.connection_ids.remove(&connection_id);
230 }
231 }
232 for worktree_id in connection.worktrees {
233 if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
234 if worktree.host_connection_id == Some(connection_id) {
235 worktree_ids.push(worktree_id);
236 } else if let Some(replica_id) =
237 worktree.guest_connection_ids.remove(&connection_id)
238 {
239 worktree.active_replica_ids.remove(&replica_id);
240 worktree_ids.push(worktree_id);
241 }
242 }
243 }
244 }
245 worktree_ids
246 }
247
248 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
249 self.peer
250 .respond(
251 request.receipt(),
252 proto::Pong {
253 id: request.payload.id,
254 },
255 )
256 .await?;
257 Ok(())
258 }
259
260 async fn share_worktree(
261 self: Arc<Server>,
262 mut request: TypedEnvelope<proto::ShareWorktree>,
263 ) -> tide::Result<()> {
264 let mut state = self.state.write().await;
265 let worktree_id = state.next_worktree_id;
266 state.next_worktree_id += 1;
267 let access_token = random_token();
268 let worktree = request
269 .payload
270 .worktree
271 .as_mut()
272 .ok_or_else(|| anyhow!("missing worktree"))?;
273 let entries = mem::take(&mut worktree.entries)
274 .into_iter()
275 .map(|entry| (entry.id, entry))
276 .collect();
277 state.worktrees.insert(
278 worktree_id,
279 Worktree {
280 host_connection_id: Some(request.sender_id),
281 guest_connection_ids: Default::default(),
282 active_replica_ids: Default::default(),
283 access_token: access_token.clone(),
284 root_name: mem::take(&mut worktree.root_name),
285 entries,
286 },
287 );
288
289 self.peer
290 .respond(
291 request.receipt(),
292 proto::ShareWorktreeResponse {
293 worktree_id,
294 access_token,
295 },
296 )
297 .await?;
298 Ok(())
299 }
300
301 async fn join_worktree(
302 self: Arc<Server>,
303 request: TypedEnvelope<proto::OpenWorktree>,
304 ) -> tide::Result<()> {
305 let worktree_id = request.payload.worktree_id;
306 let access_token = &request.payload.access_token;
307
308 let mut state = self.state.write().await;
309 if let Some((peer_replica_id, worktree)) =
310 state.join_worktree(request.sender_id, worktree_id, access_token)
311 {
312 let mut peers = Vec::new();
313 if let Some(host_connection_id) = worktree.host_connection_id {
314 peers.push(proto::Peer {
315 peer_id: host_connection_id.0,
316 replica_id: 0,
317 });
318 }
319 for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
320 if *peer_conn_id != request.sender_id {
321 peers.push(proto::Peer {
322 peer_id: peer_conn_id.0,
323 replica_id: *peer_replica_id as u32,
324 });
325 }
326 }
327
328 broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
329 self.peer.send(
330 conn_id,
331 proto::AddPeer {
332 worktree_id,
333 peer: Some(proto::Peer {
334 peer_id: request.sender_id.0,
335 replica_id: peer_replica_id as u32,
336 }),
337 },
338 )
339 })
340 .await?;
341 self.peer
342 .respond(
343 request.receipt(),
344 proto::OpenWorktreeResponse {
345 worktree_id,
346 worktree: Some(proto::Worktree {
347 root_name: worktree.root_name.clone(),
348 entries: worktree.entries.values().cloned().collect(),
349 }),
350 replica_id: peer_replica_id as u32,
351 peers,
352 },
353 )
354 .await?;
355 } else {
356 self.peer
357 .respond(
358 request.receipt(),
359 proto::OpenWorktreeResponse {
360 worktree_id,
361 worktree: None,
362 replica_id: 0,
363 peers: Vec::new(),
364 },
365 )
366 .await?;
367 }
368
369 Ok(())
370 }
371
372 async fn update_worktree(
373 self: Arc<Server>,
374 request: TypedEnvelope<proto::UpdateWorktree>,
375 ) -> tide::Result<()> {
376 {
377 let mut state = self.state.write().await;
378 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
379 for entry_id in &request.payload.removed_entries {
380 worktree.entries.remove(&entry_id);
381 }
382
383 for entry in &request.payload.updated_entries {
384 worktree.entries.insert(entry.id, entry.clone());
385 }
386 }
387
388 self.broadcast_in_worktree(request.payload.worktree_id, &request)
389 .await?;
390 Ok(())
391 }
392
393 async fn close_worktree(
394 self: Arc<Server>,
395 request: TypedEnvelope<proto::CloseWorktree>,
396 ) -> tide::Result<()> {
397 let connection_ids;
398 {
399 let mut state = self.state.write().await;
400 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
401 connection_ids = worktree.connection_ids();
402 if worktree.host_connection_id == Some(request.sender_id) {
403 worktree.host_connection_id = None;
404 } else if let Some(replica_id) =
405 worktree.guest_connection_ids.remove(&request.sender_id)
406 {
407 worktree.active_replica_ids.remove(&replica_id);
408 }
409 }
410
411 broadcast(request.sender_id, connection_ids, |conn_id| {
412 self.peer.send(
413 conn_id,
414 proto::RemovePeer {
415 worktree_id: request.payload.worktree_id,
416 peer_id: request.sender_id.0,
417 },
418 )
419 })
420 .await?;
421
422 Ok(())
423 }
424
425 async fn open_buffer(
426 self: Arc<Server>,
427 request: TypedEnvelope<proto::OpenBuffer>,
428 ) -> tide::Result<()> {
429 let receipt = request.receipt();
430 let worktree_id = request.payload.worktree_id;
431 let host_connection_id = self
432 .state
433 .read()
434 .await
435 .read_worktree(worktree_id, request.sender_id)?
436 .host_connection_id()?;
437
438 let response = self
439 .peer
440 .forward_request(request.sender_id, host_connection_id, request.payload)
441 .await?;
442 self.peer.respond(receipt, response).await?;
443 Ok(())
444 }
445
446 async fn close_buffer(
447 self: Arc<Server>,
448 request: TypedEnvelope<proto::CloseBuffer>,
449 ) -> tide::Result<()> {
450 let host_connection_id = self
451 .state
452 .read()
453 .await
454 .read_worktree(request.payload.worktree_id, request.sender_id)?
455 .host_connection_id()?;
456
457 self.peer
458 .forward_send(request.sender_id, host_connection_id, request.payload)
459 .await?;
460
461 Ok(())
462 }
463
464 async fn save_buffer(
465 self: Arc<Server>,
466 request: TypedEnvelope<proto::SaveBuffer>,
467 ) -> tide::Result<()> {
468 let host;
469 let guests;
470 {
471 let state = self.state.read().await;
472 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
473 host = worktree.host_connection_id()?;
474 guests = worktree
475 .guest_connection_ids
476 .keys()
477 .copied()
478 .collect::<Vec<_>>();
479 }
480
481 let sender = request.sender_id;
482 let receipt = request.receipt();
483 let response = self
484 .peer
485 .forward_request(sender, host, request.payload.clone())
486 .await?;
487
488 broadcast(host, guests, |conn_id| {
489 let response = response.clone();
490 let peer = &self.peer;
491 async move {
492 if conn_id == sender {
493 peer.respond(receipt, response).await
494 } else {
495 peer.forward_send(host, conn_id, response).await
496 }
497 }
498 })
499 .await?;
500
501 Ok(())
502 }
503
504 async fn update_buffer(
505 self: Arc<Server>,
506 request: TypedEnvelope<proto::UpdateBuffer>,
507 ) -> tide::Result<()> {
508 self.broadcast_in_worktree(request.payload.worktree_id, &request)
509 .await
510 }
511
512 async fn buffer_saved(
513 self: Arc<Server>,
514 request: TypedEnvelope<proto::BufferSaved>,
515 ) -> tide::Result<()> {
516 self.broadcast_in_worktree(request.payload.worktree_id, &request)
517 .await
518 }
519
520 async fn get_channels(
521 self: Arc<Server>,
522 request: TypedEnvelope<proto::GetChannels>,
523 ) -> tide::Result<()> {
524 let user_id = self
525 .state
526 .read()
527 .await
528 .user_id_for_connection(request.sender_id)?;
529 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
530 self.peer
531 .respond(
532 request.receipt(),
533 proto::GetChannelsResponse {
534 channels: channels
535 .into_iter()
536 .map(|chan| proto::Channel {
537 id: chan.id.to_proto(),
538 name: chan.name,
539 })
540 .collect(),
541 },
542 )
543 .await?;
544 Ok(())
545 }
546
547 async fn get_users(
548 self: Arc<Server>,
549 request: TypedEnvelope<proto::GetUsers>,
550 ) -> tide::Result<()> {
551 let user_id = self
552 .state
553 .read()
554 .await
555 .user_id_for_connection(request.sender_id)?;
556 let receipt = request.receipt();
557 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
558 let users = self
559 .app_state
560 .db
561 .get_users_by_ids(user_id, user_ids)
562 .await?
563 .into_iter()
564 .map(|user| proto::User {
565 id: user.id.to_proto(),
566 github_login: user.github_login,
567 avatar_url: String::new(),
568 })
569 .collect();
570 self.peer
571 .respond(receipt, proto::GetUsersResponse { users })
572 .await?;
573 Ok(())
574 }
575
576 async fn join_channel(
577 self: Arc<Self>,
578 request: TypedEnvelope<proto::JoinChannel>,
579 ) -> tide::Result<()> {
580 let user_id = self
581 .state
582 .read()
583 .await
584 .user_id_for_connection(request.sender_id)?;
585 let channel_id = ChannelId::from_proto(request.payload.channel_id);
586 if !self
587 .app_state
588 .db
589 .can_user_access_channel(user_id, channel_id)
590 .await?
591 {
592 Err(anyhow!("access denied"))?;
593 }
594
595 self.state
596 .write()
597 .await
598 .join_channel(request.sender_id, channel_id);
599 let messages = self
600 .app_state
601 .db
602 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
603 .await?
604 .into_iter()
605 .map(|msg| proto::ChannelMessage {
606 id: msg.id.to_proto(),
607 body: msg.body,
608 timestamp: msg.sent_at.unix_timestamp() as u64,
609 sender_id: msg.sender_id.to_proto(),
610 })
611 .collect::<Vec<_>>();
612 self.peer
613 .respond(
614 request.receipt(),
615 proto::JoinChannelResponse {
616 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
617 messages,
618 },
619 )
620 .await?;
621 Ok(())
622 }
623
624 async fn leave_channel(
625 self: Arc<Self>,
626 request: TypedEnvelope<proto::LeaveChannel>,
627 ) -> tide::Result<()> {
628 let user_id = self
629 .state
630 .read()
631 .await
632 .user_id_for_connection(request.sender_id)?;
633 let channel_id = ChannelId::from_proto(request.payload.channel_id);
634 if !self
635 .app_state
636 .db
637 .can_user_access_channel(user_id, channel_id)
638 .await?
639 {
640 Err(anyhow!("access denied"))?;
641 }
642
643 self.state
644 .write()
645 .await
646 .leave_channel(request.sender_id, channel_id);
647
648 Ok(())
649 }
650
651 async fn send_channel_message(
652 self: Arc<Self>,
653 request: TypedEnvelope<proto::SendChannelMessage>,
654 ) -> tide::Result<()> {
655 let receipt = request.receipt();
656 let channel_id = ChannelId::from_proto(request.payload.channel_id);
657 let user_id;
658 let connection_ids;
659 {
660 let state = self.state.read().await;
661 user_id = state.user_id_for_connection(request.sender_id)?;
662 if let Some(channel) = state.channels.get(&channel_id) {
663 connection_ids = channel.connection_ids();
664 } else {
665 return Ok(());
666 }
667 }
668
669 // Validate the message body.
670 let body = request.payload.body.trim().to_string();
671 if body.len() > MAX_MESSAGE_LEN {
672 self.peer
673 .respond_with_error(
674 receipt,
675 proto::Error {
676 message: "message is too long".to_string(),
677 },
678 )
679 .await?;
680 return Ok(());
681 }
682 if body.is_empty() {
683 self.peer
684 .respond_with_error(
685 receipt,
686 proto::Error {
687 message: "message can't be blank".to_string(),
688 },
689 )
690 .await?;
691 return Ok(());
692 }
693
694 let timestamp = OffsetDateTime::now_utc();
695 let message_id = self
696 .app_state
697 .db
698 .create_channel_message(channel_id, user_id, &body, timestamp)
699 .await?
700 .to_proto();
701 let message = proto::ChannelMessage {
702 sender_id: user_id.to_proto(),
703 id: message_id,
704 body,
705 timestamp: timestamp.unix_timestamp() as u64,
706 };
707 broadcast(request.sender_id, connection_ids, |conn_id| {
708 self.peer.send(
709 conn_id,
710 proto::ChannelMessageSent {
711 channel_id: channel_id.to_proto(),
712 message: Some(message.clone()),
713 },
714 )
715 })
716 .await?;
717 self.peer
718 .respond(
719 receipt,
720 proto::SendChannelMessageResponse {
721 message: Some(message),
722 },
723 )
724 .await?;
725 Ok(())
726 }
727
728 async fn get_channel_messages(
729 self: Arc<Self>,
730 request: TypedEnvelope<proto::GetChannelMessages>,
731 ) -> tide::Result<()> {
732 let user_id = self
733 .state
734 .read()
735 .await
736 .user_id_for_connection(request.sender_id)?;
737 let channel_id = ChannelId::from_proto(request.payload.channel_id);
738 if !self
739 .app_state
740 .db
741 .can_user_access_channel(user_id, channel_id)
742 .await?
743 {
744 Err(anyhow!("access denied"))?;
745 }
746
747 let messages = self
748 .app_state
749 .db
750 .get_channel_messages(
751 channel_id,
752 MESSAGE_COUNT_PER_PAGE,
753 Some(MessageId::from_proto(request.payload.before_message_id)),
754 )
755 .await?
756 .into_iter()
757 .map(|msg| proto::ChannelMessage {
758 id: msg.id.to_proto(),
759 body: msg.body,
760 timestamp: msg.sent_at.unix_timestamp() as u64,
761 sender_id: msg.sender_id.to_proto(),
762 })
763 .collect::<Vec<_>>();
764 self.peer
765 .respond(
766 request.receipt(),
767 proto::GetChannelMessagesResponse {
768 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
769 messages,
770 },
771 )
772 .await?;
773 Ok(())
774 }
775
776 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
777 &self,
778 worktree_id: u64,
779 message: &TypedEnvelope<T>,
780 ) -> tide::Result<()> {
781 let connection_ids = self
782 .state
783 .read()
784 .await
785 .read_worktree(worktree_id, message.sender_id)?
786 .connection_ids();
787
788 broadcast(message.sender_id, connection_ids, |conn_id| {
789 self.peer
790 .forward_send(message.sender_id, conn_id, message.payload.clone())
791 })
792 .await?;
793
794 Ok(())
795 }
796}
797
798pub async fn broadcast<F, T>(
799 sender_id: ConnectionId,
800 receiver_ids: Vec<ConnectionId>,
801 mut f: F,
802) -> anyhow::Result<()>
803where
804 F: FnMut(ConnectionId) -> T,
805 T: Future<Output = anyhow::Result<()>>,
806{
807 let futures = receiver_ids
808 .into_iter()
809 .filter(|id| *id != sender_id)
810 .map(|id| f(id));
811 futures::future::try_join_all(futures).await?;
812 Ok(())
813}
814
815impl ServerState {
816 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
817 if let Some(connection) = self.connections.get_mut(&connection_id) {
818 connection.channels.insert(channel_id);
819 self.channels
820 .entry(channel_id)
821 .or_default()
822 .connection_ids
823 .insert(connection_id);
824 }
825 }
826
827 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
828 if let Some(connection) = self.connections.get_mut(&connection_id) {
829 connection.channels.remove(&channel_id);
830 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
831 entry.get_mut().connection_ids.remove(&connection_id);
832 if entry.get_mut().connection_ids.is_empty() {
833 entry.remove();
834 }
835 }
836 }
837 }
838
839 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
840 Ok(self
841 .connections
842 .get(&connection_id)
843 .ok_or_else(|| anyhow!("unknown connection"))?
844 .user_id)
845 }
846
847 // Add the given connection as a guest of the given worktree
848 fn join_worktree(
849 &mut self,
850 connection_id: ConnectionId,
851 worktree_id: u64,
852 access_token: &str,
853 ) -> Option<(ReplicaId, &Worktree)> {
854 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
855 if access_token == worktree.access_token {
856 if let Some(connection) = self.connections.get_mut(&connection_id) {
857 connection.worktrees.insert(worktree_id);
858 }
859
860 let mut replica_id = 1;
861 while worktree.active_replica_ids.contains(&replica_id) {
862 replica_id += 1;
863 }
864 worktree.active_replica_ids.insert(replica_id);
865 worktree
866 .guest_connection_ids
867 .insert(connection_id, replica_id);
868 Some((replica_id, worktree))
869 } else {
870 None
871 }
872 } else {
873 None
874 }
875 }
876
877 fn read_worktree(
878 &self,
879 worktree_id: u64,
880 connection_id: ConnectionId,
881 ) -> tide::Result<&Worktree> {
882 let worktree = self
883 .worktrees
884 .get(&worktree_id)
885 .ok_or_else(|| anyhow!("worktree not found"))?;
886
887 if worktree.host_connection_id == Some(connection_id)
888 || worktree.guest_connection_ids.contains_key(&connection_id)
889 {
890 Ok(worktree)
891 } else {
892 Err(anyhow!(
893 "{} is not a member of worktree {}",
894 connection_id,
895 worktree_id
896 ))?
897 }
898 }
899
900 fn write_worktree(
901 &mut self,
902 worktree_id: u64,
903 connection_id: ConnectionId,
904 ) -> tide::Result<&mut Worktree> {
905 let worktree = self
906 .worktrees
907 .get_mut(&worktree_id)
908 .ok_or_else(|| anyhow!("worktree not found"))?;
909
910 if worktree.host_connection_id == Some(connection_id)
911 || worktree.guest_connection_ids.contains_key(&connection_id)
912 {
913 Ok(worktree)
914 } else {
915 Err(anyhow!(
916 "{} is not a member of worktree {}",
917 connection_id,
918 worktree_id
919 ))?
920 }
921 }
922}
923
924impl Worktree {
925 pub fn connection_ids(&self) -> Vec<ConnectionId> {
926 self.guest_connection_ids
927 .keys()
928 .copied()
929 .chain(self.host_connection_id)
930 .collect()
931 }
932
933 fn host_connection_id(&self) -> tide::Result<ConnectionId> {
934 Ok(self
935 .host_connection_id
936 .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
937 }
938}
939
940impl Channel {
941 fn connection_ids(&self) -> Vec<ConnectionId> {
942 self.connection_ids.iter().copied().collect()
943 }
944}
945
946pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
947 let server = Server::new(app.state().clone(), rpc.clone(), None);
948 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
949 let user_id = request.ext::<UserId>().copied();
950 let server = server.clone();
951 async move {
952 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
953
954 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
955 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
956 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
957
958 if !upgrade_requested {
959 return Ok(Response::new(StatusCode::UpgradeRequired));
960 }
961
962 let header = match request.header("Sec-Websocket-Key") {
963 Some(h) => h.as_str(),
964 None => return Err(anyhow!("expected sec-websocket-key"))?,
965 };
966
967 let mut response = Response::new(StatusCode::SwitchingProtocols);
968 response.insert_header(UPGRADE, "websocket");
969 response.insert_header(CONNECTION, "Upgrade");
970 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
971 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
972 response.insert_header("Sec-Websocket-Version", "13");
973
974 let http_res: &mut tide::http::Response = response.as_mut();
975 let upgrade_receiver = http_res.recv_upgrade().await;
976 let addr = request.remote().unwrap_or("unknown").to_string();
977 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
978 task::spawn(async move {
979 if let Some(stream) = upgrade_receiver.await {
980 server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
981 }
982 });
983
984 Ok(response)
985 }
986 });
987}
988
989fn header_contains_ignore_case<T>(
990 request: &tide::Request<T>,
991 header_name: HeaderName,
992 value: &str,
993) -> bool {
994 request
995 .header(header_name)
996 .map(|h| {
997 h.as_str()
998 .split(',')
999 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1000 })
1001 .unwrap_or(false)
1002}
1003
1004#[cfg(test)]
1005mod tests {
1006 use super::*;
1007 use crate::{
1008 auth,
1009 db::{tests::TestDb, UserId},
1010 github, AppState, Config,
1011 };
1012 use async_std::{sync::RwLockReadGuard, task};
1013 use gpui::TestAppContext;
1014 use parking_lot::Mutex;
1015 use postage::{mpsc, watch};
1016 use serde_json::json;
1017 use sqlx::types::time::OffsetDateTime;
1018 use std::{
1019 path::Path,
1020 sync::{
1021 atomic::{AtomicBool, Ordering::SeqCst},
1022 Arc,
1023 },
1024 time::Duration,
1025 };
1026 use zed::{
1027 channel::{Channel, ChannelDetails, ChannelList},
1028 editor::{Editor, Insert},
1029 fs::{FakeFs, Fs as _},
1030 language::LanguageRegistry,
1031 rpc::{self, Client},
1032 settings,
1033 user::UserStore,
1034 worktree::Worktree,
1035 };
1036 use zrpc::Peer;
1037
1038 #[gpui::test]
1039 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1040 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1041 let settings = cx_b.read(settings::test).1;
1042 let lang_registry = Arc::new(LanguageRegistry::new());
1043
1044 // Connect to a server as 2 clients.
1045 let mut server = TestServer::start().await;
1046 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1047 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1048
1049 cx_a.foreground().forbid_parking();
1050
1051 // Share a local worktree as client A
1052 let fs = Arc::new(FakeFs::new());
1053 fs.insert_tree(
1054 "/a",
1055 json!({
1056 "a.txt": "a-contents",
1057 "b.txt": "b-contents",
1058 }),
1059 )
1060 .await;
1061 let worktree_a = Worktree::open_local(
1062 "/a".as_ref(),
1063 lang_registry.clone(),
1064 fs,
1065 &mut cx_a.to_async(),
1066 )
1067 .await
1068 .unwrap();
1069 worktree_a
1070 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1071 .await;
1072 let (worktree_id, worktree_token) = worktree_a
1073 .update(&mut cx_a, |tree, cx| {
1074 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1075 })
1076 .await
1077 .unwrap();
1078
1079 // Join that worktree as client B, and see that a guest has joined as client A.
1080 let worktree_b = Worktree::open_remote(
1081 client_b.clone(),
1082 worktree_id,
1083 worktree_token,
1084 lang_registry.clone(),
1085 &mut cx_b.to_async(),
1086 )
1087 .await
1088 .unwrap();
1089 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1090 worktree_a
1091 .condition(&cx_a, |tree, _| {
1092 tree.peers()
1093 .values()
1094 .any(|replica_id| *replica_id == replica_id_b)
1095 })
1096 .await;
1097
1098 // Open the same file as client B and client A.
1099 let buffer_b = worktree_b
1100 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1101 .await
1102 .unwrap();
1103 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1104 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1105 let buffer_a = worktree_a
1106 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1107 .await
1108 .unwrap();
1109
1110 // Create a selection set as client B and see that selection set as client A.
1111 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1112 buffer_a
1113 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1114 .await;
1115
1116 // Edit the buffer as client B and see that edit as client A.
1117 editor_b.update(&mut cx_b, |editor, cx| {
1118 editor.insert(&Insert("ok, ".into()), cx)
1119 });
1120 buffer_a
1121 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1122 .await;
1123
1124 // Remove the selection set as client B, see those selections disappear as client A.
1125 cx_b.update(move |_| drop(editor_b));
1126 buffer_a
1127 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1128 .await;
1129
1130 // Close the buffer as client A, see that the buffer is closed.
1131 drop(buffer_a);
1132 worktree_a
1133 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1134 .await;
1135
1136 // Dropping the worktree removes client B from client A's peers.
1137 cx_b.update(move |_| drop(worktree_b));
1138 worktree_a
1139 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1140 .await;
1141 }
1142
1143 #[gpui::test]
1144 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1145 mut cx_a: TestAppContext,
1146 mut cx_b: TestAppContext,
1147 mut cx_c: TestAppContext,
1148 ) {
1149 cx_a.foreground().forbid_parking();
1150 let lang_registry = Arc::new(LanguageRegistry::new());
1151
1152 // Connect to a server as 3 clients.
1153 let mut server = TestServer::start().await;
1154 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1155 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1156 let (_, client_c) = server.create_client(&mut cx_c, "user_c").await;
1157
1158 let fs = Arc::new(FakeFs::new());
1159
1160 // Share a worktree as client A.
1161 fs.insert_tree(
1162 "/a",
1163 json!({
1164 "file1": "",
1165 "file2": ""
1166 }),
1167 )
1168 .await;
1169
1170 let worktree_a = Worktree::open_local(
1171 "/a".as_ref(),
1172 lang_registry.clone(),
1173 fs.clone(),
1174 &mut cx_a.to_async(),
1175 )
1176 .await
1177 .unwrap();
1178 worktree_a
1179 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1180 .await;
1181 let (worktree_id, worktree_token) = worktree_a
1182 .update(&mut cx_a, |tree, cx| {
1183 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1184 })
1185 .await
1186 .unwrap();
1187
1188 // Join that worktree as clients B and C.
1189 let worktree_b = Worktree::open_remote(
1190 client_b.clone(),
1191 worktree_id,
1192 worktree_token.clone(),
1193 lang_registry.clone(),
1194 &mut cx_b.to_async(),
1195 )
1196 .await
1197 .unwrap();
1198 let worktree_c = Worktree::open_remote(
1199 client_c.clone(),
1200 worktree_id,
1201 worktree_token,
1202 lang_registry.clone(),
1203 &mut cx_c.to_async(),
1204 )
1205 .await
1206 .unwrap();
1207
1208 // Open and edit a buffer as both guests B and C.
1209 let buffer_b = worktree_b
1210 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1211 .await
1212 .unwrap();
1213 let buffer_c = worktree_c
1214 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1215 .await
1216 .unwrap();
1217 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1218 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1219
1220 // Open and edit that buffer as the host.
1221 let buffer_a = worktree_a
1222 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1223 .await
1224 .unwrap();
1225
1226 buffer_a
1227 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1228 .await;
1229 buffer_a.update(&mut cx_a, |buf, cx| {
1230 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1231 });
1232
1233 // Wait for edits to propagate
1234 buffer_a
1235 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1236 .await;
1237 buffer_b
1238 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1239 .await;
1240 buffer_c
1241 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1242 .await;
1243
1244 // Edit the buffer as the host and concurrently save as guest B.
1245 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1246 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1247 save_b.await.unwrap();
1248 assert_eq!(
1249 fs.load("/a/file1".as_ref()).await.unwrap(),
1250 "hi-a, i-am-c, i-am-b, i-am-a"
1251 );
1252 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1253 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1254 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1255
1256 // Make changes on host's file system, see those changes on the guests.
1257 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1258 .await
1259 .unwrap();
1260 fs.insert_file(Path::new("/a/file4"), "4".into())
1261 .await
1262 .unwrap();
1263
1264 worktree_b
1265 .condition(&cx_b, |tree, _| tree.file_count() == 3)
1266 .await;
1267 worktree_c
1268 .condition(&cx_c, |tree, _| tree.file_count() == 3)
1269 .await;
1270 worktree_b.read_with(&cx_b, |tree, _| {
1271 assert_eq!(
1272 tree.paths()
1273 .map(|p| p.to_string_lossy())
1274 .collect::<Vec<_>>(),
1275 &["file1", "file3", "file4"]
1276 )
1277 });
1278 worktree_c.read_with(&cx_c, |tree, _| {
1279 assert_eq!(
1280 tree.paths()
1281 .map(|p| p.to_string_lossy())
1282 .collect::<Vec<_>>(),
1283 &["file1", "file3", "file4"]
1284 )
1285 });
1286 }
1287
1288 #[gpui::test]
1289 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1290 cx_a.foreground().forbid_parking();
1291 let lang_registry = Arc::new(LanguageRegistry::new());
1292
1293 // Connect to a server as 2 clients.
1294 let mut server = TestServer::start().await;
1295 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1296 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1297
1298 // Share a local worktree as client A
1299 let fs = Arc::new(FakeFs::new());
1300 fs.save(Path::new("/a.txt"), &"a-contents".into())
1301 .await
1302 .unwrap();
1303 let worktree_a = Worktree::open_local(
1304 "/".as_ref(),
1305 lang_registry.clone(),
1306 fs,
1307 &mut cx_a.to_async(),
1308 )
1309 .await
1310 .unwrap();
1311 worktree_a
1312 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1313 .await;
1314 let (worktree_id, worktree_token) = worktree_a
1315 .update(&mut cx_a, |tree, cx| {
1316 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1317 })
1318 .await
1319 .unwrap();
1320
1321 // Join that worktree as client B, and see that a guest has joined as client A.
1322 let worktree_b = Worktree::open_remote(
1323 client_b.clone(),
1324 worktree_id,
1325 worktree_token,
1326 lang_registry.clone(),
1327 &mut cx_b.to_async(),
1328 )
1329 .await
1330 .unwrap();
1331
1332 let buffer_b = worktree_b
1333 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1334 .await
1335 .unwrap();
1336 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1337
1338 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1339 buffer_b.read_with(&cx_b, |buf, _| {
1340 assert!(buf.is_dirty());
1341 assert!(!buf.has_conflict());
1342 });
1343
1344 buffer_b
1345 .update(&mut cx_b, |buf, cx| buf.save(cx))
1346 .unwrap()
1347 .await
1348 .unwrap();
1349 worktree_b
1350 .condition(&cx_b, |_, cx| {
1351 buffer_b.read(cx).file().unwrap().mtime != mtime
1352 })
1353 .await;
1354 buffer_b.read_with(&cx_b, |buf, _| {
1355 assert!(!buf.is_dirty());
1356 assert!(!buf.has_conflict());
1357 });
1358
1359 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1360 buffer_b.read_with(&cx_b, |buf, _| {
1361 assert!(buf.is_dirty());
1362 assert!(!buf.has_conflict());
1363 });
1364 }
1365
1366 #[gpui::test]
1367 async fn test_editing_while_guest_opens_buffer(
1368 mut cx_a: TestAppContext,
1369 mut cx_b: TestAppContext,
1370 ) {
1371 cx_a.foreground().forbid_parking();
1372 let lang_registry = Arc::new(LanguageRegistry::new());
1373
1374 // Connect to a server as 2 clients.
1375 let mut server = TestServer::start().await;
1376 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1377 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1378
1379 // Share a local worktree as client A
1380 let fs = Arc::new(FakeFs::new());
1381 fs.save(Path::new("/a.txt"), &"a-contents".into())
1382 .await
1383 .unwrap();
1384 let worktree_a = Worktree::open_local(
1385 "/".as_ref(),
1386 lang_registry.clone(),
1387 fs,
1388 &mut cx_a.to_async(),
1389 )
1390 .await
1391 .unwrap();
1392 worktree_a
1393 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1394 .await;
1395 let (worktree_id, worktree_token) = worktree_a
1396 .update(&mut cx_a, |tree, cx| {
1397 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1398 })
1399 .await
1400 .unwrap();
1401
1402 // Join that worktree as client B, and see that a guest has joined as client A.
1403 let worktree_b = Worktree::open_remote(
1404 client_b.clone(),
1405 worktree_id,
1406 worktree_token,
1407 lang_registry.clone(),
1408 &mut cx_b.to_async(),
1409 )
1410 .await
1411 .unwrap();
1412
1413 let buffer_a = worktree_a
1414 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1415 .await
1416 .unwrap();
1417 let buffer_b = cx_b
1418 .background()
1419 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1420
1421 task::yield_now().await;
1422 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1423
1424 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1425 let buffer_b = buffer_b.await.unwrap();
1426 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1427 }
1428
1429 #[gpui::test]
1430 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1431 cx_a.foreground().forbid_parking();
1432 let lang_registry = Arc::new(LanguageRegistry::new());
1433
1434 // Connect to a server as 2 clients.
1435 let mut server = TestServer::start().await;
1436 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1437 let (_, client_b) = server.create_client(&mut cx_a, "user_b").await;
1438
1439 // Share a local worktree as client A
1440 let fs = Arc::new(FakeFs::new());
1441 fs.insert_tree(
1442 "/a",
1443 json!({
1444 "a.txt": "a-contents",
1445 "b.txt": "b-contents",
1446 }),
1447 )
1448 .await;
1449 let worktree_a = Worktree::open_local(
1450 "/a".as_ref(),
1451 lang_registry.clone(),
1452 fs,
1453 &mut cx_a.to_async(),
1454 )
1455 .await
1456 .unwrap();
1457 worktree_a
1458 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1459 .await;
1460 let (worktree_id, worktree_token) = worktree_a
1461 .update(&mut cx_a, |tree, cx| {
1462 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1463 })
1464 .await
1465 .unwrap();
1466
1467 // Join that worktree as client B, and see that a guest has joined as client A.
1468 let _worktree_b = Worktree::open_remote(
1469 client_b.clone(),
1470 worktree_id,
1471 worktree_token,
1472 lang_registry.clone(),
1473 &mut cx_b.to_async(),
1474 )
1475 .await
1476 .unwrap();
1477 worktree_a
1478 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1479 .await;
1480
1481 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1482 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1483 worktree_a
1484 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1485 .await;
1486 }
1487
1488 #[gpui::test]
1489 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1490 cx_a.foreground().forbid_parking();
1491
1492 // Connect to a server as 2 clients.
1493 let mut server = TestServer::start().await;
1494 let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
1495 let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
1496
1497 // Create an org that includes these 2 users.
1498 let db = &server.app_state.db;
1499 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1500 db.add_org_member(org_id, user_id_a, false).await.unwrap();
1501 db.add_org_member(org_id, user_id_b, false).await.unwrap();
1502
1503 // Create a channel that includes all the users.
1504 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1505 db.add_channel_member(channel_id, user_id_a, false)
1506 .await
1507 .unwrap();
1508 db.add_channel_member(channel_id, user_id_b, false)
1509 .await
1510 .unwrap();
1511 db.create_channel_message(
1512 channel_id,
1513 user_id_b,
1514 "hello A, it's B.",
1515 OffsetDateTime::now_utc(),
1516 )
1517 .await
1518 .unwrap();
1519
1520 let user_store_a = Arc::new(UserStore::new(client_a.clone()));
1521 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1522 channels_a
1523 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1524 .await;
1525 channels_a.read_with(&cx_a, |list, _| {
1526 assert_eq!(
1527 list.available_channels().unwrap(),
1528 &[ChannelDetails {
1529 id: channel_id.to_proto(),
1530 name: "test-channel".to_string()
1531 }]
1532 )
1533 });
1534 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1535 this.get_channel(channel_id.to_proto(), cx).unwrap()
1536 });
1537 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1538 channel_a
1539 .condition(&cx_a, |channel, _| {
1540 channel_messages(channel)
1541 == [("user_b".to_string(), "hello A, it's B.".to_string())]
1542 })
1543 .await;
1544
1545 let user_store_b = Arc::new(UserStore::new(client_b.clone()));
1546 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1547 channels_b
1548 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1549 .await;
1550 channels_b.read_with(&cx_b, |list, _| {
1551 assert_eq!(
1552 list.available_channels().unwrap(),
1553 &[ChannelDetails {
1554 id: channel_id.to_proto(),
1555 name: "test-channel".to_string()
1556 }]
1557 )
1558 });
1559
1560 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1561 this.get_channel(channel_id.to_proto(), cx).unwrap()
1562 });
1563 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1564 channel_b
1565 .condition(&cx_b, |channel, _| {
1566 channel_messages(channel)
1567 == [("user_b".to_string(), "hello A, it's B.".to_string())]
1568 })
1569 .await;
1570
1571 channel_a
1572 .update(&mut cx_a, |channel, cx| {
1573 channel
1574 .send_message("oh, hi B.".to_string(), cx)
1575 .unwrap()
1576 .detach();
1577 let task = channel.send_message("sup".to_string(), cx).unwrap();
1578 assert_eq!(
1579 channel
1580 .pending_messages()
1581 .iter()
1582 .map(|m| &m.body)
1583 .collect::<Vec<_>>(),
1584 &["oh, hi B.", "sup"]
1585 );
1586 task
1587 })
1588 .await
1589 .unwrap();
1590
1591 channel_a
1592 .condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
1593 .await;
1594 channel_b
1595 .condition(&cx_b, |channel, _| {
1596 channel_messages(channel)
1597 == [
1598 ("user_b".to_string(), "hello A, it's B.".to_string()),
1599 ("user_a".to_string(), "oh, hi B.".to_string()),
1600 ("user_a".to_string(), "sup".to_string()),
1601 ]
1602 })
1603 .await;
1604
1605 assert_eq!(
1606 server.state().await.channels[&channel_id]
1607 .connection_ids
1608 .len(),
1609 2
1610 );
1611 cx_b.update(|_| drop(channel_b));
1612 server
1613 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1614 .await;
1615
1616 cx_a.update(|_| drop(channel_a));
1617 server
1618 .condition(|state| !state.channels.contains_key(&channel_id))
1619 .await;
1620
1621 fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
1622 channel
1623 .messages()
1624 .cursor::<(), ()>()
1625 .map(|m| (m.sender.github_login.clone(), m.body.clone()))
1626 .collect()
1627 }
1628 }
1629
1630 #[gpui::test]
1631 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1632 cx_a.foreground().forbid_parking();
1633
1634 let mut server = TestServer::start().await;
1635 let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
1636
1637 let db = &server.app_state.db;
1638 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1639 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1640 db.add_org_member(org_id, user_id_a, false).await.unwrap();
1641 db.add_channel_member(channel_id, user_id_a, false)
1642 .await
1643 .unwrap();
1644
1645 let user_store_a = Arc::new(UserStore::new(client_a.clone()));
1646 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1647 channels_a
1648 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1649 .await;
1650 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1651 this.get_channel(channel_id.to_proto(), cx).unwrap()
1652 });
1653
1654 // Messages aren't allowed to be too long.
1655 channel_a
1656 .update(&mut cx_a, |channel, cx| {
1657 let long_body = "this is long.\n".repeat(1024);
1658 channel.send_message(long_body, cx).unwrap()
1659 })
1660 .await
1661 .unwrap_err();
1662
1663 // Messages aren't allowed to be blank.
1664 channel_a.update(&mut cx_a, |channel, cx| {
1665 channel.send_message(String::new(), cx).unwrap_err()
1666 });
1667
1668 // Leading and trailing whitespace are trimmed.
1669 channel_a
1670 .update(&mut cx_a, |channel, cx| {
1671 channel
1672 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1673 .unwrap()
1674 })
1675 .await
1676 .unwrap();
1677 assert_eq!(
1678 db.get_channel_messages(channel_id, 10, None)
1679 .await
1680 .unwrap()
1681 .iter()
1682 .map(|m| &m.body)
1683 .collect::<Vec<_>>(),
1684 &["surrounded by whitespace"]
1685 );
1686 }
1687
1688 #[gpui::test]
1689 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1690 cx_a.foreground().forbid_parking();
1691
1692 // Connect to a server as 2 clients.
1693 let mut server = TestServer::start().await;
1694 let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
1695 let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
1696 let mut status_b = client_b.status();
1697
1698 // Create an org that includes these 2 users.
1699 let db = &server.app_state.db;
1700 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1701 db.add_org_member(org_id, user_id_a, false).await.unwrap();
1702 db.add_org_member(org_id, user_id_b, false).await.unwrap();
1703
1704 // Create a channel that includes all the users.
1705 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1706 db.add_channel_member(channel_id, user_id_a, false)
1707 .await
1708 .unwrap();
1709 db.add_channel_member(channel_id, user_id_b, false)
1710 .await
1711 .unwrap();
1712 db.create_channel_message(
1713 channel_id,
1714 user_id_b,
1715 "hello A, it's B.",
1716 OffsetDateTime::now_utc(),
1717 )
1718 .await
1719 .unwrap();
1720
1721 let user_store_a = Arc::new(UserStore::new(client_a.clone()));
1722 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1723 channels_a
1724 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1725 .await;
1726
1727 channels_a.read_with(&cx_a, |list, _| {
1728 assert_eq!(
1729 list.available_channels().unwrap(),
1730 &[ChannelDetails {
1731 id: channel_id.to_proto(),
1732 name: "test-channel".to_string()
1733 }]
1734 )
1735 });
1736 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1737 this.get_channel(channel_id.to_proto(), cx).unwrap()
1738 });
1739 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1740 channel_a
1741 .condition(&cx_a, |channel, _| {
1742 channel_messages(channel)
1743 == [("user_b".to_string(), "hello A, it's B.".to_string())]
1744 })
1745 .await;
1746
1747 let user_store_b = Arc::new(UserStore::new(client_b.clone()));
1748 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1749 channels_b
1750 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1751 .await;
1752 channels_b.read_with(&cx_b, |list, _| {
1753 assert_eq!(
1754 list.available_channels().unwrap(),
1755 &[ChannelDetails {
1756 id: channel_id.to_proto(),
1757 name: "test-channel".to_string()
1758 }]
1759 )
1760 });
1761
1762 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1763 this.get_channel(channel_id.to_proto(), cx).unwrap()
1764 });
1765 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1766 channel_b
1767 .condition(&cx_b, |channel, _| {
1768 channel_messages(channel)
1769 == [("user_b".to_string(), "hello A, it's B.".to_string())]
1770 })
1771 .await;
1772
1773 // Disconnect client B, ensuring we can still access its cached channel data.
1774 server.forbid_connections();
1775 server.disconnect_client(user_id_b);
1776 while !matches!(
1777 status_b.recv().await,
1778 Some(rpc::Status::ReconnectionError { .. })
1779 ) {}
1780
1781 channels_b.read_with(&cx_b, |channels, _| {
1782 assert_eq!(
1783 channels.available_channels().unwrap(),
1784 [ChannelDetails {
1785 id: channel_id.to_proto(),
1786 name: "test-channel".to_string()
1787 }]
1788 )
1789 });
1790 channel_b.read_with(&cx_b, |channel, _| {
1791 assert_eq!(
1792 channel_messages(channel),
1793 [("user_b".to_string(), "hello A, it's B.".to_string())]
1794 )
1795 });
1796
1797 // Send a message from client A while B is disconnected.
1798 channel_a
1799 .update(&mut cx_a, |channel, cx| {
1800 channel
1801 .send_message("oh, hi B.".to_string(), cx)
1802 .unwrap()
1803 .detach();
1804 let task = channel.send_message("sup".to_string(), cx).unwrap();
1805 assert_eq!(
1806 channel
1807 .pending_messages()
1808 .iter()
1809 .map(|m| &m.body)
1810 .collect::<Vec<_>>(),
1811 &["oh, hi B.", "sup"]
1812 );
1813 task
1814 })
1815 .await
1816 .unwrap();
1817
1818 // Give client B a chance to reconnect.
1819 server.allow_connections();
1820 cx_b.foreground().advance_clock(Duration::from_secs(10));
1821
1822 // Verify that B sees the new messages upon reconnection.
1823 channel_b
1824 .condition(&cx_b, |channel, _| {
1825 channel_messages(channel)
1826 == [
1827 ("user_b".to_string(), "hello A, it's B.".to_string()),
1828 ("user_a".to_string(), "oh, hi B.".to_string()),
1829 ("user_a".to_string(), "sup".to_string()),
1830 ]
1831 })
1832 .await;
1833
1834 fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
1835 channel
1836 .messages()
1837 .cursor::<(), ()>()
1838 .map(|m| (m.sender.github_login.clone(), m.body.clone()))
1839 .collect()
1840 }
1841 }
1842
1843 struct TestServer {
1844 peer: Arc<Peer>,
1845 app_state: Arc<AppState>,
1846 server: Arc<Server>,
1847 notifications: mpsc::Receiver<()>,
1848 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
1849 forbid_connections: Arc<AtomicBool>,
1850 _test_db: TestDb,
1851 }
1852
1853 impl TestServer {
1854 async fn start() -> Self {
1855 let test_db = TestDb::new();
1856 let app_state = Self::build_app_state(&test_db).await;
1857 let peer = Peer::new();
1858 let notifications = mpsc::channel(128);
1859 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
1860 Self {
1861 peer,
1862 app_state,
1863 server,
1864 notifications: notifications.1,
1865 connection_killers: Default::default(),
1866 forbid_connections: Default::default(),
1867 _test_db: test_db,
1868 }
1869 }
1870
1871 async fn create_client(
1872 &mut self,
1873 cx: &mut TestAppContext,
1874 name: &str,
1875 ) -> (UserId, Arc<Client>) {
1876 let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
1877 let client_name = name.to_string();
1878 let mut client = Client::new();
1879 let server = self.server.clone();
1880 let connection_killers = self.connection_killers.clone();
1881 let forbid_connections = self.forbid_connections.clone();
1882 Arc::get_mut(&mut client)
1883 .unwrap()
1884 .set_login_and_connect_callbacks(
1885 move |cx| {
1886 cx.spawn(|_| async move {
1887 let access_token = "the-token".to_string();
1888 Ok((client_user_id.0 as u64, access_token))
1889 })
1890 },
1891 move |user_id, access_token, cx| {
1892 assert_eq!(user_id, client_user_id.0 as u64);
1893 assert_eq!(access_token, "the-token");
1894
1895 let server = server.clone();
1896 let connection_killers = connection_killers.clone();
1897 let forbid_connections = forbid_connections.clone();
1898 let client_name = client_name.clone();
1899 cx.spawn(move |cx| async move {
1900 if forbid_connections.load(SeqCst) {
1901 Err(anyhow!("server is forbidding connections"))
1902 } else {
1903 let (client_conn, server_conn, kill_conn) = Conn::in_memory();
1904 connection_killers.lock().insert(client_user_id, kill_conn);
1905 cx.background()
1906 .spawn(server.handle_connection(
1907 server_conn,
1908 client_name,
1909 client_user_id,
1910 ))
1911 .detach();
1912 Ok(client_conn)
1913 }
1914 })
1915 },
1916 );
1917
1918 client
1919 .authenticate_and_connect(&cx.to_async())
1920 .await
1921 .unwrap();
1922 (client_user_id, client)
1923 }
1924
1925 fn disconnect_client(&self, user_id: UserId) {
1926 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
1927 let _ = kill_conn.try_send(Some(()));
1928 }
1929 }
1930
1931 fn forbid_connections(&self) {
1932 self.forbid_connections.store(true, SeqCst);
1933 }
1934
1935 fn allow_connections(&self) {
1936 self.forbid_connections.store(false, SeqCst);
1937 }
1938
1939 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
1940 let mut config = Config::default();
1941 config.session_secret = "a".repeat(32);
1942 config.database_url = test_db.url.clone();
1943 let github_client = github::AppClient::test();
1944 Arc::new(AppState {
1945 db: test_db.db().clone(),
1946 handlebars: Default::default(),
1947 auth_client: auth::build_client("", ""),
1948 repo_client: github::RepoClient::test(&github_client),
1949 github_client,
1950 config,
1951 })
1952 }
1953
1954 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
1955 self.server.state.read().await
1956 }
1957
1958 async fn condition<F>(&mut self, mut predicate: F)
1959 where
1960 F: FnMut(&ServerState) -> bool,
1961 {
1962 async_std::future::timeout(Duration::from_millis(500), async {
1963 while !(predicate)(&*self.server.state.read().await) {
1964 self.notifications.recv().await;
1965 }
1966 })
1967 .await
1968 .expect("condition timed out");
1969 }
1970 }
1971
1972 impl Drop for TestServer {
1973 fn drop(&mut self) {
1974 task::block_on(self.peer.reset());
1975 }
1976 }
1977
1978 struct EmptyView;
1979
1980 impl gpui::Entity for EmptyView {
1981 type Event = ();
1982 }
1983
1984 impl gpui::View for EmptyView {
1985 fn ui_name() -> &'static str {
1986 "empty view"
1987 }
1988
1989 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
1990 gpui::Element::boxed(gpui::elements::Empty)
1991 }
1992 }
1993}