1use super::{
2 auth,
3 db::{ChannelId, MessageId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
9use futures::{future::BoxFuture, FutureExt};
10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
11use sha1::{Digest as _, Sha1};
12use std::{
13 any::TypeId,
14 collections::{hash_map, HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use time::OffsetDateTime;
27use zrpc::{
28 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
29 Connection, ConnectionId, Peer, TypedEnvelope,
30};
31
32type ReplicaId = u16;
33
34type MessageHandler = Box<
35 dyn Send
36 + Sync
37 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
38>;
39
40pub struct Server {
41 peer: Arc<Peer>,
42 state: RwLock<ServerState>,
43 app_state: Arc<AppState>,
44 handlers: HashMap<TypeId, MessageHandler>,
45 notifications: Option<mpsc::Sender<()>>,
46}
47
48#[derive(Default)]
49struct ServerState {
50 connections: HashMap<ConnectionId, ConnectionState>,
51 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
52 pub worktrees: HashMap<u64, Worktree>,
53 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
54 channels: HashMap<ChannelId, Channel>,
55 next_worktree_id: u64,
56}
57
58struct ConnectionState {
59 user_id: UserId,
60 worktrees: HashSet<u64>,
61 channels: HashSet<ChannelId>,
62}
63
64struct Worktree {
65 host_connection_id: ConnectionId,
66 collaborator_user_ids: Vec<UserId>,
67 root_name: String,
68 share: Option<WorktreeShare>,
69}
70
71struct WorktreeShare {
72 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
73 active_replica_ids: HashSet<ReplicaId>,
74 entries: HashMap<u64, proto::Entry>,
75}
76
77#[derive(Default)]
78struct Channel {
79 connection_ids: HashSet<ConnectionId>,
80}
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84
85impl Server {
86 pub fn new(
87 app_state: Arc<AppState>,
88 peer: Arc<Peer>,
89 notifications: Option<mpsc::Sender<()>>,
90 ) -> Arc<Self> {
91 let mut server = Self {
92 peer,
93 app_state,
94 state: Default::default(),
95 handlers: Default::default(),
96 notifications,
97 };
98
99 server
100 .add_handler(Server::ping)
101 .add_handler(Server::open_worktree)
102 .add_handler(Server::close_worktree)
103 .add_handler(Server::share_worktree)
104 .add_handler(Server::unshare_worktree)
105 .add_handler(Server::join_worktree)
106 .add_handler(Server::update_worktree)
107 .add_handler(Server::open_buffer)
108 .add_handler(Server::close_buffer)
109 .add_handler(Server::update_buffer)
110 .add_handler(Server::buffer_saved)
111 .add_handler(Server::save_buffer)
112 .add_handler(Server::get_channels)
113 .add_handler(Server::get_users)
114 .add_handler(Server::join_channel)
115 .add_handler(Server::leave_channel)
116 .add_handler(Server::send_channel_message)
117 .add_handler(Server::get_channel_messages);
118
119 Arc::new(server)
120 }
121
122 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
123 where
124 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
125 Fut: 'static + Send + Future<Output = tide::Result<()>>,
126 M: EnvelopedMessage,
127 {
128 let prev_handler = self.handlers.insert(
129 TypeId::of::<M>(),
130 Box::new(move |server, envelope| {
131 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
132 (handler)(server, *envelope).boxed()
133 }),
134 );
135 if prev_handler.is_some() {
136 panic!("registered a handler for the same message twice");
137 }
138 self
139 }
140
141 pub fn handle_connection(
142 self: &Arc<Self>,
143 connection: Connection,
144 addr: String,
145 user_id: UserId,
146 ) -> impl Future<Output = ()> {
147 let this = self.clone();
148 async move {
149 let (connection_id, handle_io, mut incoming_rx) =
150 this.peer.add_connection(connection).await;
151 this.add_connection(connection_id, user_id).await;
152
153 let handle_io = handle_io.fuse();
154 futures::pin_mut!(handle_io);
155 loop {
156 let next_message = incoming_rx.recv().fuse();
157 futures::pin_mut!(next_message);
158 futures::select_biased! {
159 message = next_message => {
160 if let Some(message) = message {
161 let start_time = Instant::now();
162 log::info!("RPC message received: {}", message.payload_type_name());
163 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
164 if let Err(err) = (handler)(this.clone(), message).await {
165 log::error!("error handling message: {:?}", err);
166 } else {
167 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
168 }
169
170 if let Some(mut notifications) = this.notifications.clone() {
171 let _ = notifications.send(()).await;
172 }
173 } else {
174 log::warn!("unhandled message: {}", message.payload_type_name());
175 }
176 } else {
177 log::info!("rpc connection closed {:?}", addr);
178 break;
179 }
180 }
181 handle_io = handle_io => {
182 if let Err(err) = handle_io {
183 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
184 }
185 break;
186 }
187 }
188 }
189
190 if let Err(err) = this.sign_out(connection_id).await {
191 log::error!("error signing out connection {:?} - {:?}", addr, err);
192 }
193 }
194 }
195
196 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
197 self.peer.disconnect(connection_id).await;
198 let worktree_ids = self.remove_connection(connection_id).await;
199 for worktree_id in worktree_ids {
200 let state = self.state.read().await;
201 if let Some(worktree) = state.worktrees.get(&worktree_id) {
202 broadcast(connection_id, worktree.connection_ids(), |conn_id| {
203 self.peer.send(
204 conn_id,
205 proto::RemovePeer {
206 worktree_id,
207 peer_id: connection_id.0,
208 },
209 )
210 })
211 .await?;
212 }
213 }
214 Ok(())
215 }
216
217 // Add a new connection associated with a given user.
218 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
219 let mut state = self.state.write().await;
220 state.connections.insert(
221 connection_id,
222 ConnectionState {
223 user_id,
224 worktrees: Default::default(),
225 channels: Default::default(),
226 },
227 );
228 state
229 .connections_by_user_id
230 .entry(user_id)
231 .or_default()
232 .insert(connection_id);
233 }
234
235 // Remove the given connection and its association with any worktrees.
236 async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
237 let mut worktree_ids = Vec::new();
238 let mut state = self.state.write().await;
239 if let Some(connection) = state.connections.remove(&connection_id) {
240 for channel_id in connection.channels {
241 if let Some(channel) = state.channels.get_mut(&channel_id) {
242 channel.connection_ids.remove(&connection_id);
243 }
244 }
245 for worktree_id in connection.worktrees {
246 if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
247 if worktree.host_connection_id == connection_id {
248 worktree_ids.push(worktree_id);
249 } else if let Some(share_state) = worktree.share.as_mut() {
250 if let Some(replica_id) =
251 share_state.guest_connection_ids.remove(&connection_id)
252 {
253 share_state.active_replica_ids.remove(&replica_id);
254 worktree_ids.push(worktree_id);
255 }
256 }
257 }
258 }
259
260 let user_connections = state
261 .connections_by_user_id
262 .get_mut(&connection.user_id)
263 .unwrap();
264 user_connections.remove(&connection_id);
265 if user_connections.is_empty() {
266 state.connections_by_user_id.remove(&connection.user_id);
267 }
268 }
269 worktree_ids
270 }
271
272 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
273 self.peer.respond(request.receipt(), proto::Ack {}).await?;
274 Ok(())
275 }
276
277 async fn open_worktree(
278 self: Arc<Server>,
279 request: TypedEnvelope<proto::OpenWorktree>,
280 ) -> tide::Result<()> {
281 let receipt = request.receipt();
282 let user_id = self
283 .state
284 .read()
285 .await
286 .user_id_for_connection(request.sender_id)?;
287
288 let mut collaborator_user_ids = Vec::new();
289 for github_login in request.payload.collaborator_logins {
290 match self.app_state.db.create_user(&github_login, false).await {
291 Ok(collaborator_user_id) => {
292 if collaborator_user_id != user_id {
293 collaborator_user_ids.push(collaborator_user_id);
294 }
295 }
296 Err(err) => {
297 let message = err.to_string();
298 self.peer
299 .respond_with_error(receipt, proto::Error { message })
300 .await?;
301 return Ok(());
302 }
303 }
304 }
305
306 let mut state = self.state.write().await;
307 let worktree_id = state.add_worktree(Worktree {
308 host_connection_id: request.sender_id,
309 collaborator_user_ids: collaborator_user_ids.clone(),
310 root_name: request.payload.root_name,
311 share: None,
312 });
313
314 self.peer
315 .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
316 .await?;
317 self.update_collaborators(&collaborator_user_ids).await?;
318
319 Ok(())
320 }
321
322 async fn share_worktree(
323 self: Arc<Server>,
324 mut request: TypedEnvelope<proto::ShareWorktree>,
325 ) -> tide::Result<()> {
326 let worktree = request
327 .payload
328 .worktree
329 .as_mut()
330 .ok_or_else(|| anyhow!("missing worktree"))?;
331 let entries = mem::take(&mut worktree.entries)
332 .into_iter()
333 .map(|entry| (entry.id, entry))
334 .collect();
335 let mut state = self.state.write().await;
336 if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
337 worktree.share = Some(WorktreeShare {
338 guest_connection_ids: Default::default(),
339 active_replica_ids: Default::default(),
340 entries,
341 });
342 self.peer
343 .respond(request.receipt(), proto::ShareWorktreeResponse {})
344 .await?;
345
346 let collaborator_user_ids = worktree.collaborator_user_ids.clone();
347 drop(state);
348 self.update_collaborators(&collaborator_user_ids).await?;
349 } else {
350 self.peer
351 .respond_with_error(
352 request.receipt(),
353 proto::Error {
354 message: "no such worktree".to_string(),
355 },
356 )
357 .await?;
358 }
359 Ok(())
360 }
361
362 async fn unshare_worktree(
363 self: Arc<Server>,
364 request: TypedEnvelope<proto::UnshareWorktree>,
365 ) -> tide::Result<()> {
366 let worktree_id = request.payload.worktree_id;
367
368 let connection_ids;
369 let collaborator_user_ids;
370 {
371 let mut state = self.state.write().await;
372 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
373 if worktree.host_connection_id != request.sender_id {
374 return Err(anyhow!("no such worktree"))?;
375 }
376
377 connection_ids = worktree.connection_ids();
378 collaborator_user_ids = worktree.collaborator_user_ids.clone();
379 worktree.share.take();
380 for connection_id in &connection_ids {
381 if let Some(connection) = state.connections.get_mut(connection_id) {
382 connection.worktrees.remove(&worktree_id);
383 }
384 }
385 }
386
387 broadcast(request.sender_id, connection_ids, |conn_id| {
388 self.peer
389 .send(conn_id, proto::UnshareWorktree { worktree_id })
390 })
391 .await?;
392 self.update_collaborators(&collaborator_user_ids).await?;
393
394 Ok(())
395 }
396
397 async fn join_worktree(
398 self: Arc<Server>,
399 request: TypedEnvelope<proto::JoinWorktree>,
400 ) -> tide::Result<()> {
401 let worktree_id = request.payload.worktree_id;
402 let user_id = self
403 .state
404 .read()
405 .await
406 .user_id_for_connection(request.sender_id)?;
407
408 let response;
409 let connection_ids;
410 let collaborator_user_ids;
411 let mut state = self.state.write().await;
412 match state.join_worktree(request.sender_id, user_id, worktree_id) {
413 Ok((peer_replica_id, worktree)) => {
414 let share = worktree.share()?;
415 let peer_count = share.guest_connection_ids.len();
416 let mut peers = Vec::with_capacity(peer_count);
417 peers.push(proto::Peer {
418 peer_id: worktree.host_connection_id.0,
419 replica_id: 0,
420 });
421 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
422 if *peer_conn_id != request.sender_id {
423 peers.push(proto::Peer {
424 peer_id: peer_conn_id.0,
425 replica_id: *peer_replica_id as u32,
426 });
427 }
428 }
429 connection_ids = worktree.connection_ids();
430 collaborator_user_ids = worktree.collaborator_user_ids.clone();
431 response = proto::JoinWorktreeResponse {
432 worktree: Some(proto::Worktree {
433 id: worktree_id,
434 root_name: worktree.root_name.clone(),
435 entries: share.entries.values().cloned().collect(),
436 }),
437 replica_id: peer_replica_id as u32,
438 peers,
439 };
440 }
441 Err(error) => {
442 self.peer
443 .respond_with_error(
444 request.receipt(),
445 proto::Error {
446 message: error.to_string(),
447 },
448 )
449 .await?;
450 return Ok(());
451 }
452 }
453
454 broadcast(request.sender_id, connection_ids, |conn_id| {
455 self.peer.send(
456 conn_id,
457 proto::AddPeer {
458 worktree_id,
459 peer: Some(proto::Peer {
460 peer_id: request.sender_id.0,
461 replica_id: response.replica_id,
462 }),
463 },
464 )
465 })
466 .await?;
467 self.peer.respond(request.receipt(), response).await?;
468 self.update_collaborators(&collaborator_user_ids).await?;
469
470 Ok(())
471 }
472
473 async fn close_worktree(
474 self: Arc<Server>,
475 request: TypedEnvelope<proto::CloseWorktree>,
476 ) -> tide::Result<()> {
477 let worktree_id = request.payload.worktree_id;
478 let connection_ids;
479 let mut is_host = false;
480 let mut is_guest = false;
481 {
482 let mut state = self.state.write().await;
483 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
484 connection_ids = worktree.connection_ids();
485
486 if worktree.host_connection_id == request.sender_id {
487 is_host = true;
488 state.remove_worktree(worktree_id);
489 } else {
490 let share = worktree.share_mut()?;
491 if let Some(replica_id) = share.guest_connection_ids.remove(&request.sender_id) {
492 is_guest = true;
493 share.active_replica_ids.remove(&replica_id);
494 }
495 }
496 }
497
498 if is_host {
499 broadcast(request.sender_id, connection_ids, |conn_id| {
500 self.peer
501 .send(conn_id, proto::UnshareWorktree { worktree_id })
502 })
503 .await?;
504 } else if is_guest {
505 broadcast(request.sender_id, connection_ids, |conn_id| {
506 self.peer.send(
507 conn_id,
508 proto::RemovePeer {
509 worktree_id,
510 peer_id: request.sender_id.0,
511 },
512 )
513 })
514 .await?
515 }
516
517 Ok(())
518 }
519
520 async fn update_worktree(
521 self: Arc<Server>,
522 request: TypedEnvelope<proto::UpdateWorktree>,
523 ) -> tide::Result<()> {
524 {
525 let mut state = self.state.write().await;
526 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
527 let share = worktree.share_mut()?;
528
529 for entry_id in &request.payload.removed_entries {
530 share.entries.remove(&entry_id);
531 }
532
533 for entry in &request.payload.updated_entries {
534 share.entries.insert(entry.id, entry.clone());
535 }
536 }
537
538 self.broadcast_in_worktree(request.payload.worktree_id, &request)
539 .await?;
540 Ok(())
541 }
542
543 async fn open_buffer(
544 self: Arc<Server>,
545 request: TypedEnvelope<proto::OpenBuffer>,
546 ) -> tide::Result<()> {
547 let receipt = request.receipt();
548 let worktree_id = request.payload.worktree_id;
549 let host_connection_id = self
550 .state
551 .read()
552 .await
553 .read_worktree(worktree_id, request.sender_id)?
554 .host_connection_id;
555
556 let response = self
557 .peer
558 .forward_request(request.sender_id, host_connection_id, request.payload)
559 .await?;
560 self.peer.respond(receipt, response).await?;
561 Ok(())
562 }
563
564 async fn close_buffer(
565 self: Arc<Server>,
566 request: TypedEnvelope<proto::CloseBuffer>,
567 ) -> tide::Result<()> {
568 let host_connection_id = self
569 .state
570 .read()
571 .await
572 .read_worktree(request.payload.worktree_id, request.sender_id)?
573 .host_connection_id;
574
575 self.peer
576 .forward_send(request.sender_id, host_connection_id, request.payload)
577 .await?;
578
579 Ok(())
580 }
581
582 async fn save_buffer(
583 self: Arc<Server>,
584 request: TypedEnvelope<proto::SaveBuffer>,
585 ) -> tide::Result<()> {
586 let host;
587 let guests;
588 {
589 let state = self.state.read().await;
590 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
591 host = worktree.host_connection_id;
592 guests = worktree
593 .share()?
594 .guest_connection_ids
595 .keys()
596 .copied()
597 .collect::<Vec<_>>();
598 }
599
600 let sender = request.sender_id;
601 let receipt = request.receipt();
602 let response = self
603 .peer
604 .forward_request(sender, host, request.payload.clone())
605 .await?;
606
607 broadcast(host, guests, |conn_id| {
608 let response = response.clone();
609 let peer = &self.peer;
610 async move {
611 if conn_id == sender {
612 peer.respond(receipt, response).await
613 } else {
614 peer.forward_send(host, conn_id, response).await
615 }
616 }
617 })
618 .await?;
619
620 Ok(())
621 }
622
623 async fn update_buffer(
624 self: Arc<Server>,
625 request: TypedEnvelope<proto::UpdateBuffer>,
626 ) -> tide::Result<()> {
627 self.broadcast_in_worktree(request.payload.worktree_id, &request)
628 .await?;
629 self.peer.respond(request.receipt(), proto::Ack {}).await?;
630 Ok(())
631 }
632
633 async fn buffer_saved(
634 self: Arc<Server>,
635 request: TypedEnvelope<proto::BufferSaved>,
636 ) -> tide::Result<()> {
637 self.broadcast_in_worktree(request.payload.worktree_id, &request)
638 .await
639 }
640
641 async fn get_channels(
642 self: Arc<Server>,
643 request: TypedEnvelope<proto::GetChannels>,
644 ) -> tide::Result<()> {
645 let user_id = self
646 .state
647 .read()
648 .await
649 .user_id_for_connection(request.sender_id)?;
650 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
651 self.peer
652 .respond(
653 request.receipt(),
654 proto::GetChannelsResponse {
655 channels: channels
656 .into_iter()
657 .map(|chan| proto::Channel {
658 id: chan.id.to_proto(),
659 name: chan.name,
660 })
661 .collect(),
662 },
663 )
664 .await?;
665 Ok(())
666 }
667
668 async fn get_users(
669 self: Arc<Server>,
670 request: TypedEnvelope<proto::GetUsers>,
671 ) -> tide::Result<()> {
672 let user_id = self
673 .state
674 .read()
675 .await
676 .user_id_for_connection(request.sender_id)?;
677 let receipt = request.receipt();
678 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
679 let users = self
680 .app_state
681 .db
682 .get_users_by_ids(user_id, user_ids)
683 .await?
684 .into_iter()
685 .map(|user| proto::User {
686 id: user.id.to_proto(),
687 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
688 github_login: user.github_login,
689 })
690 .collect();
691 self.peer
692 .respond(receipt, proto::GetUsersResponse { users })
693 .await?;
694 Ok(())
695 }
696
697 async fn update_collaborators(self: &Arc<Server>, user_ids: &[UserId]) -> tide::Result<()> {
698 let mut send_futures = Vec::new();
699
700 let state = self.state.read().await;
701 for user_id in user_ids {
702 let mut collaborators = HashMap::new();
703 for worktree_id in state
704 .visible_worktrees_by_user_id
705 .get(&user_id)
706 .unwrap_or(&HashSet::new())
707 {
708 let worktree = &state.worktrees[worktree_id];
709
710 let mut participants = HashSet::new();
711 if let Ok(share) = worktree.share() {
712 for guest_connection_id in share.guest_connection_ids.keys() {
713 let user_id = state.user_id_for_connection(*guest_connection_id)?;
714 participants.insert(user_id.to_proto());
715 }
716 }
717
718 let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
719 let host =
720 collaborators
721 .entry(host_user_id)
722 .or_insert_with(|| proto::Collaborator {
723 user_id: host_user_id.to_proto(),
724 worktrees: Vec::new(),
725 });
726 host.worktrees.push(proto::WorktreeMetadata {
727 root_name: worktree.root_name.clone(),
728 is_shared: worktree.share().is_ok(),
729 participants: participants.into_iter().collect(),
730 });
731 }
732
733 let connection_ids = self
734 .state
735 .read()
736 .await
737 .user_connection_ids(*user_id)
738 .collect::<Vec<_>>();
739
740 let collaborators = collaborators.into_values().collect::<Vec<_>>();
741 for connection_id in connection_ids {
742 send_futures.push(self.peer.send(
743 connection_id,
744 proto::UpdateCollaborators {
745 collaborators: collaborators.clone(),
746 },
747 ));
748 }
749 }
750
751 futures::future::try_join_all(send_futures).await?;
752
753 Ok(())
754 }
755
756 async fn join_channel(
757 self: Arc<Self>,
758 request: TypedEnvelope<proto::JoinChannel>,
759 ) -> tide::Result<()> {
760 let user_id = self
761 .state
762 .read()
763 .await
764 .user_id_for_connection(request.sender_id)?;
765 let channel_id = ChannelId::from_proto(request.payload.channel_id);
766 if !self
767 .app_state
768 .db
769 .can_user_access_channel(user_id, channel_id)
770 .await?
771 {
772 Err(anyhow!("access denied"))?;
773 }
774
775 self.state
776 .write()
777 .await
778 .join_channel(request.sender_id, channel_id);
779 let messages = self
780 .app_state
781 .db
782 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
783 .await?
784 .into_iter()
785 .map(|msg| proto::ChannelMessage {
786 id: msg.id.to_proto(),
787 body: msg.body,
788 timestamp: msg.sent_at.unix_timestamp() as u64,
789 sender_id: msg.sender_id.to_proto(),
790 nonce: Some(msg.nonce.as_u128().into()),
791 })
792 .collect::<Vec<_>>();
793 self.peer
794 .respond(
795 request.receipt(),
796 proto::JoinChannelResponse {
797 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
798 messages,
799 },
800 )
801 .await?;
802 Ok(())
803 }
804
805 async fn leave_channel(
806 self: Arc<Self>,
807 request: TypedEnvelope<proto::LeaveChannel>,
808 ) -> tide::Result<()> {
809 let user_id = self
810 .state
811 .read()
812 .await
813 .user_id_for_connection(request.sender_id)?;
814 let channel_id = ChannelId::from_proto(request.payload.channel_id);
815 if !self
816 .app_state
817 .db
818 .can_user_access_channel(user_id, channel_id)
819 .await?
820 {
821 Err(anyhow!("access denied"))?;
822 }
823
824 self.state
825 .write()
826 .await
827 .leave_channel(request.sender_id, channel_id);
828
829 Ok(())
830 }
831
832 async fn send_channel_message(
833 self: Arc<Self>,
834 request: TypedEnvelope<proto::SendChannelMessage>,
835 ) -> tide::Result<()> {
836 let receipt = request.receipt();
837 let channel_id = ChannelId::from_proto(request.payload.channel_id);
838 let user_id;
839 let connection_ids;
840 {
841 let state = self.state.read().await;
842 user_id = state.user_id_for_connection(request.sender_id)?;
843 if let Some(channel) = state.channels.get(&channel_id) {
844 connection_ids = channel.connection_ids();
845 } else {
846 return Ok(());
847 }
848 }
849
850 // Validate the message body.
851 let body = request.payload.body.trim().to_string();
852 if body.len() > MAX_MESSAGE_LEN {
853 self.peer
854 .respond_with_error(
855 receipt,
856 proto::Error {
857 message: "message is too long".to_string(),
858 },
859 )
860 .await?;
861 return Ok(());
862 }
863 if body.is_empty() {
864 self.peer
865 .respond_with_error(
866 receipt,
867 proto::Error {
868 message: "message can't be blank".to_string(),
869 },
870 )
871 .await?;
872 return Ok(());
873 }
874
875 let timestamp = OffsetDateTime::now_utc();
876 let nonce = if let Some(nonce) = request.payload.nonce {
877 nonce
878 } else {
879 self.peer
880 .respond_with_error(
881 receipt,
882 proto::Error {
883 message: "nonce can't be blank".to_string(),
884 },
885 )
886 .await?;
887 return Ok(());
888 };
889
890 let message_id = self
891 .app_state
892 .db
893 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
894 .await?
895 .to_proto();
896 let message = proto::ChannelMessage {
897 sender_id: user_id.to_proto(),
898 id: message_id,
899 body,
900 timestamp: timestamp.unix_timestamp() as u64,
901 nonce: Some(nonce),
902 };
903 broadcast(request.sender_id, connection_ids, |conn_id| {
904 self.peer.send(
905 conn_id,
906 proto::ChannelMessageSent {
907 channel_id: channel_id.to_proto(),
908 message: Some(message.clone()),
909 },
910 )
911 })
912 .await?;
913 self.peer
914 .respond(
915 receipt,
916 proto::SendChannelMessageResponse {
917 message: Some(message),
918 },
919 )
920 .await?;
921 Ok(())
922 }
923
924 async fn get_channel_messages(
925 self: Arc<Self>,
926 request: TypedEnvelope<proto::GetChannelMessages>,
927 ) -> tide::Result<()> {
928 let user_id = self
929 .state
930 .read()
931 .await
932 .user_id_for_connection(request.sender_id)?;
933 let channel_id = ChannelId::from_proto(request.payload.channel_id);
934 if !self
935 .app_state
936 .db
937 .can_user_access_channel(user_id, channel_id)
938 .await?
939 {
940 Err(anyhow!("access denied"))?;
941 }
942
943 let messages = self
944 .app_state
945 .db
946 .get_channel_messages(
947 channel_id,
948 MESSAGE_COUNT_PER_PAGE,
949 Some(MessageId::from_proto(request.payload.before_message_id)),
950 )
951 .await?
952 .into_iter()
953 .map(|msg| proto::ChannelMessage {
954 id: msg.id.to_proto(),
955 body: msg.body,
956 timestamp: msg.sent_at.unix_timestamp() as u64,
957 sender_id: msg.sender_id.to_proto(),
958 nonce: Some(msg.nonce.as_u128().into()),
959 })
960 .collect::<Vec<_>>();
961 self.peer
962 .respond(
963 request.receipt(),
964 proto::GetChannelMessagesResponse {
965 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
966 messages,
967 },
968 )
969 .await?;
970 Ok(())
971 }
972
973 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
974 &self,
975 worktree_id: u64,
976 message: &TypedEnvelope<T>,
977 ) -> tide::Result<()> {
978 let connection_ids = self
979 .state
980 .read()
981 .await
982 .read_worktree(worktree_id, message.sender_id)?
983 .connection_ids();
984
985 broadcast(message.sender_id, connection_ids, |conn_id| {
986 self.peer
987 .forward_send(message.sender_id, conn_id, message.payload.clone())
988 })
989 .await?;
990
991 Ok(())
992 }
993}
994
995pub async fn broadcast<F, T>(
996 sender_id: ConnectionId,
997 receiver_ids: Vec<ConnectionId>,
998 mut f: F,
999) -> anyhow::Result<()>
1000where
1001 F: FnMut(ConnectionId) -> T,
1002 T: Future<Output = anyhow::Result<()>>,
1003{
1004 let futures = receiver_ids
1005 .into_iter()
1006 .filter(|id| *id != sender_id)
1007 .map(|id| f(id));
1008 futures::future::try_join_all(futures).await?;
1009 Ok(())
1010}
1011
1012impl ServerState {
1013 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1014 if let Some(connection) = self.connections.get_mut(&connection_id) {
1015 connection.channels.insert(channel_id);
1016 self.channels
1017 .entry(channel_id)
1018 .or_default()
1019 .connection_ids
1020 .insert(connection_id);
1021 }
1022 }
1023
1024 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1025 if let Some(connection) = self.connections.get_mut(&connection_id) {
1026 connection.channels.remove(&channel_id);
1027 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1028 entry.get_mut().connection_ids.remove(&connection_id);
1029 if entry.get_mut().connection_ids.is_empty() {
1030 entry.remove();
1031 }
1032 }
1033 }
1034 }
1035
1036 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1037 Ok(self
1038 .connections
1039 .get(&connection_id)
1040 .ok_or_else(|| anyhow!("unknown connection"))?
1041 .user_id)
1042 }
1043
1044 fn user_connection_ids<'a>(
1045 &'a self,
1046 user_id: UserId,
1047 ) -> impl 'a + Iterator<Item = ConnectionId> {
1048 self.connections_by_user_id
1049 .get(&user_id)
1050 .into_iter()
1051 .flatten()
1052 .copied()
1053 }
1054
1055 fn is_online(&self, user_id: UserId) -> bool {
1056 self.connections_by_user_id.contains_key(&user_id)
1057 }
1058
1059 // Add the given connection as a guest of the given worktree
1060 fn join_worktree(
1061 &mut self,
1062 connection_id: ConnectionId,
1063 user_id: UserId,
1064 worktree_id: u64,
1065 ) -> tide::Result<(ReplicaId, &Worktree)> {
1066 let connection = self
1067 .connections
1068 .get_mut(&connection_id)
1069 .ok_or_else(|| anyhow!("no such connection"))?;
1070 let worktree = self
1071 .worktrees
1072 .get_mut(&worktree_id)
1073 .ok_or_else(|| anyhow!("no such worktree"))?;
1074 if !worktree.collaborator_user_ids.contains(&user_id) {
1075 Err(anyhow!("no such worktree"))?;
1076 }
1077
1078 let share = worktree.share_mut()?;
1079 connection.worktrees.insert(worktree_id);
1080
1081 let mut replica_id = 1;
1082 while share.active_replica_ids.contains(&replica_id) {
1083 replica_id += 1;
1084 }
1085 share.active_replica_ids.insert(replica_id);
1086 share.guest_connection_ids.insert(connection_id, replica_id);
1087 return Ok((replica_id, worktree));
1088 }
1089
1090 fn read_worktree(
1091 &self,
1092 worktree_id: u64,
1093 connection_id: ConnectionId,
1094 ) -> tide::Result<&Worktree> {
1095 let worktree = self
1096 .worktrees
1097 .get(&worktree_id)
1098 .ok_or_else(|| anyhow!("worktree not found"))?;
1099
1100 if worktree.host_connection_id == connection_id
1101 || worktree
1102 .share()?
1103 .guest_connection_ids
1104 .contains_key(&connection_id)
1105 {
1106 Ok(worktree)
1107 } else {
1108 Err(anyhow!(
1109 "{} is not a member of worktree {}",
1110 connection_id,
1111 worktree_id
1112 ))?
1113 }
1114 }
1115
1116 fn write_worktree(
1117 &mut self,
1118 worktree_id: u64,
1119 connection_id: ConnectionId,
1120 ) -> tide::Result<&mut Worktree> {
1121 let worktree = self
1122 .worktrees
1123 .get_mut(&worktree_id)
1124 .ok_or_else(|| anyhow!("worktree not found"))?;
1125
1126 if worktree.host_connection_id == connection_id
1127 || worktree.share.as_ref().map_or(false, |share| {
1128 share.guest_connection_ids.contains_key(&connection_id)
1129 })
1130 {
1131 Ok(worktree)
1132 } else {
1133 Err(anyhow!(
1134 "{} is not a member of worktree {}",
1135 connection_id,
1136 worktree_id
1137 ))?
1138 }
1139 }
1140
1141 fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1142 let worktree_id = self.next_worktree_id;
1143 for collaborator_user_id in &worktree.collaborator_user_ids {
1144 self.visible_worktrees_by_user_id
1145 .entry(*collaborator_user_id)
1146 .or_default()
1147 .insert(worktree_id);
1148 }
1149 self.next_worktree_id += 1;
1150 self.worktrees.insert(worktree_id, worktree);
1151 worktree_id
1152 }
1153
1154 fn remove_worktree(&mut self, worktree_id: u64) {
1155 let worktree = self.worktrees.remove(&worktree_id).unwrap();
1156 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1157 connection.worktrees.remove(&worktree_id);
1158 }
1159 if let Some(share) = worktree.share {
1160 for connection_id in share.guest_connection_ids.keys() {
1161 if let Some(connection) = self.connections.get_mut(connection_id) {
1162 connection.worktrees.remove(&worktree_id);
1163 }
1164 }
1165 }
1166 for collaborator_user_id in worktree.collaborator_user_ids {
1167 if let Some(visible_worktrees) = self
1168 .visible_worktrees_by_user_id
1169 .get_mut(&collaborator_user_id)
1170 {
1171 visible_worktrees.remove(&worktree_id);
1172 }
1173 }
1174 }
1175}
1176
1177impl Worktree {
1178 pub fn connection_ids(&self) -> Vec<ConnectionId> {
1179 if let Some(share) = &self.share {
1180 share
1181 .guest_connection_ids
1182 .keys()
1183 .copied()
1184 .chain(Some(self.host_connection_id))
1185 .collect()
1186 } else {
1187 vec![self.host_connection_id]
1188 }
1189 }
1190
1191 fn share(&self) -> tide::Result<&WorktreeShare> {
1192 Ok(self
1193 .share
1194 .as_ref()
1195 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1196 }
1197
1198 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1199 Ok(self
1200 .share
1201 .as_mut()
1202 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1203 }
1204}
1205
1206impl Channel {
1207 fn connection_ids(&self) -> Vec<ConnectionId> {
1208 self.connection_ids.iter().copied().collect()
1209 }
1210}
1211
1212pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1213 let server = Server::new(app.state().clone(), rpc.clone(), None);
1214 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1215 let user_id = request.ext::<UserId>().copied();
1216 let server = server.clone();
1217 async move {
1218 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1219
1220 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1221 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1222 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1223
1224 if !upgrade_requested {
1225 return Ok(Response::new(StatusCode::UpgradeRequired));
1226 }
1227
1228 let header = match request.header("Sec-Websocket-Key") {
1229 Some(h) => h.as_str(),
1230 None => return Err(anyhow!("expected sec-websocket-key"))?,
1231 };
1232
1233 let mut response = Response::new(StatusCode::SwitchingProtocols);
1234 response.insert_header(UPGRADE, "websocket");
1235 response.insert_header(CONNECTION, "Upgrade");
1236 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1237 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1238 response.insert_header("Sec-Websocket-Version", "13");
1239
1240 let http_res: &mut tide::http::Response = response.as_mut();
1241 let upgrade_receiver = http_res.recv_upgrade().await;
1242 let addr = request.remote().unwrap_or("unknown").to_string();
1243 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1244 task::spawn(async move {
1245 if let Some(stream) = upgrade_receiver.await {
1246 server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1247 }
1248 });
1249
1250 Ok(response)
1251 }
1252 });
1253}
1254
1255fn header_contains_ignore_case<T>(
1256 request: &tide::Request<T>,
1257 header_name: HeaderName,
1258 value: &str,
1259) -> bool {
1260 request
1261 .header(header_name)
1262 .map(|h| {
1263 h.as_str()
1264 .split(',')
1265 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1266 })
1267 .unwrap_or(false)
1268}
1269
1270#[cfg(test)]
1271mod tests {
1272 use super::*;
1273 use crate::{
1274 auth,
1275 db::{tests::TestDb, UserId},
1276 github, AppState, Config,
1277 };
1278 use async_std::{sync::RwLockReadGuard, task};
1279 use gpui::TestAppContext;
1280 use parking_lot::Mutex;
1281 use postage::{mpsc, watch};
1282 use serde_json::json;
1283 use sqlx::types::time::OffsetDateTime;
1284 use std::{
1285 path::Path,
1286 sync::{
1287 atomic::{AtomicBool, Ordering::SeqCst},
1288 Arc,
1289 },
1290 time::Duration,
1291 };
1292 use zed::{
1293 channel::{Channel, ChannelDetails, ChannelList},
1294 editor::{Editor, Insert},
1295 fs::{FakeFs, Fs as _},
1296 language::LanguageRegistry,
1297 rpc::{self, Client, Credentials, EstablishConnectionError},
1298 settings,
1299 test::FakeHttpClient,
1300 user::UserStore,
1301 worktree::Worktree,
1302 };
1303 use zrpc::Peer;
1304
1305 #[gpui::test]
1306 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1307 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1308 let settings = cx_b.read(settings::test).1;
1309 let lang_registry = Arc::new(LanguageRegistry::new());
1310
1311 // Connect to a server as 2 clients.
1312 let mut server = TestServer::start().await;
1313 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1314 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1315
1316 cx_a.foreground().forbid_parking();
1317
1318 // Share a local worktree as client A
1319 let fs = Arc::new(FakeFs::new());
1320 fs.insert_tree(
1321 "/a",
1322 json!({
1323 ".zed.toml": r#"collaborators = ["user_b"]"#,
1324 "a.txt": "a-contents",
1325 "b.txt": "b-contents",
1326 }),
1327 )
1328 .await;
1329 let worktree_a = Worktree::open_local(
1330 client_a.clone(),
1331 "/a".as_ref(),
1332 fs,
1333 lang_registry.clone(),
1334 &mut cx_a.to_async(),
1335 )
1336 .await
1337 .unwrap();
1338 worktree_a
1339 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1340 .await;
1341 let worktree_id = worktree_a
1342 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1343 .await
1344 .unwrap();
1345
1346 // Join that worktree as client B, and see that a guest has joined as client A.
1347 let worktree_b = Worktree::open_remote(
1348 client_b.clone(),
1349 worktree_id,
1350 lang_registry.clone(),
1351 &mut cx_b.to_async(),
1352 )
1353 .await
1354 .unwrap();
1355 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1356 worktree_a
1357 .condition(&cx_a, |tree, _| {
1358 tree.peers()
1359 .values()
1360 .any(|replica_id| *replica_id == replica_id_b)
1361 })
1362 .await;
1363
1364 // Open the same file as client B and client A.
1365 let buffer_b = worktree_b
1366 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1367 .await
1368 .unwrap();
1369 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1370 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1371 let buffer_a = worktree_a
1372 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1373 .await
1374 .unwrap();
1375
1376 // Create a selection set as client B and see that selection set as client A.
1377 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1378 buffer_a
1379 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1380 .await;
1381
1382 // Edit the buffer as client B and see that edit as client A.
1383 editor_b.update(&mut cx_b, |editor, cx| {
1384 editor.insert(&Insert("ok, ".into()), cx)
1385 });
1386 buffer_a
1387 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1388 .await;
1389
1390 // Remove the selection set as client B, see those selections disappear as client A.
1391 cx_b.update(move |_| drop(editor_b));
1392 buffer_a
1393 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1394 .await;
1395
1396 // Close the buffer as client A, see that the buffer is closed.
1397 cx_a.update(move |_| drop(buffer_a));
1398 worktree_a
1399 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1400 .await;
1401
1402 // Dropping the worktree removes client B from client A's peers.
1403 cx_b.update(move |_| drop(worktree_b));
1404 worktree_a
1405 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1406 .await;
1407 }
1408
1409 #[gpui::test]
1410 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1411 mut cx_a: TestAppContext,
1412 mut cx_b: TestAppContext,
1413 mut cx_c: TestAppContext,
1414 ) {
1415 cx_a.foreground().forbid_parking();
1416 let lang_registry = Arc::new(LanguageRegistry::new());
1417
1418 // Connect to a server as 3 clients.
1419 let mut server = TestServer::start().await;
1420 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1421 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1422 let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1423
1424 let fs = Arc::new(FakeFs::new());
1425
1426 // Share a worktree as client A.
1427 fs.insert_tree(
1428 "/a",
1429 json!({
1430 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1431 "file1": "",
1432 "file2": ""
1433 }),
1434 )
1435 .await;
1436
1437 let worktree_a = Worktree::open_local(
1438 client_a.clone(),
1439 "/a".as_ref(),
1440 fs.clone(),
1441 lang_registry.clone(),
1442 &mut cx_a.to_async(),
1443 )
1444 .await
1445 .unwrap();
1446 worktree_a
1447 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1448 .await;
1449 let worktree_id = worktree_a
1450 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1451 .await
1452 .unwrap();
1453
1454 // Join that worktree as clients B and C.
1455 let worktree_b = Worktree::open_remote(
1456 client_b.clone(),
1457 worktree_id,
1458 lang_registry.clone(),
1459 &mut cx_b.to_async(),
1460 )
1461 .await
1462 .unwrap();
1463 let worktree_c = Worktree::open_remote(
1464 client_c.clone(),
1465 worktree_id,
1466 lang_registry.clone(),
1467 &mut cx_c.to_async(),
1468 )
1469 .await
1470 .unwrap();
1471
1472 // Open and edit a buffer as both guests B and C.
1473 let buffer_b = worktree_b
1474 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1475 .await
1476 .unwrap();
1477 let buffer_c = worktree_c
1478 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1479 .await
1480 .unwrap();
1481 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1482 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1483
1484 // Open and edit that buffer as the host.
1485 let buffer_a = worktree_a
1486 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1487 .await
1488 .unwrap();
1489
1490 buffer_a
1491 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1492 .await;
1493 buffer_a.update(&mut cx_a, |buf, cx| {
1494 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1495 });
1496
1497 // Wait for edits to propagate
1498 buffer_a
1499 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1500 .await;
1501 buffer_b
1502 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1503 .await;
1504 buffer_c
1505 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1506 .await;
1507
1508 // Edit the buffer as the host and concurrently save as guest B.
1509 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1510 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1511 save_b.await.unwrap();
1512 assert_eq!(
1513 fs.load("/a/file1".as_ref()).await.unwrap(),
1514 "hi-a, i-am-c, i-am-b, i-am-a"
1515 );
1516 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1517 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1518 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1519
1520 // Make changes on host's file system, see those changes on the guests.
1521 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1522 .await
1523 .unwrap();
1524 fs.insert_file(Path::new("/a/file4"), "4".into())
1525 .await
1526 .unwrap();
1527
1528 worktree_b
1529 .condition(&cx_b, |tree, _| tree.file_count() == 4)
1530 .await;
1531 worktree_c
1532 .condition(&cx_c, |tree, _| tree.file_count() == 4)
1533 .await;
1534 worktree_b.read_with(&cx_b, |tree, _| {
1535 assert_eq!(
1536 tree.paths()
1537 .map(|p| p.to_string_lossy())
1538 .collect::<Vec<_>>(),
1539 &[".zed.toml", "file1", "file3", "file4"]
1540 )
1541 });
1542 worktree_c.read_with(&cx_c, |tree, _| {
1543 assert_eq!(
1544 tree.paths()
1545 .map(|p| p.to_string_lossy())
1546 .collect::<Vec<_>>(),
1547 &[".zed.toml", "file1", "file3", "file4"]
1548 )
1549 });
1550 }
1551
1552 #[gpui::test]
1553 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1554 cx_a.foreground().forbid_parking();
1555 let lang_registry = Arc::new(LanguageRegistry::new());
1556
1557 // Connect to a server as 2 clients.
1558 let mut server = TestServer::start().await;
1559 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1560 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1561
1562 // Share a local worktree as client A
1563 let fs = Arc::new(FakeFs::new());
1564 fs.insert_tree(
1565 "/dir",
1566 json!({
1567 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1568 "a.txt": "a-contents",
1569 }),
1570 )
1571 .await;
1572
1573 let worktree_a = Worktree::open_local(
1574 client_a.clone(),
1575 "/dir".as_ref(),
1576 fs,
1577 lang_registry.clone(),
1578 &mut cx_a.to_async(),
1579 )
1580 .await
1581 .unwrap();
1582 worktree_a
1583 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1584 .await;
1585 let worktree_id = worktree_a
1586 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1587 .await
1588 .unwrap();
1589
1590 // Join that worktree as client B, and see that a guest has joined as client A.
1591 let worktree_b = Worktree::open_remote(
1592 client_b.clone(),
1593 worktree_id,
1594 lang_registry.clone(),
1595 &mut cx_b.to_async(),
1596 )
1597 .await
1598 .unwrap();
1599
1600 let buffer_b = worktree_b
1601 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1602 .await
1603 .unwrap();
1604 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1605
1606 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1607 buffer_b.read_with(&cx_b, |buf, _| {
1608 assert!(buf.is_dirty());
1609 assert!(!buf.has_conflict());
1610 });
1611
1612 buffer_b
1613 .update(&mut cx_b, |buf, cx| buf.save(cx))
1614 .unwrap()
1615 .await
1616 .unwrap();
1617 worktree_b
1618 .condition(&cx_b, |_, cx| {
1619 buffer_b.read(cx).file().unwrap().mtime != mtime
1620 })
1621 .await;
1622 buffer_b.read_with(&cx_b, |buf, _| {
1623 assert!(!buf.is_dirty());
1624 assert!(!buf.has_conflict());
1625 });
1626
1627 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1628 buffer_b.read_with(&cx_b, |buf, _| {
1629 assert!(buf.is_dirty());
1630 assert!(!buf.has_conflict());
1631 });
1632 }
1633
1634 #[gpui::test]
1635 async fn test_editing_while_guest_opens_buffer(
1636 mut cx_a: TestAppContext,
1637 mut cx_b: TestAppContext,
1638 ) {
1639 cx_a.foreground().forbid_parking();
1640 let lang_registry = Arc::new(LanguageRegistry::new());
1641
1642 // Connect to a server as 2 clients.
1643 let mut server = TestServer::start().await;
1644 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1645 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1646
1647 // Share a local worktree as client A
1648 let fs = Arc::new(FakeFs::new());
1649 fs.insert_tree(
1650 "/dir",
1651 json!({
1652 ".zed.toml": r#"collaborators = ["user_b"]"#,
1653 "a.txt": "a-contents",
1654 }),
1655 )
1656 .await;
1657 let worktree_a = Worktree::open_local(
1658 client_a.clone(),
1659 "/dir".as_ref(),
1660 fs,
1661 lang_registry.clone(),
1662 &mut cx_a.to_async(),
1663 )
1664 .await
1665 .unwrap();
1666 worktree_a
1667 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1668 .await;
1669 let worktree_id = worktree_a
1670 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1671 .await
1672 .unwrap();
1673
1674 // Join that worktree as client B, and see that a guest has joined as client A.
1675 let worktree_b = Worktree::open_remote(
1676 client_b.clone(),
1677 worktree_id,
1678 lang_registry.clone(),
1679 &mut cx_b.to_async(),
1680 )
1681 .await
1682 .unwrap();
1683
1684 let buffer_a = worktree_a
1685 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1686 .await
1687 .unwrap();
1688 let buffer_b = cx_b
1689 .background()
1690 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1691
1692 task::yield_now().await;
1693 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1694
1695 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1696 let buffer_b = buffer_b.await.unwrap();
1697 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1698 }
1699
1700 #[gpui::test]
1701 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1702 cx_a.foreground().forbid_parking();
1703 let lang_registry = Arc::new(LanguageRegistry::new());
1704
1705 // Connect to a server as 2 clients.
1706 let mut server = TestServer::start().await;
1707 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1708 let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1709
1710 // Share a local worktree as client A
1711 let fs = Arc::new(FakeFs::new());
1712 fs.insert_tree(
1713 "/a",
1714 json!({
1715 ".zed.toml": r#"collaborators = ["user_b"]"#,
1716 "a.txt": "a-contents",
1717 "b.txt": "b-contents",
1718 }),
1719 )
1720 .await;
1721 let worktree_a = Worktree::open_local(
1722 client_a.clone(),
1723 "/a".as_ref(),
1724 fs,
1725 lang_registry.clone(),
1726 &mut cx_a.to_async(),
1727 )
1728 .await
1729 .unwrap();
1730 worktree_a
1731 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1732 .await;
1733 let worktree_id = worktree_a
1734 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1735 .await
1736 .unwrap();
1737
1738 // Join that worktree as client B, and see that a guest has joined as client A.
1739 let _worktree_b = Worktree::open_remote(
1740 client_b.clone(),
1741 worktree_id,
1742 lang_registry.clone(),
1743 &mut cx_b.to_async(),
1744 )
1745 .await
1746 .unwrap();
1747 worktree_a
1748 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1749 .await;
1750
1751 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1752 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1753 worktree_a
1754 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1755 .await;
1756 }
1757
1758 #[gpui::test]
1759 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1760 cx_a.foreground().forbid_parking();
1761
1762 // Connect to a server as 2 clients.
1763 let mut server = TestServer::start().await;
1764 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1765 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1766
1767 // Create an org that includes these 2 users.
1768 let db = &server.app_state.db;
1769 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1770 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1771 .await
1772 .unwrap();
1773 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1774 .await
1775 .unwrap();
1776
1777 // Create a channel that includes all the users.
1778 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1779 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1780 .await
1781 .unwrap();
1782 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1783 .await
1784 .unwrap();
1785 db.create_channel_message(
1786 channel_id,
1787 current_user_id(&user_store_b),
1788 "hello A, it's B.",
1789 OffsetDateTime::now_utc(),
1790 1,
1791 )
1792 .await
1793 .unwrap();
1794
1795 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1796 channels_a
1797 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1798 .await;
1799 channels_a.read_with(&cx_a, |list, _| {
1800 assert_eq!(
1801 list.available_channels().unwrap(),
1802 &[ChannelDetails {
1803 id: channel_id.to_proto(),
1804 name: "test-channel".to_string()
1805 }]
1806 )
1807 });
1808 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1809 this.get_channel(channel_id.to_proto(), cx).unwrap()
1810 });
1811 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1812 channel_a
1813 .condition(&cx_a, |channel, _| {
1814 channel_messages(channel)
1815 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1816 })
1817 .await;
1818
1819 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1820 channels_b
1821 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1822 .await;
1823 channels_b.read_with(&cx_b, |list, _| {
1824 assert_eq!(
1825 list.available_channels().unwrap(),
1826 &[ChannelDetails {
1827 id: channel_id.to_proto(),
1828 name: "test-channel".to_string()
1829 }]
1830 )
1831 });
1832
1833 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1834 this.get_channel(channel_id.to_proto(), cx).unwrap()
1835 });
1836 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1837 channel_b
1838 .condition(&cx_b, |channel, _| {
1839 channel_messages(channel)
1840 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1841 })
1842 .await;
1843
1844 channel_a
1845 .update(&mut cx_a, |channel, cx| {
1846 channel
1847 .send_message("oh, hi B.".to_string(), cx)
1848 .unwrap()
1849 .detach();
1850 let task = channel.send_message("sup".to_string(), cx).unwrap();
1851 assert_eq!(
1852 channel_messages(channel),
1853 &[
1854 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1855 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1856 ("user_a".to_string(), "sup".to_string(), true)
1857 ]
1858 );
1859 task
1860 })
1861 .await
1862 .unwrap();
1863
1864 channel_b
1865 .condition(&cx_b, |channel, _| {
1866 channel_messages(channel)
1867 == [
1868 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1869 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1870 ("user_a".to_string(), "sup".to_string(), false),
1871 ]
1872 })
1873 .await;
1874
1875 assert_eq!(
1876 server.state().await.channels[&channel_id]
1877 .connection_ids
1878 .len(),
1879 2
1880 );
1881 cx_b.update(|_| drop(channel_b));
1882 server
1883 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1884 .await;
1885
1886 cx_a.update(|_| drop(channel_a));
1887 server
1888 .condition(|state| !state.channels.contains_key(&channel_id))
1889 .await;
1890 }
1891
1892 #[gpui::test]
1893 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1894 cx_a.foreground().forbid_parking();
1895
1896 let mut server = TestServer::start().await;
1897 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1898
1899 let db = &server.app_state.db;
1900 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1901 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1902 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1903 .await
1904 .unwrap();
1905 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1906 .await
1907 .unwrap();
1908
1909 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1910 channels_a
1911 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1912 .await;
1913 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1914 this.get_channel(channel_id.to_proto(), cx).unwrap()
1915 });
1916
1917 // Messages aren't allowed to be too long.
1918 channel_a
1919 .update(&mut cx_a, |channel, cx| {
1920 let long_body = "this is long.\n".repeat(1024);
1921 channel.send_message(long_body, cx).unwrap()
1922 })
1923 .await
1924 .unwrap_err();
1925
1926 // Messages aren't allowed to be blank.
1927 channel_a.update(&mut cx_a, |channel, cx| {
1928 channel.send_message(String::new(), cx).unwrap_err()
1929 });
1930
1931 // Leading and trailing whitespace are trimmed.
1932 channel_a
1933 .update(&mut cx_a, |channel, cx| {
1934 channel
1935 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1936 .unwrap()
1937 })
1938 .await
1939 .unwrap();
1940 assert_eq!(
1941 db.get_channel_messages(channel_id, 10, None)
1942 .await
1943 .unwrap()
1944 .iter()
1945 .map(|m| &m.body)
1946 .collect::<Vec<_>>(),
1947 &["surrounded by whitespace"]
1948 );
1949 }
1950
1951 #[gpui::test]
1952 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1953 cx_a.foreground().forbid_parking();
1954 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1955
1956 // Connect to a server as 2 clients.
1957 let mut server = TestServer::start().await;
1958 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1959 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1960 let mut status_b = client_b.status();
1961
1962 // Create an org that includes these 2 users.
1963 let db = &server.app_state.db;
1964 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1965 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1966 .await
1967 .unwrap();
1968 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1969 .await
1970 .unwrap();
1971
1972 // Create a channel that includes all the users.
1973 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1974 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1975 .await
1976 .unwrap();
1977 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1978 .await
1979 .unwrap();
1980 db.create_channel_message(
1981 channel_id,
1982 current_user_id(&user_store_b),
1983 "hello A, it's B.",
1984 OffsetDateTime::now_utc(),
1985 2,
1986 )
1987 .await
1988 .unwrap();
1989
1990 let user_store_a =
1991 UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
1992 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1993 channels_a
1994 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1995 .await;
1996
1997 channels_a.read_with(&cx_a, |list, _| {
1998 assert_eq!(
1999 list.available_channels().unwrap(),
2000 &[ChannelDetails {
2001 id: channel_id.to_proto(),
2002 name: "test-channel".to_string()
2003 }]
2004 )
2005 });
2006 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
2007 this.get_channel(channel_id.to_proto(), cx).unwrap()
2008 });
2009 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
2010 channel_a
2011 .condition(&cx_a, |channel, _| {
2012 channel_messages(channel)
2013 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2014 })
2015 .await;
2016
2017 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2018 channels_b
2019 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2020 .await;
2021 channels_b.read_with(&cx_b, |list, _| {
2022 assert_eq!(
2023 list.available_channels().unwrap(),
2024 &[ChannelDetails {
2025 id: channel_id.to_proto(),
2026 name: "test-channel".to_string()
2027 }]
2028 )
2029 });
2030
2031 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2032 this.get_channel(channel_id.to_proto(), cx).unwrap()
2033 });
2034 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2035 channel_b
2036 .condition(&cx_b, |channel, _| {
2037 channel_messages(channel)
2038 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2039 })
2040 .await;
2041
2042 // Disconnect client B, ensuring we can still access its cached channel data.
2043 server.forbid_connections();
2044 server.disconnect_client(current_user_id(&user_store_b));
2045 while !matches!(
2046 status_b.recv().await,
2047 Some(rpc::Status::ReconnectionError { .. })
2048 ) {}
2049
2050 channels_b.read_with(&cx_b, |channels, _| {
2051 assert_eq!(
2052 channels.available_channels().unwrap(),
2053 [ChannelDetails {
2054 id: channel_id.to_proto(),
2055 name: "test-channel".to_string()
2056 }]
2057 )
2058 });
2059 channel_b.read_with(&cx_b, |channel, _| {
2060 assert_eq!(
2061 channel_messages(channel),
2062 [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2063 )
2064 });
2065
2066 // Send a message from client B while it is disconnected.
2067 channel_b
2068 .update(&mut cx_b, |channel, cx| {
2069 let task = channel
2070 .send_message("can you see this?".to_string(), cx)
2071 .unwrap();
2072 assert_eq!(
2073 channel_messages(channel),
2074 &[
2075 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2076 ("user_b".to_string(), "can you see this?".to_string(), true)
2077 ]
2078 );
2079 task
2080 })
2081 .await
2082 .unwrap_err();
2083
2084 // Send a message from client A while B is disconnected.
2085 channel_a
2086 .update(&mut cx_a, |channel, cx| {
2087 channel
2088 .send_message("oh, hi B.".to_string(), cx)
2089 .unwrap()
2090 .detach();
2091 let task = channel.send_message("sup".to_string(), cx).unwrap();
2092 assert_eq!(
2093 channel_messages(channel),
2094 &[
2095 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2096 ("user_a".to_string(), "oh, hi B.".to_string(), true),
2097 ("user_a".to_string(), "sup".to_string(), true)
2098 ]
2099 );
2100 task
2101 })
2102 .await
2103 .unwrap();
2104
2105 // Give client B a chance to reconnect.
2106 server.allow_connections();
2107 cx_b.foreground().advance_clock(Duration::from_secs(10));
2108
2109 // Verify that B sees the new messages upon reconnection, as well as the message client B
2110 // sent while offline.
2111 channel_b
2112 .condition(&cx_b, |channel, _| {
2113 channel_messages(channel)
2114 == [
2115 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2116 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2117 ("user_a".to_string(), "sup".to_string(), false),
2118 ("user_b".to_string(), "can you see this?".to_string(), false),
2119 ]
2120 })
2121 .await;
2122
2123 // Ensure client A and B can communicate normally after reconnection.
2124 channel_a
2125 .update(&mut cx_a, |channel, cx| {
2126 channel.send_message("you online?".to_string(), cx).unwrap()
2127 })
2128 .await
2129 .unwrap();
2130 channel_b
2131 .condition(&cx_b, |channel, _| {
2132 channel_messages(channel)
2133 == [
2134 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2135 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2136 ("user_a".to_string(), "sup".to_string(), false),
2137 ("user_b".to_string(), "can you see this?".to_string(), false),
2138 ("user_a".to_string(), "you online?".to_string(), false),
2139 ]
2140 })
2141 .await;
2142
2143 channel_b
2144 .update(&mut cx_b, |channel, cx| {
2145 channel.send_message("yep".to_string(), cx).unwrap()
2146 })
2147 .await
2148 .unwrap();
2149 channel_a
2150 .condition(&cx_a, |channel, _| {
2151 channel_messages(channel)
2152 == [
2153 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2154 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2155 ("user_a".to_string(), "sup".to_string(), false),
2156 ("user_b".to_string(), "can you see this?".to_string(), false),
2157 ("user_a".to_string(), "you online?".to_string(), false),
2158 ("user_b".to_string(), "yep".to_string(), false),
2159 ]
2160 })
2161 .await;
2162 }
2163
2164 struct TestServer {
2165 peer: Arc<Peer>,
2166 app_state: Arc<AppState>,
2167 server: Arc<Server>,
2168 notifications: mpsc::Receiver<()>,
2169 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2170 forbid_connections: Arc<AtomicBool>,
2171 _test_db: TestDb,
2172 }
2173
2174 impl TestServer {
2175 async fn start() -> Self {
2176 let test_db = TestDb::new();
2177 let app_state = Self::build_app_state(&test_db).await;
2178 let peer = Peer::new();
2179 let notifications = mpsc::channel(128);
2180 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2181 Self {
2182 peer,
2183 app_state,
2184 server,
2185 notifications: notifications.1,
2186 connection_killers: Default::default(),
2187 forbid_connections: Default::default(),
2188 _test_db: test_db,
2189 }
2190 }
2191
2192 async fn create_client(
2193 &mut self,
2194 cx: &mut TestAppContext,
2195 name: &str,
2196 ) -> (Arc<Client>, Arc<UserStore>) {
2197 let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2198 let client_name = name.to_string();
2199 let mut client = Client::new();
2200 let server = self.server.clone();
2201 let connection_killers = self.connection_killers.clone();
2202 let forbid_connections = self.forbid_connections.clone();
2203 Arc::get_mut(&mut client)
2204 .unwrap()
2205 .override_authenticate(move |cx| {
2206 cx.spawn(|_| async move {
2207 let access_token = "the-token".to_string();
2208 Ok(Credentials {
2209 user_id: user_id.0 as u64,
2210 access_token,
2211 })
2212 })
2213 })
2214 .override_establish_connection(move |credentials, cx| {
2215 assert_eq!(credentials.user_id, user_id.0 as u64);
2216 assert_eq!(credentials.access_token, "the-token");
2217
2218 let server = server.clone();
2219 let connection_killers = connection_killers.clone();
2220 let forbid_connections = forbid_connections.clone();
2221 let client_name = client_name.clone();
2222 cx.spawn(move |cx| async move {
2223 if forbid_connections.load(SeqCst) {
2224 Err(EstablishConnectionError::other(anyhow!(
2225 "server is forbidding connections"
2226 )))
2227 } else {
2228 let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2229 connection_killers.lock().insert(user_id, kill_conn);
2230 cx.background()
2231 .spawn(server.handle_connection(server_conn, client_name, user_id))
2232 .detach();
2233 Ok(client_conn)
2234 }
2235 })
2236 });
2237
2238 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2239 client
2240 .authenticate_and_connect(&cx.to_async())
2241 .await
2242 .unwrap();
2243
2244 let user_store = UserStore::new(client.clone(), http, &cx.background());
2245 let mut authed_user = user_store.watch_current_user();
2246 while authed_user.recv().await.unwrap().is_none() {}
2247
2248 (client, user_store)
2249 }
2250
2251 fn disconnect_client(&self, user_id: UserId) {
2252 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2253 let _ = kill_conn.try_send(Some(()));
2254 }
2255 }
2256
2257 fn forbid_connections(&self) {
2258 self.forbid_connections.store(true, SeqCst);
2259 }
2260
2261 fn allow_connections(&self) {
2262 self.forbid_connections.store(false, SeqCst);
2263 }
2264
2265 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2266 let mut config = Config::default();
2267 config.session_secret = "a".repeat(32);
2268 config.database_url = test_db.url.clone();
2269 let github_client = github::AppClient::test();
2270 Arc::new(AppState {
2271 db: test_db.db().clone(),
2272 handlebars: Default::default(),
2273 auth_client: auth::build_client("", ""),
2274 repo_client: github::RepoClient::test(&github_client),
2275 github_client,
2276 config,
2277 })
2278 }
2279
2280 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2281 self.server.state.read().await
2282 }
2283
2284 async fn condition<F>(&mut self, mut predicate: F)
2285 where
2286 F: FnMut(&ServerState) -> bool,
2287 {
2288 async_std::future::timeout(Duration::from_millis(500), async {
2289 while !(predicate)(&*self.server.state.read().await) {
2290 self.notifications.recv().await;
2291 }
2292 })
2293 .await
2294 .expect("condition timed out");
2295 }
2296 }
2297
2298 impl Drop for TestServer {
2299 fn drop(&mut self) {
2300 task::block_on(self.peer.reset());
2301 }
2302 }
2303
2304 fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2305 UserId::from_proto(user_store.current_user().unwrap().id)
2306 }
2307
2308 fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2309 channel
2310 .messages()
2311 .cursor::<(), ()>()
2312 .map(|m| {
2313 (
2314 m.sender.github_login.clone(),
2315 m.body.clone(),
2316 m.is_pending(),
2317 )
2318 })
2319 .collect()
2320 }
2321
2322 struct EmptyView;
2323
2324 impl gpui::Entity for EmptyView {
2325 type Event = ();
2326 }
2327
2328 impl gpui::View for EmptyView {
2329 fn ui_name() -> &'static str {
2330 "empty view"
2331 }
2332
2333 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2334 gpui::Element::boxed(gpui::elements::Empty)
2335 }
2336 }
2337}