1use super::{
2 auth,
3 db::{ChannelId, MessageId, User, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
9use futures::{future::BoxFuture, FutureExt};
10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
11use sha1::{Digest as _, Sha1};
12use std::{
13 any::TypeId,
14 collections::{hash_map, HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use time::OffsetDateTime;
27use zrpc::{
28 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
29 Connection, ConnectionId, Peer, TypedEnvelope,
30};
31
32type ReplicaId = u16;
33
34type MessageHandler = Box<
35 dyn Send
36 + Sync
37 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
38>;
39
40pub struct Server {
41 peer: Arc<Peer>,
42 state: RwLock<ServerState>,
43 app_state: Arc<AppState>,
44 handlers: HashMap<TypeId, MessageHandler>,
45 notifications: Option<mpsc::Sender<()>>,
46}
47
48#[derive(Default)]
49struct ServerState {
50 connections: HashMap<ConnectionId, ConnectionState>,
51 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
52 pub worktrees: HashMap<u64, Worktree>,
53 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
54 channels: HashMap<ChannelId, Channel>,
55 next_worktree_id: u64,
56}
57
58struct ConnectionState {
59 user_id: UserId,
60 worktrees: HashSet<u64>,
61 channels: HashSet<ChannelId>,
62}
63
64struct Worktree {
65 host_connection_id: ConnectionId,
66 collaborator_user_ids: Vec<UserId>,
67 root_name: String,
68 share: Option<WorktreeShare>,
69}
70
71struct WorktreeShare {
72 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
73 active_replica_ids: HashSet<ReplicaId>,
74 entries: HashMap<u64, proto::Entry>,
75}
76
77#[derive(Default)]
78struct Channel {
79 connection_ids: HashSet<ConnectionId>,
80}
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84
85impl Server {
86 pub fn new(
87 app_state: Arc<AppState>,
88 peer: Arc<Peer>,
89 notifications: Option<mpsc::Sender<()>>,
90 ) -> Arc<Self> {
91 let mut server = Self {
92 peer,
93 app_state,
94 state: Default::default(),
95 handlers: Default::default(),
96 notifications,
97 };
98
99 server
100 .add_handler(Server::ping)
101 .add_handler(Server::open_worktree)
102 .add_handler(Server::close_worktree)
103 .add_handler(Server::share_worktree)
104 .add_handler(Server::unshare_worktree)
105 .add_handler(Server::join_worktree)
106 .add_handler(Server::update_worktree)
107 .add_handler(Server::open_buffer)
108 .add_handler(Server::close_buffer)
109 .add_handler(Server::update_buffer)
110 .add_handler(Server::buffer_saved)
111 .add_handler(Server::save_buffer)
112 .add_handler(Server::get_channels)
113 .add_handler(Server::get_users)
114 .add_handler(Server::join_channel)
115 .add_handler(Server::leave_channel)
116 .add_handler(Server::send_channel_message)
117 .add_handler(Server::get_channel_messages)
118 .add_handler(Server::get_collaborators);
119
120 Arc::new(server)
121 }
122
123 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
124 where
125 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
126 Fut: 'static + Send + Future<Output = tide::Result<()>>,
127 M: EnvelopedMessage,
128 {
129 let prev_handler = self.handlers.insert(
130 TypeId::of::<M>(),
131 Box::new(move |server, envelope| {
132 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
133 (handler)(server, *envelope).boxed()
134 }),
135 );
136 if prev_handler.is_some() {
137 panic!("registered a handler for the same message twice");
138 }
139 self
140 }
141
142 pub fn handle_connection(
143 self: &Arc<Self>,
144 connection: Connection,
145 addr: String,
146 user_id: UserId,
147 ) -> impl Future<Output = ()> {
148 let this = self.clone();
149 async move {
150 let (connection_id, handle_io, mut incoming_rx) =
151 this.peer.add_connection(connection).await;
152 this.add_connection(connection_id, user_id).await;
153
154 let handle_io = handle_io.fuse();
155 futures::pin_mut!(handle_io);
156 loop {
157 let next_message = incoming_rx.recv().fuse();
158 futures::pin_mut!(next_message);
159 futures::select_biased! {
160 message = next_message => {
161 if let Some(message) = message {
162 let start_time = Instant::now();
163 log::info!("RPC message received: {}", message.payload_type_name());
164 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
165 if let Err(err) = (handler)(this.clone(), message).await {
166 log::error!("error handling message: {:?}", err);
167 } else {
168 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
169 }
170
171 if let Some(mut notifications) = this.notifications.clone() {
172 let _ = notifications.send(()).await;
173 }
174 } else {
175 log::warn!("unhandled message: {}", message.payload_type_name());
176 }
177 } else {
178 log::info!("rpc connection closed {:?}", addr);
179 break;
180 }
181 }
182 handle_io = handle_io => {
183 if let Err(err) = handle_io {
184 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
185 }
186 break;
187 }
188 }
189 }
190
191 if let Err(err) = this.sign_out(connection_id).await {
192 log::error!("error signing out connection {:?} - {:?}", addr, err);
193 }
194 }
195 }
196
197 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
198 self.peer.disconnect(connection_id).await;
199 let worktree_ids = self.remove_connection(connection_id).await;
200 for worktree_id in worktree_ids {
201 let state = self.state.read().await;
202 if let Some(worktree) = state.worktrees.get(&worktree_id) {
203 broadcast(connection_id, worktree.connection_ids(), |conn_id| {
204 self.peer.send(
205 conn_id,
206 proto::RemovePeer {
207 worktree_id,
208 peer_id: connection_id.0,
209 },
210 )
211 })
212 .await?;
213 }
214 }
215 Ok(())
216 }
217
218 // Add a new connection associated with a given user.
219 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
220 let mut state = self.state.write().await;
221 state.connections.insert(
222 connection_id,
223 ConnectionState {
224 user_id,
225 worktrees: Default::default(),
226 channels: Default::default(),
227 },
228 );
229 state
230 .connections_by_user_id
231 .entry(user_id)
232 .or_default()
233 .insert(connection_id);
234 }
235
236 // Remove the given connection and its association with any worktrees.
237 async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
238 let mut worktree_ids = Vec::new();
239 let mut state = self.state.write().await;
240 if let Some(connection) = state.connections.remove(&connection_id) {
241 for channel_id in connection.channels {
242 if let Some(channel) = state.channels.get_mut(&channel_id) {
243 channel.connection_ids.remove(&connection_id);
244 }
245 }
246 for worktree_id in connection.worktrees {
247 if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
248 if worktree.host_connection_id == connection_id {
249 worktree_ids.push(worktree_id);
250 } else if let Some(share_state) = worktree.share.as_mut() {
251 if let Some(replica_id) =
252 share_state.guest_connection_ids.remove(&connection_id)
253 {
254 share_state.active_replica_ids.remove(&replica_id);
255 worktree_ids.push(worktree_id);
256 }
257 }
258 }
259 }
260
261 let user_connections = state
262 .connections_by_user_id
263 .get_mut(&connection.user_id)
264 .unwrap();
265 user_connections.remove(&connection_id);
266 if user_connections.is_empty() {
267 state.connections_by_user_id.remove(&connection.user_id);
268 }
269 }
270 worktree_ids
271 }
272
273 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
274 self.peer.respond(request.receipt(), proto::Ack {}).await?;
275 Ok(())
276 }
277
278 async fn open_worktree(
279 self: Arc<Server>,
280 request: TypedEnvelope<proto::OpenWorktree>,
281 ) -> tide::Result<()> {
282 let receipt = request.receipt();
283
284 let mut collaborator_user_ids = Vec::new();
285 for github_login in request.payload.collaborator_logins {
286 match self.app_state.db.create_user(&github_login, false).await {
287 Ok(user_id) => collaborator_user_ids.push(user_id),
288 Err(err) => {
289 let message = err.to_string();
290 self.peer
291 .respond_with_error(receipt, proto::Error { message })
292 .await?;
293 return Ok(());
294 }
295 }
296 }
297
298 let mut state = self.state.write().await;
299 let worktree_id = state.add_worktree(Worktree {
300 host_connection_id: request.sender_id,
301 collaborator_user_ids,
302 root_name: request.payload.root_name,
303 share: None,
304 });
305
306 self.peer
307 .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
308 .await?;
309 Ok(())
310 }
311
312 async fn share_worktree(
313 self: Arc<Server>,
314 mut request: TypedEnvelope<proto::ShareWorktree>,
315 ) -> tide::Result<()> {
316 let worktree = request
317 .payload
318 .worktree
319 .as_mut()
320 .ok_or_else(|| anyhow!("missing worktree"))?;
321 let entries = mem::take(&mut worktree.entries)
322 .into_iter()
323 .map(|entry| (entry.id, entry))
324 .collect();
325 let mut state = self.state.write().await;
326 if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
327 worktree.share = Some(WorktreeShare {
328 guest_connection_ids: Default::default(),
329 active_replica_ids: Default::default(),
330 entries,
331 });
332 self.peer
333 .respond(request.receipt(), proto::ShareWorktreeResponse {})
334 .await?;
335 } else {
336 self.peer
337 .respond_with_error(
338 request.receipt(),
339 proto::Error {
340 message: "no such worktree".to_string(),
341 },
342 )
343 .await?;
344 }
345 Ok(())
346 }
347
348 async fn unshare_worktree(
349 self: Arc<Server>,
350 request: TypedEnvelope<proto::UnshareWorktree>,
351 ) -> tide::Result<()> {
352 let worktree_id = request.payload.worktree_id;
353
354 let connection_ids;
355 {
356 let mut state = self.state.write().await;
357 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
358 if worktree.host_connection_id != request.sender_id {
359 return Err(anyhow!("no such worktree"))?;
360 }
361
362 connection_ids = worktree.connection_ids();
363 worktree.share.take();
364 for connection_id in &connection_ids {
365 if let Some(connection) = state.connections.get_mut(connection_id) {
366 connection.worktrees.remove(&worktree_id);
367 }
368 }
369 }
370
371 broadcast(request.sender_id, connection_ids, |conn_id| {
372 self.peer
373 .send(conn_id, proto::UnshareWorktree { worktree_id })
374 })
375 .await?;
376
377 Ok(())
378 }
379
380 async fn join_worktree(
381 self: Arc<Server>,
382 request: TypedEnvelope<proto::JoinWorktree>,
383 ) -> tide::Result<()> {
384 let worktree_id = request.payload.worktree_id;
385 let user_id = self
386 .state
387 .read()
388 .await
389 .user_id_for_connection(request.sender_id)?;
390
391 let response;
392 let connection_ids;
393 let mut state = self.state.write().await;
394 match state.join_worktree(request.sender_id, user_id, worktree_id) {
395 Ok((peer_replica_id, worktree)) => {
396 let share = worktree.share()?;
397 let peer_count = share.guest_connection_ids.len();
398 let mut peers = Vec::with_capacity(peer_count);
399 peers.push(proto::Peer {
400 peer_id: worktree.host_connection_id.0,
401 replica_id: 0,
402 });
403 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
404 if *peer_conn_id != request.sender_id {
405 peers.push(proto::Peer {
406 peer_id: peer_conn_id.0,
407 replica_id: *peer_replica_id as u32,
408 });
409 }
410 }
411 connection_ids = worktree.connection_ids();
412 response = proto::JoinWorktreeResponse {
413 worktree: Some(proto::Worktree {
414 id: worktree_id,
415 root_name: worktree.root_name.clone(),
416 entries: share.entries.values().cloned().collect(),
417 }),
418 replica_id: peer_replica_id as u32,
419 peers,
420 };
421 }
422 Err(error) => {
423 self.peer
424 .respond_with_error(
425 request.receipt(),
426 proto::Error {
427 message: error.to_string(),
428 },
429 )
430 .await?;
431 return Ok(());
432 }
433 }
434
435 broadcast(request.sender_id, connection_ids, |conn_id| {
436 self.peer.send(
437 conn_id,
438 proto::AddPeer {
439 worktree_id,
440 peer: Some(proto::Peer {
441 peer_id: request.sender_id.0,
442 replica_id: response.replica_id,
443 }),
444 },
445 )
446 })
447 .await?;
448 self.peer.respond(request.receipt(), response).await?;
449
450 Ok(())
451 }
452
453 async fn close_worktree(
454 self: Arc<Server>,
455 request: TypedEnvelope<proto::CloseWorktree>,
456 ) -> tide::Result<()> {
457 let worktree_id = request.payload.worktree_id;
458 let connection_ids;
459 let mut is_host = false;
460 let mut is_guest = false;
461 {
462 let mut state = self.state.write().await;
463 let worktree = state.write_worktree(worktree_id, request.sender_id)?;
464 connection_ids = worktree.connection_ids();
465
466 if worktree.host_connection_id == request.sender_id {
467 is_host = true;
468 state.remove_worktree(worktree_id);
469 } else {
470 let share = worktree.share_mut()?;
471 if let Some(replica_id) = share.guest_connection_ids.remove(&request.sender_id) {
472 is_guest = true;
473 share.active_replica_ids.remove(&replica_id);
474 }
475 }
476 }
477
478 if is_host {
479 broadcast(request.sender_id, connection_ids, |conn_id| {
480 self.peer
481 .send(conn_id, proto::UnshareWorktree { worktree_id })
482 })
483 .await?;
484 } else if is_guest {
485 broadcast(request.sender_id, connection_ids, |conn_id| {
486 self.peer.send(
487 conn_id,
488 proto::RemovePeer {
489 worktree_id,
490 peer_id: request.sender_id.0,
491 },
492 )
493 })
494 .await?
495 }
496
497 Ok(())
498 }
499
500 async fn update_worktree(
501 self: Arc<Server>,
502 request: TypedEnvelope<proto::UpdateWorktree>,
503 ) -> tide::Result<()> {
504 {
505 let mut state = self.state.write().await;
506 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
507 let share = worktree.share_mut()?;
508
509 for entry_id in &request.payload.removed_entries {
510 share.entries.remove(&entry_id);
511 }
512
513 for entry in &request.payload.updated_entries {
514 share.entries.insert(entry.id, entry.clone());
515 }
516 }
517
518 self.broadcast_in_worktree(request.payload.worktree_id, &request)
519 .await?;
520 Ok(())
521 }
522
523 async fn open_buffer(
524 self: Arc<Server>,
525 request: TypedEnvelope<proto::OpenBuffer>,
526 ) -> tide::Result<()> {
527 let receipt = request.receipt();
528 let worktree_id = request.payload.worktree_id;
529 let host_connection_id = self
530 .state
531 .read()
532 .await
533 .read_worktree(worktree_id, request.sender_id)?
534 .host_connection_id;
535
536 let response = self
537 .peer
538 .forward_request(request.sender_id, host_connection_id, request.payload)
539 .await?;
540 self.peer.respond(receipt, response).await?;
541 Ok(())
542 }
543
544 async fn close_buffer(
545 self: Arc<Server>,
546 request: TypedEnvelope<proto::CloseBuffer>,
547 ) -> tide::Result<()> {
548 let host_connection_id = self
549 .state
550 .read()
551 .await
552 .read_worktree(request.payload.worktree_id, request.sender_id)?
553 .host_connection_id;
554
555 self.peer
556 .forward_send(request.sender_id, host_connection_id, request.payload)
557 .await?;
558
559 Ok(())
560 }
561
562 async fn save_buffer(
563 self: Arc<Server>,
564 request: TypedEnvelope<proto::SaveBuffer>,
565 ) -> tide::Result<()> {
566 let host;
567 let guests;
568 {
569 let state = self.state.read().await;
570 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
571 host = worktree.host_connection_id;
572 guests = worktree
573 .share()?
574 .guest_connection_ids
575 .keys()
576 .copied()
577 .collect::<Vec<_>>();
578 }
579
580 let sender = request.sender_id;
581 let receipt = request.receipt();
582 let response = self
583 .peer
584 .forward_request(sender, host, request.payload.clone())
585 .await?;
586
587 broadcast(host, guests, |conn_id| {
588 let response = response.clone();
589 let peer = &self.peer;
590 async move {
591 if conn_id == sender {
592 peer.respond(receipt, response).await
593 } else {
594 peer.forward_send(host, conn_id, response).await
595 }
596 }
597 })
598 .await?;
599
600 Ok(())
601 }
602
603 async fn update_buffer(
604 self: Arc<Server>,
605 request: TypedEnvelope<proto::UpdateBuffer>,
606 ) -> tide::Result<()> {
607 self.broadcast_in_worktree(request.payload.worktree_id, &request)
608 .await?;
609 self.peer.respond(request.receipt(), proto::Ack {}).await?;
610 Ok(())
611 }
612
613 async fn buffer_saved(
614 self: Arc<Server>,
615 request: TypedEnvelope<proto::BufferSaved>,
616 ) -> tide::Result<()> {
617 self.broadcast_in_worktree(request.payload.worktree_id, &request)
618 .await
619 }
620
621 async fn get_channels(
622 self: Arc<Server>,
623 request: TypedEnvelope<proto::GetChannels>,
624 ) -> tide::Result<()> {
625 let user_id = self
626 .state
627 .read()
628 .await
629 .user_id_for_connection(request.sender_id)?;
630 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
631 self.peer
632 .respond(
633 request.receipt(),
634 proto::GetChannelsResponse {
635 channels: channels
636 .into_iter()
637 .map(|chan| proto::Channel {
638 id: chan.id.to_proto(),
639 name: chan.name,
640 })
641 .collect(),
642 },
643 )
644 .await?;
645 Ok(())
646 }
647
648 async fn get_users(
649 self: Arc<Server>,
650 request: TypedEnvelope<proto::GetUsers>,
651 ) -> tide::Result<()> {
652 let user_id = self
653 .state
654 .read()
655 .await
656 .user_id_for_connection(request.sender_id)?;
657 let receipt = request.receipt();
658 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
659 let users = self
660 .app_state
661 .db
662 .get_users_by_ids(user_id, user_ids)
663 .await?
664 .into_iter()
665 .map(|user| proto::User {
666 id: user.id.to_proto(),
667 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
668 github_login: user.github_login,
669 })
670 .collect();
671 self.peer
672 .respond(receipt, proto::GetUsersResponse { users })
673 .await?;
674 Ok(())
675 }
676
677 async fn get_collaborators(
678 self: Arc<Server>,
679 request: TypedEnvelope<proto::GetCollaborators>,
680 ) -> tide::Result<()> {
681 let mut collaborators = HashMap::new();
682 {
683 let state = self.state.read().await;
684 let user_id = state.user_id_for_connection(request.sender_id)?;
685 for worktree_id in state
686 .visible_worktrees_by_user_id
687 .get(&user_id)
688 .unwrap_or(&HashSet::new())
689 {
690 let worktree = &state.worktrees[worktree_id];
691
692 let mut participants = Vec::new();
693 for collaborator_user_id in &worktree.collaborator_user_ids {
694 collaborators
695 .entry(*collaborator_user_id)
696 .or_insert_with(|| proto::Collaborator {
697 user_id: collaborator_user_id.to_proto(),
698 worktrees: Vec::new(),
699 is_online: state.is_online(*collaborator_user_id),
700 });
701
702 if let Ok(share) = worktree.share() {
703 let mut conn_ids = state.user_connection_ids(*collaborator_user_id);
704 if conn_ids.any(|c| share.guest_connection_ids.contains_key(&c)) {
705 participants.push(collaborator_user_id.to_proto());
706 }
707 }
708 }
709
710 let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
711 let host =
712 collaborators
713 .entry(host_user_id)
714 .or_insert_with(|| proto::Collaborator {
715 user_id: host_user_id.to_proto(),
716 worktrees: Vec::new(),
717 is_online: true,
718 });
719 host.worktrees.push(proto::CollaboratorWorktree {
720 is_shared: worktree.share().is_ok(),
721 participants,
722 });
723 }
724 }
725
726 self.peer
727 .respond(
728 request.receipt(),
729 proto::GetCollaboratorsResponse {
730 collaborators: collaborators.into_values().collect(),
731 },
732 )
733 .await?;
734 Ok(())
735 }
736
737 async fn join_channel(
738 self: Arc<Self>,
739 request: TypedEnvelope<proto::JoinChannel>,
740 ) -> tide::Result<()> {
741 let user_id = self
742 .state
743 .read()
744 .await
745 .user_id_for_connection(request.sender_id)?;
746 let channel_id = ChannelId::from_proto(request.payload.channel_id);
747 if !self
748 .app_state
749 .db
750 .can_user_access_channel(user_id, channel_id)
751 .await?
752 {
753 Err(anyhow!("access denied"))?;
754 }
755
756 self.state
757 .write()
758 .await
759 .join_channel(request.sender_id, channel_id);
760 let messages = self
761 .app_state
762 .db
763 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
764 .await?
765 .into_iter()
766 .map(|msg| proto::ChannelMessage {
767 id: msg.id.to_proto(),
768 body: msg.body,
769 timestamp: msg.sent_at.unix_timestamp() as u64,
770 sender_id: msg.sender_id.to_proto(),
771 nonce: Some(msg.nonce.as_u128().into()),
772 })
773 .collect::<Vec<_>>();
774 self.peer
775 .respond(
776 request.receipt(),
777 proto::JoinChannelResponse {
778 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
779 messages,
780 },
781 )
782 .await?;
783 Ok(())
784 }
785
786 async fn leave_channel(
787 self: Arc<Self>,
788 request: TypedEnvelope<proto::LeaveChannel>,
789 ) -> tide::Result<()> {
790 let user_id = self
791 .state
792 .read()
793 .await
794 .user_id_for_connection(request.sender_id)?;
795 let channel_id = ChannelId::from_proto(request.payload.channel_id);
796 if !self
797 .app_state
798 .db
799 .can_user_access_channel(user_id, channel_id)
800 .await?
801 {
802 Err(anyhow!("access denied"))?;
803 }
804
805 self.state
806 .write()
807 .await
808 .leave_channel(request.sender_id, channel_id);
809
810 Ok(())
811 }
812
813 async fn send_channel_message(
814 self: Arc<Self>,
815 request: TypedEnvelope<proto::SendChannelMessage>,
816 ) -> tide::Result<()> {
817 let receipt = request.receipt();
818 let channel_id = ChannelId::from_proto(request.payload.channel_id);
819 let user_id;
820 let connection_ids;
821 {
822 let state = self.state.read().await;
823 user_id = state.user_id_for_connection(request.sender_id)?;
824 if let Some(channel) = state.channels.get(&channel_id) {
825 connection_ids = channel.connection_ids();
826 } else {
827 return Ok(());
828 }
829 }
830
831 // Validate the message body.
832 let body = request.payload.body.trim().to_string();
833 if body.len() > MAX_MESSAGE_LEN {
834 self.peer
835 .respond_with_error(
836 receipt,
837 proto::Error {
838 message: "message is too long".to_string(),
839 },
840 )
841 .await?;
842 return Ok(());
843 }
844 if body.is_empty() {
845 self.peer
846 .respond_with_error(
847 receipt,
848 proto::Error {
849 message: "message can't be blank".to_string(),
850 },
851 )
852 .await?;
853 return Ok(());
854 }
855
856 let timestamp = OffsetDateTime::now_utc();
857 let nonce = if let Some(nonce) = request.payload.nonce {
858 nonce
859 } else {
860 self.peer
861 .respond_with_error(
862 receipt,
863 proto::Error {
864 message: "nonce can't be blank".to_string(),
865 },
866 )
867 .await?;
868 return Ok(());
869 };
870
871 let message_id = self
872 .app_state
873 .db
874 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
875 .await?
876 .to_proto();
877 let message = proto::ChannelMessage {
878 sender_id: user_id.to_proto(),
879 id: message_id,
880 body,
881 timestamp: timestamp.unix_timestamp() as u64,
882 nonce: Some(nonce),
883 };
884 broadcast(request.sender_id, connection_ids, |conn_id| {
885 self.peer.send(
886 conn_id,
887 proto::ChannelMessageSent {
888 channel_id: channel_id.to_proto(),
889 message: Some(message.clone()),
890 },
891 )
892 })
893 .await?;
894 self.peer
895 .respond(
896 receipt,
897 proto::SendChannelMessageResponse {
898 message: Some(message),
899 },
900 )
901 .await?;
902 Ok(())
903 }
904
905 async fn get_channel_messages(
906 self: Arc<Self>,
907 request: TypedEnvelope<proto::GetChannelMessages>,
908 ) -> tide::Result<()> {
909 let user_id = self
910 .state
911 .read()
912 .await
913 .user_id_for_connection(request.sender_id)?;
914 let channel_id = ChannelId::from_proto(request.payload.channel_id);
915 if !self
916 .app_state
917 .db
918 .can_user_access_channel(user_id, channel_id)
919 .await?
920 {
921 Err(anyhow!("access denied"))?;
922 }
923
924 let messages = self
925 .app_state
926 .db
927 .get_channel_messages(
928 channel_id,
929 MESSAGE_COUNT_PER_PAGE,
930 Some(MessageId::from_proto(request.payload.before_message_id)),
931 )
932 .await?
933 .into_iter()
934 .map(|msg| proto::ChannelMessage {
935 id: msg.id.to_proto(),
936 body: msg.body,
937 timestamp: msg.sent_at.unix_timestamp() as u64,
938 sender_id: msg.sender_id.to_proto(),
939 nonce: Some(msg.nonce.as_u128().into()),
940 })
941 .collect::<Vec<_>>();
942 self.peer
943 .respond(
944 request.receipt(),
945 proto::GetChannelMessagesResponse {
946 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
947 messages,
948 },
949 )
950 .await?;
951 Ok(())
952 }
953
954 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
955 &self,
956 worktree_id: u64,
957 message: &TypedEnvelope<T>,
958 ) -> tide::Result<()> {
959 let connection_ids = self
960 .state
961 .read()
962 .await
963 .read_worktree(worktree_id, message.sender_id)?
964 .connection_ids();
965
966 broadcast(message.sender_id, connection_ids, |conn_id| {
967 self.peer
968 .forward_send(message.sender_id, conn_id, message.payload.clone())
969 })
970 .await?;
971
972 Ok(())
973 }
974}
975
976pub async fn broadcast<F, T>(
977 sender_id: ConnectionId,
978 receiver_ids: Vec<ConnectionId>,
979 mut f: F,
980) -> anyhow::Result<()>
981where
982 F: FnMut(ConnectionId) -> T,
983 T: Future<Output = anyhow::Result<()>>,
984{
985 let futures = receiver_ids
986 .into_iter()
987 .filter(|id| *id != sender_id)
988 .map(|id| f(id));
989 futures::future::try_join_all(futures).await?;
990 Ok(())
991}
992
993impl ServerState {
994 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
995 if let Some(connection) = self.connections.get_mut(&connection_id) {
996 connection.channels.insert(channel_id);
997 self.channels
998 .entry(channel_id)
999 .or_default()
1000 .connection_ids
1001 .insert(connection_id);
1002 }
1003 }
1004
1005 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1006 if let Some(connection) = self.connections.get_mut(&connection_id) {
1007 connection.channels.remove(&channel_id);
1008 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1009 entry.get_mut().connection_ids.remove(&connection_id);
1010 if entry.get_mut().connection_ids.is_empty() {
1011 entry.remove();
1012 }
1013 }
1014 }
1015 }
1016
1017 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1018 Ok(self
1019 .connections
1020 .get(&connection_id)
1021 .ok_or_else(|| anyhow!("unknown connection"))?
1022 .user_id)
1023 }
1024
1025 fn user_connection_ids<'a>(
1026 &'a self,
1027 user_id: UserId,
1028 ) -> impl 'a + Iterator<Item = ConnectionId> {
1029 self.connections_by_user_id
1030 .get(&user_id)
1031 .into_iter()
1032 .flatten()
1033 .copied()
1034 }
1035
1036 fn is_online(&self, user_id: UserId) -> bool {
1037 self.connections_by_user_id.contains_key(&user_id)
1038 }
1039
1040 // Add the given connection as a guest of the given worktree
1041 fn join_worktree(
1042 &mut self,
1043 connection_id: ConnectionId,
1044 user_id: UserId,
1045 worktree_id: u64,
1046 ) -> tide::Result<(ReplicaId, &Worktree)> {
1047 let connection = self
1048 .connections
1049 .get_mut(&connection_id)
1050 .ok_or_else(|| anyhow!("no such connection"))?;
1051 let worktree = self
1052 .worktrees
1053 .get_mut(&worktree_id)
1054 .ok_or_else(|| anyhow!("no such worktree"))?;
1055 if !worktree.collaborator_user_ids.contains(&user_id) {
1056 Err(anyhow!("no such worktree"))?;
1057 }
1058
1059 let share = worktree.share_mut()?;
1060 connection.worktrees.insert(worktree_id);
1061
1062 let mut replica_id = 1;
1063 while share.active_replica_ids.contains(&replica_id) {
1064 replica_id += 1;
1065 }
1066 share.active_replica_ids.insert(replica_id);
1067 share.guest_connection_ids.insert(connection_id, replica_id);
1068 return Ok((replica_id, worktree));
1069 }
1070
1071 fn read_worktree(
1072 &self,
1073 worktree_id: u64,
1074 connection_id: ConnectionId,
1075 ) -> tide::Result<&Worktree> {
1076 let worktree = self
1077 .worktrees
1078 .get(&worktree_id)
1079 .ok_or_else(|| anyhow!("worktree not found"))?;
1080
1081 if worktree.host_connection_id == connection_id
1082 || worktree
1083 .share()?
1084 .guest_connection_ids
1085 .contains_key(&connection_id)
1086 {
1087 Ok(worktree)
1088 } else {
1089 Err(anyhow!(
1090 "{} is not a member of worktree {}",
1091 connection_id,
1092 worktree_id
1093 ))?
1094 }
1095 }
1096
1097 fn write_worktree(
1098 &mut self,
1099 worktree_id: u64,
1100 connection_id: ConnectionId,
1101 ) -> tide::Result<&mut Worktree> {
1102 let worktree = self
1103 .worktrees
1104 .get_mut(&worktree_id)
1105 .ok_or_else(|| anyhow!("worktree not found"))?;
1106
1107 if worktree.host_connection_id == connection_id
1108 || worktree.share.as_ref().map_or(false, |share| {
1109 share.guest_connection_ids.contains_key(&connection_id)
1110 })
1111 {
1112 Ok(worktree)
1113 } else {
1114 Err(anyhow!(
1115 "{} is not a member of worktree {}",
1116 connection_id,
1117 worktree_id
1118 ))?
1119 }
1120 }
1121
1122 fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1123 let worktree_id = self.next_worktree_id;
1124 for collaborator_user_id in &worktree.collaborator_user_ids {
1125 self.visible_worktrees_by_user_id
1126 .entry(*collaborator_user_id)
1127 .or_default()
1128 .insert(worktree_id);
1129 }
1130 self.next_worktree_id += 1;
1131 self.worktrees.insert(worktree_id, worktree);
1132 worktree_id
1133 }
1134
1135 fn remove_worktree(&mut self, worktree_id: u64) {
1136 let worktree = self.worktrees.remove(&worktree_id).unwrap();
1137 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1138 connection.worktrees.remove(&worktree_id);
1139 }
1140 if let Some(share) = worktree.share {
1141 for connection_id in share.guest_connection_ids.keys() {
1142 if let Some(connection) = self.connections.get_mut(connection_id) {
1143 connection.worktrees.remove(&worktree_id);
1144 }
1145 }
1146 }
1147 for collaborator_user_id in worktree.collaborator_user_ids {
1148 if let Some(visible_worktrees) = self
1149 .visible_worktrees_by_user_id
1150 .get_mut(&collaborator_user_id)
1151 {
1152 visible_worktrees.remove(&worktree_id);
1153 }
1154 }
1155 }
1156}
1157
1158impl Worktree {
1159 pub fn connection_ids(&self) -> Vec<ConnectionId> {
1160 if let Some(share) = &self.share {
1161 share
1162 .guest_connection_ids
1163 .keys()
1164 .copied()
1165 .chain(Some(self.host_connection_id))
1166 .collect()
1167 } else {
1168 vec![self.host_connection_id]
1169 }
1170 }
1171
1172 fn share(&self) -> tide::Result<&WorktreeShare> {
1173 Ok(self
1174 .share
1175 .as_ref()
1176 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1177 }
1178
1179 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1180 Ok(self
1181 .share
1182 .as_mut()
1183 .ok_or_else(|| anyhow!("worktree is not shared"))?)
1184 }
1185}
1186
1187impl Channel {
1188 fn connection_ids(&self) -> Vec<ConnectionId> {
1189 self.connection_ids.iter().copied().collect()
1190 }
1191}
1192
1193pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1194 let server = Server::new(app.state().clone(), rpc.clone(), None);
1195 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1196 let user_id = request.ext::<UserId>().copied();
1197 let server = server.clone();
1198 async move {
1199 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1200
1201 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1202 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1203 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1204
1205 if !upgrade_requested {
1206 return Ok(Response::new(StatusCode::UpgradeRequired));
1207 }
1208
1209 let header = match request.header("Sec-Websocket-Key") {
1210 Some(h) => h.as_str(),
1211 None => return Err(anyhow!("expected sec-websocket-key"))?,
1212 };
1213
1214 let mut response = Response::new(StatusCode::SwitchingProtocols);
1215 response.insert_header(UPGRADE, "websocket");
1216 response.insert_header(CONNECTION, "Upgrade");
1217 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1218 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1219 response.insert_header("Sec-Websocket-Version", "13");
1220
1221 let http_res: &mut tide::http::Response = response.as_mut();
1222 let upgrade_receiver = http_res.recv_upgrade().await;
1223 let addr = request.remote().unwrap_or("unknown").to_string();
1224 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1225 task::spawn(async move {
1226 if let Some(stream) = upgrade_receiver.await {
1227 server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1228 }
1229 });
1230
1231 Ok(response)
1232 }
1233 });
1234}
1235
1236fn header_contains_ignore_case<T>(
1237 request: &tide::Request<T>,
1238 header_name: HeaderName,
1239 value: &str,
1240) -> bool {
1241 request
1242 .header(header_name)
1243 .map(|h| {
1244 h.as_str()
1245 .split(',')
1246 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1247 })
1248 .unwrap_or(false)
1249}
1250
1251#[cfg(test)]
1252mod tests {
1253 use super::*;
1254 use crate::{
1255 auth,
1256 db::{tests::TestDb, UserId},
1257 github, AppState, Config,
1258 };
1259 use async_std::{sync::RwLockReadGuard, task};
1260 use gpui::TestAppContext;
1261 use parking_lot::Mutex;
1262 use postage::{mpsc, watch};
1263 use serde_json::json;
1264 use sqlx::types::time::OffsetDateTime;
1265 use std::{
1266 path::Path,
1267 sync::{
1268 atomic::{AtomicBool, Ordering::SeqCst},
1269 Arc,
1270 },
1271 time::Duration,
1272 };
1273 use zed::{
1274 channel::{Channel, ChannelDetails, ChannelList},
1275 editor::{Editor, Insert},
1276 fs::{FakeFs, Fs as _},
1277 language::LanguageRegistry,
1278 rpc::{self, Client, Credentials, EstablishConnectionError},
1279 settings,
1280 test::FakeHttpClient,
1281 user::UserStore,
1282 worktree::Worktree,
1283 };
1284 use zrpc::Peer;
1285
1286 #[gpui::test]
1287 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1288 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1289 let settings = cx_b.read(settings::test).1;
1290 let lang_registry = Arc::new(LanguageRegistry::new());
1291
1292 // Connect to a server as 2 clients.
1293 let mut server = TestServer::start().await;
1294 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1295 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1296
1297 cx_a.foreground().forbid_parking();
1298
1299 // Share a local worktree as client A
1300 let fs = Arc::new(FakeFs::new());
1301 fs.insert_tree(
1302 "/a",
1303 json!({
1304 ".zed.toml": r#"collaborators = ["user_b"]"#,
1305 "a.txt": "a-contents",
1306 "b.txt": "b-contents",
1307 }),
1308 )
1309 .await;
1310 let worktree_a = Worktree::open_local(
1311 client_a.clone(),
1312 "/a".as_ref(),
1313 fs,
1314 lang_registry.clone(),
1315 &mut cx_a.to_async(),
1316 )
1317 .await
1318 .unwrap();
1319 worktree_a
1320 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1321 .await;
1322 let worktree_id = worktree_a
1323 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1324 .await
1325 .unwrap();
1326
1327 // Join that worktree as client B, and see that a guest has joined as client A.
1328 let worktree_b = Worktree::open_remote(
1329 client_b.clone(),
1330 worktree_id,
1331 lang_registry.clone(),
1332 &mut cx_b.to_async(),
1333 )
1334 .await
1335 .unwrap();
1336 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1337 worktree_a
1338 .condition(&cx_a, |tree, _| {
1339 tree.peers()
1340 .values()
1341 .any(|replica_id| *replica_id == replica_id_b)
1342 })
1343 .await;
1344
1345 // Open the same file as client B and client A.
1346 let buffer_b = worktree_b
1347 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1348 .await
1349 .unwrap();
1350 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1351 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1352 let buffer_a = worktree_a
1353 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1354 .await
1355 .unwrap();
1356
1357 // Create a selection set as client B and see that selection set as client A.
1358 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1359 buffer_a
1360 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1361 .await;
1362
1363 // Edit the buffer as client B and see that edit as client A.
1364 editor_b.update(&mut cx_b, |editor, cx| {
1365 editor.insert(&Insert("ok, ".into()), cx)
1366 });
1367 buffer_a
1368 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1369 .await;
1370
1371 // Remove the selection set as client B, see those selections disappear as client A.
1372 cx_b.update(move |_| drop(editor_b));
1373 buffer_a
1374 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1375 .await;
1376
1377 // Close the buffer as client A, see that the buffer is closed.
1378 cx_a.update(move |_| drop(buffer_a));
1379 worktree_a
1380 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1381 .await;
1382
1383 // Dropping the worktree removes client B from client A's peers.
1384 cx_b.update(move |_| drop(worktree_b));
1385 worktree_a
1386 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1387 .await;
1388 }
1389
1390 #[gpui::test]
1391 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1392 mut cx_a: TestAppContext,
1393 mut cx_b: TestAppContext,
1394 mut cx_c: TestAppContext,
1395 ) {
1396 cx_a.foreground().forbid_parking();
1397 let lang_registry = Arc::new(LanguageRegistry::new());
1398
1399 // Connect to a server as 3 clients.
1400 let mut server = TestServer::start().await;
1401 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1402 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1403 let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1404
1405 let fs = Arc::new(FakeFs::new());
1406
1407 // Share a worktree as client A.
1408 fs.insert_tree(
1409 "/a",
1410 json!({
1411 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1412 "file1": "",
1413 "file2": ""
1414 }),
1415 )
1416 .await;
1417
1418 let worktree_a = Worktree::open_local(
1419 client_a.clone(),
1420 "/a".as_ref(),
1421 fs.clone(),
1422 lang_registry.clone(),
1423 &mut cx_a.to_async(),
1424 )
1425 .await
1426 .unwrap();
1427 worktree_a
1428 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1429 .await;
1430 let worktree_id = worktree_a
1431 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1432 .await
1433 .unwrap();
1434
1435 // Join that worktree as clients B and C.
1436 let worktree_b = Worktree::open_remote(
1437 client_b.clone(),
1438 worktree_id,
1439 lang_registry.clone(),
1440 &mut cx_b.to_async(),
1441 )
1442 .await
1443 .unwrap();
1444 let worktree_c = Worktree::open_remote(
1445 client_c.clone(),
1446 worktree_id,
1447 lang_registry.clone(),
1448 &mut cx_c.to_async(),
1449 )
1450 .await
1451 .unwrap();
1452
1453 // Open and edit a buffer as both guests B and C.
1454 let buffer_b = worktree_b
1455 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1456 .await
1457 .unwrap();
1458 let buffer_c = worktree_c
1459 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1460 .await
1461 .unwrap();
1462 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1463 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1464
1465 // Open and edit that buffer as the host.
1466 let buffer_a = worktree_a
1467 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1468 .await
1469 .unwrap();
1470
1471 buffer_a
1472 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1473 .await;
1474 buffer_a.update(&mut cx_a, |buf, cx| {
1475 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1476 });
1477
1478 // Wait for edits to propagate
1479 buffer_a
1480 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1481 .await;
1482 buffer_b
1483 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1484 .await;
1485 buffer_c
1486 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1487 .await;
1488
1489 // Edit the buffer as the host and concurrently save as guest B.
1490 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1491 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1492 save_b.await.unwrap();
1493 assert_eq!(
1494 fs.load("/a/file1".as_ref()).await.unwrap(),
1495 "hi-a, i-am-c, i-am-b, i-am-a"
1496 );
1497 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1498 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1499 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1500
1501 // Make changes on host's file system, see those changes on the guests.
1502 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1503 .await
1504 .unwrap();
1505 fs.insert_file(Path::new("/a/file4"), "4".into())
1506 .await
1507 .unwrap();
1508
1509 worktree_b
1510 .condition(&cx_b, |tree, _| tree.file_count() == 4)
1511 .await;
1512 worktree_c
1513 .condition(&cx_c, |tree, _| tree.file_count() == 4)
1514 .await;
1515 worktree_b.read_with(&cx_b, |tree, _| {
1516 assert_eq!(
1517 tree.paths()
1518 .map(|p| p.to_string_lossy())
1519 .collect::<Vec<_>>(),
1520 &[".zed.toml", "file1", "file3", "file4"]
1521 )
1522 });
1523 worktree_c.read_with(&cx_c, |tree, _| {
1524 assert_eq!(
1525 tree.paths()
1526 .map(|p| p.to_string_lossy())
1527 .collect::<Vec<_>>(),
1528 &[".zed.toml", "file1", "file3", "file4"]
1529 )
1530 });
1531 }
1532
1533 #[gpui::test]
1534 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1535 cx_a.foreground().forbid_parking();
1536 let lang_registry = Arc::new(LanguageRegistry::new());
1537
1538 // Connect to a server as 2 clients.
1539 let mut server = TestServer::start().await;
1540 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1541 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1542
1543 // Share a local worktree as client A
1544 let fs = Arc::new(FakeFs::new());
1545 fs.insert_tree(
1546 "/dir",
1547 json!({
1548 ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1549 "a.txt": "a-contents",
1550 }),
1551 )
1552 .await;
1553
1554 let worktree_a = Worktree::open_local(
1555 client_a.clone(),
1556 "/dir".as_ref(),
1557 fs,
1558 lang_registry.clone(),
1559 &mut cx_a.to_async(),
1560 )
1561 .await
1562 .unwrap();
1563 worktree_a
1564 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1565 .await;
1566 let worktree_id = worktree_a
1567 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1568 .await
1569 .unwrap();
1570
1571 // Join that worktree as client B, and see that a guest has joined as client A.
1572 let worktree_b = Worktree::open_remote(
1573 client_b.clone(),
1574 worktree_id,
1575 lang_registry.clone(),
1576 &mut cx_b.to_async(),
1577 )
1578 .await
1579 .unwrap();
1580
1581 let buffer_b = worktree_b
1582 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1583 .await
1584 .unwrap();
1585 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1586
1587 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1588 buffer_b.read_with(&cx_b, |buf, _| {
1589 assert!(buf.is_dirty());
1590 assert!(!buf.has_conflict());
1591 });
1592
1593 buffer_b
1594 .update(&mut cx_b, |buf, cx| buf.save(cx))
1595 .unwrap()
1596 .await
1597 .unwrap();
1598 worktree_b
1599 .condition(&cx_b, |_, cx| {
1600 buffer_b.read(cx).file().unwrap().mtime != mtime
1601 })
1602 .await;
1603 buffer_b.read_with(&cx_b, |buf, _| {
1604 assert!(!buf.is_dirty());
1605 assert!(!buf.has_conflict());
1606 });
1607
1608 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1609 buffer_b.read_with(&cx_b, |buf, _| {
1610 assert!(buf.is_dirty());
1611 assert!(!buf.has_conflict());
1612 });
1613 }
1614
1615 #[gpui::test]
1616 async fn test_editing_while_guest_opens_buffer(
1617 mut cx_a: TestAppContext,
1618 mut cx_b: TestAppContext,
1619 ) {
1620 cx_a.foreground().forbid_parking();
1621 let lang_registry = Arc::new(LanguageRegistry::new());
1622
1623 // Connect to a server as 2 clients.
1624 let mut server = TestServer::start().await;
1625 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1626 let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1627
1628 // Share a local worktree as client A
1629 let fs = Arc::new(FakeFs::new());
1630 fs.insert_tree(
1631 "/dir",
1632 json!({
1633 ".zed.toml": r#"collaborators = ["user_b"]"#,
1634 "a.txt": "a-contents",
1635 }),
1636 )
1637 .await;
1638 let worktree_a = Worktree::open_local(
1639 client_a.clone(),
1640 "/dir".as_ref(),
1641 fs,
1642 lang_registry.clone(),
1643 &mut cx_a.to_async(),
1644 )
1645 .await
1646 .unwrap();
1647 worktree_a
1648 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1649 .await;
1650 let worktree_id = worktree_a
1651 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1652 .await
1653 .unwrap();
1654
1655 // Join that worktree as client B, and see that a guest has joined as client A.
1656 let worktree_b = Worktree::open_remote(
1657 client_b.clone(),
1658 worktree_id,
1659 lang_registry.clone(),
1660 &mut cx_b.to_async(),
1661 )
1662 .await
1663 .unwrap();
1664
1665 let buffer_a = worktree_a
1666 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1667 .await
1668 .unwrap();
1669 let buffer_b = cx_b
1670 .background()
1671 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1672
1673 task::yield_now().await;
1674 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1675
1676 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1677 let buffer_b = buffer_b.await.unwrap();
1678 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1679 }
1680
1681 #[gpui::test]
1682 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1683 cx_a.foreground().forbid_parking();
1684 let lang_registry = Arc::new(LanguageRegistry::new());
1685
1686 // Connect to a server as 2 clients.
1687 let mut server = TestServer::start().await;
1688 let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1689 let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1690
1691 // Share a local worktree as client A
1692 let fs = Arc::new(FakeFs::new());
1693 fs.insert_tree(
1694 "/a",
1695 json!({
1696 ".zed.toml": r#"collaborators = ["user_b"]"#,
1697 "a.txt": "a-contents",
1698 "b.txt": "b-contents",
1699 }),
1700 )
1701 .await;
1702 let worktree_a = Worktree::open_local(
1703 client_a.clone(),
1704 "/a".as_ref(),
1705 fs,
1706 lang_registry.clone(),
1707 &mut cx_a.to_async(),
1708 )
1709 .await
1710 .unwrap();
1711 worktree_a
1712 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1713 .await;
1714 let worktree_id = worktree_a
1715 .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1716 .await
1717 .unwrap();
1718
1719 // Join that worktree as client B, and see that a guest has joined as client A.
1720 let _worktree_b = Worktree::open_remote(
1721 client_b.clone(),
1722 worktree_id,
1723 lang_registry.clone(),
1724 &mut cx_b.to_async(),
1725 )
1726 .await
1727 .unwrap();
1728 worktree_a
1729 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1730 .await;
1731
1732 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1733 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1734 worktree_a
1735 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1736 .await;
1737 }
1738
1739 #[gpui::test]
1740 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1741 cx_a.foreground().forbid_parking();
1742
1743 // Connect to a server as 2 clients.
1744 let mut server = TestServer::start().await;
1745 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1746 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1747
1748 // Create an org that includes these 2 users.
1749 let db = &server.app_state.db;
1750 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1751 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1752 .await
1753 .unwrap();
1754 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1755 .await
1756 .unwrap();
1757
1758 // Create a channel that includes all the users.
1759 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1760 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1761 .await
1762 .unwrap();
1763 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1764 .await
1765 .unwrap();
1766 db.create_channel_message(
1767 channel_id,
1768 current_user_id(&user_store_b),
1769 "hello A, it's B.",
1770 OffsetDateTime::now_utc(),
1771 1,
1772 )
1773 .await
1774 .unwrap();
1775
1776 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1777 channels_a
1778 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1779 .await;
1780 channels_a.read_with(&cx_a, |list, _| {
1781 assert_eq!(
1782 list.available_channels().unwrap(),
1783 &[ChannelDetails {
1784 id: channel_id.to_proto(),
1785 name: "test-channel".to_string()
1786 }]
1787 )
1788 });
1789 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1790 this.get_channel(channel_id.to_proto(), cx).unwrap()
1791 });
1792 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1793 channel_a
1794 .condition(&cx_a, |channel, _| {
1795 channel_messages(channel)
1796 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1797 })
1798 .await;
1799
1800 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1801 channels_b
1802 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1803 .await;
1804 channels_b.read_with(&cx_b, |list, _| {
1805 assert_eq!(
1806 list.available_channels().unwrap(),
1807 &[ChannelDetails {
1808 id: channel_id.to_proto(),
1809 name: "test-channel".to_string()
1810 }]
1811 )
1812 });
1813
1814 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1815 this.get_channel(channel_id.to_proto(), cx).unwrap()
1816 });
1817 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1818 channel_b
1819 .condition(&cx_b, |channel, _| {
1820 channel_messages(channel)
1821 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1822 })
1823 .await;
1824
1825 channel_a
1826 .update(&mut cx_a, |channel, cx| {
1827 channel
1828 .send_message("oh, hi B.".to_string(), cx)
1829 .unwrap()
1830 .detach();
1831 let task = channel.send_message("sup".to_string(), cx).unwrap();
1832 assert_eq!(
1833 channel_messages(channel),
1834 &[
1835 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1836 ("user_a".to_string(), "oh, hi B.".to_string(), true),
1837 ("user_a".to_string(), "sup".to_string(), true)
1838 ]
1839 );
1840 task
1841 })
1842 .await
1843 .unwrap();
1844
1845 channel_b
1846 .condition(&cx_b, |channel, _| {
1847 channel_messages(channel)
1848 == [
1849 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1850 ("user_a".to_string(), "oh, hi B.".to_string(), false),
1851 ("user_a".to_string(), "sup".to_string(), false),
1852 ]
1853 })
1854 .await;
1855
1856 assert_eq!(
1857 server.state().await.channels[&channel_id]
1858 .connection_ids
1859 .len(),
1860 2
1861 );
1862 cx_b.update(|_| drop(channel_b));
1863 server
1864 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1865 .await;
1866
1867 cx_a.update(|_| drop(channel_a));
1868 server
1869 .condition(|state| !state.channels.contains_key(&channel_id))
1870 .await;
1871 }
1872
1873 #[gpui::test]
1874 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1875 cx_a.foreground().forbid_parking();
1876
1877 let mut server = TestServer::start().await;
1878 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1879
1880 let db = &server.app_state.db;
1881 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1882 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1883 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1884 .await
1885 .unwrap();
1886 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1887 .await
1888 .unwrap();
1889
1890 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1891 channels_a
1892 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1893 .await;
1894 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1895 this.get_channel(channel_id.to_proto(), cx).unwrap()
1896 });
1897
1898 // Messages aren't allowed to be too long.
1899 channel_a
1900 .update(&mut cx_a, |channel, cx| {
1901 let long_body = "this is long.\n".repeat(1024);
1902 channel.send_message(long_body, cx).unwrap()
1903 })
1904 .await
1905 .unwrap_err();
1906
1907 // Messages aren't allowed to be blank.
1908 channel_a.update(&mut cx_a, |channel, cx| {
1909 channel.send_message(String::new(), cx).unwrap_err()
1910 });
1911
1912 // Leading and trailing whitespace are trimmed.
1913 channel_a
1914 .update(&mut cx_a, |channel, cx| {
1915 channel
1916 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1917 .unwrap()
1918 })
1919 .await
1920 .unwrap();
1921 assert_eq!(
1922 db.get_channel_messages(channel_id, 10, None)
1923 .await
1924 .unwrap()
1925 .iter()
1926 .map(|m| &m.body)
1927 .collect::<Vec<_>>(),
1928 &["surrounded by whitespace"]
1929 );
1930 }
1931
1932 #[gpui::test]
1933 async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1934 cx_a.foreground().forbid_parking();
1935 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1936
1937 // Connect to a server as 2 clients.
1938 let mut server = TestServer::start().await;
1939 let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1940 let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1941 let mut status_b = client_b.status();
1942
1943 // Create an org that includes these 2 users.
1944 let db = &server.app_state.db;
1945 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1946 db.add_org_member(org_id, current_user_id(&user_store_a), false)
1947 .await
1948 .unwrap();
1949 db.add_org_member(org_id, current_user_id(&user_store_b), false)
1950 .await
1951 .unwrap();
1952
1953 // Create a channel that includes all the users.
1954 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1955 db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1956 .await
1957 .unwrap();
1958 db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1959 .await
1960 .unwrap();
1961 db.create_channel_message(
1962 channel_id,
1963 current_user_id(&user_store_b),
1964 "hello A, it's B.",
1965 OffsetDateTime::now_utc(),
1966 2,
1967 )
1968 .await
1969 .unwrap();
1970
1971 let user_store_a =
1972 UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
1973 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1974 channels_a
1975 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1976 .await;
1977
1978 channels_a.read_with(&cx_a, |list, _| {
1979 assert_eq!(
1980 list.available_channels().unwrap(),
1981 &[ChannelDetails {
1982 id: channel_id.to_proto(),
1983 name: "test-channel".to_string()
1984 }]
1985 )
1986 });
1987 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1988 this.get_channel(channel_id.to_proto(), cx).unwrap()
1989 });
1990 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1991 channel_a
1992 .condition(&cx_a, |channel, _| {
1993 channel_messages(channel)
1994 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1995 })
1996 .await;
1997
1998 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
1999 channels_b
2000 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2001 .await;
2002 channels_b.read_with(&cx_b, |list, _| {
2003 assert_eq!(
2004 list.available_channels().unwrap(),
2005 &[ChannelDetails {
2006 id: channel_id.to_proto(),
2007 name: "test-channel".to_string()
2008 }]
2009 )
2010 });
2011
2012 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2013 this.get_channel(channel_id.to_proto(), cx).unwrap()
2014 });
2015 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2016 channel_b
2017 .condition(&cx_b, |channel, _| {
2018 channel_messages(channel)
2019 == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2020 })
2021 .await;
2022
2023 // Disconnect client B, ensuring we can still access its cached channel data.
2024 server.forbid_connections();
2025 server.disconnect_client(current_user_id(&user_store_b));
2026 while !matches!(
2027 status_b.recv().await,
2028 Some(rpc::Status::ReconnectionError { .. })
2029 ) {}
2030
2031 channels_b.read_with(&cx_b, |channels, _| {
2032 assert_eq!(
2033 channels.available_channels().unwrap(),
2034 [ChannelDetails {
2035 id: channel_id.to_proto(),
2036 name: "test-channel".to_string()
2037 }]
2038 )
2039 });
2040 channel_b.read_with(&cx_b, |channel, _| {
2041 assert_eq!(
2042 channel_messages(channel),
2043 [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2044 )
2045 });
2046
2047 // Send a message from client B while it is disconnected.
2048 channel_b
2049 .update(&mut cx_b, |channel, cx| {
2050 let task = channel
2051 .send_message("can you see this?".to_string(), cx)
2052 .unwrap();
2053 assert_eq!(
2054 channel_messages(channel),
2055 &[
2056 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2057 ("user_b".to_string(), "can you see this?".to_string(), true)
2058 ]
2059 );
2060 task
2061 })
2062 .await
2063 .unwrap_err();
2064
2065 // Send a message from client A while B is disconnected.
2066 channel_a
2067 .update(&mut cx_a, |channel, cx| {
2068 channel
2069 .send_message("oh, hi B.".to_string(), cx)
2070 .unwrap()
2071 .detach();
2072 let task = channel.send_message("sup".to_string(), cx).unwrap();
2073 assert_eq!(
2074 channel_messages(channel),
2075 &[
2076 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2077 ("user_a".to_string(), "oh, hi B.".to_string(), true),
2078 ("user_a".to_string(), "sup".to_string(), true)
2079 ]
2080 );
2081 task
2082 })
2083 .await
2084 .unwrap();
2085
2086 // Give client B a chance to reconnect.
2087 server.allow_connections();
2088 cx_b.foreground().advance_clock(Duration::from_secs(10));
2089
2090 // Verify that B sees the new messages upon reconnection, as well as the message client B
2091 // sent while offline.
2092 channel_b
2093 .condition(&cx_b, |channel, _| {
2094 channel_messages(channel)
2095 == [
2096 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2097 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2098 ("user_a".to_string(), "sup".to_string(), false),
2099 ("user_b".to_string(), "can you see this?".to_string(), false),
2100 ]
2101 })
2102 .await;
2103
2104 // Ensure client A and B can communicate normally after reconnection.
2105 channel_a
2106 .update(&mut cx_a, |channel, cx| {
2107 channel.send_message("you online?".to_string(), cx).unwrap()
2108 })
2109 .await
2110 .unwrap();
2111 channel_b
2112 .condition(&cx_b, |channel, _| {
2113 channel_messages(channel)
2114 == [
2115 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2116 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2117 ("user_a".to_string(), "sup".to_string(), false),
2118 ("user_b".to_string(), "can you see this?".to_string(), false),
2119 ("user_a".to_string(), "you online?".to_string(), false),
2120 ]
2121 })
2122 .await;
2123
2124 channel_b
2125 .update(&mut cx_b, |channel, cx| {
2126 channel.send_message("yep".to_string(), cx).unwrap()
2127 })
2128 .await
2129 .unwrap();
2130 channel_a
2131 .condition(&cx_a, |channel, _| {
2132 channel_messages(channel)
2133 == [
2134 ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2135 ("user_a".to_string(), "oh, hi B.".to_string(), false),
2136 ("user_a".to_string(), "sup".to_string(), false),
2137 ("user_b".to_string(), "can you see this?".to_string(), false),
2138 ("user_a".to_string(), "you online?".to_string(), false),
2139 ("user_b".to_string(), "yep".to_string(), false),
2140 ]
2141 })
2142 .await;
2143 }
2144
2145 struct TestServer {
2146 peer: Arc<Peer>,
2147 app_state: Arc<AppState>,
2148 server: Arc<Server>,
2149 notifications: mpsc::Receiver<()>,
2150 connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2151 forbid_connections: Arc<AtomicBool>,
2152 _test_db: TestDb,
2153 }
2154
2155 impl TestServer {
2156 async fn start() -> Self {
2157 let test_db = TestDb::new();
2158 let app_state = Self::build_app_state(&test_db).await;
2159 let peer = Peer::new();
2160 let notifications = mpsc::channel(128);
2161 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2162 Self {
2163 peer,
2164 app_state,
2165 server,
2166 notifications: notifications.1,
2167 connection_killers: Default::default(),
2168 forbid_connections: Default::default(),
2169 _test_db: test_db,
2170 }
2171 }
2172
2173 async fn create_client(
2174 &mut self,
2175 cx: &mut TestAppContext,
2176 name: &str,
2177 ) -> (Arc<Client>, Arc<UserStore>) {
2178 let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2179 let client_name = name.to_string();
2180 let mut client = Client::new();
2181 let server = self.server.clone();
2182 let connection_killers = self.connection_killers.clone();
2183 let forbid_connections = self.forbid_connections.clone();
2184 Arc::get_mut(&mut client)
2185 .unwrap()
2186 .override_authenticate(move |cx| {
2187 cx.spawn(|_| async move {
2188 let access_token = "the-token".to_string();
2189 Ok(Credentials {
2190 user_id: user_id.0 as u64,
2191 access_token,
2192 })
2193 })
2194 })
2195 .override_establish_connection(move |credentials, cx| {
2196 assert_eq!(credentials.user_id, user_id.0 as u64);
2197 assert_eq!(credentials.access_token, "the-token");
2198
2199 let server = server.clone();
2200 let connection_killers = connection_killers.clone();
2201 let forbid_connections = forbid_connections.clone();
2202 let client_name = client_name.clone();
2203 cx.spawn(move |cx| async move {
2204 if forbid_connections.load(SeqCst) {
2205 Err(EstablishConnectionError::other(anyhow!(
2206 "server is forbidding connections"
2207 )))
2208 } else {
2209 let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2210 connection_killers.lock().insert(user_id, kill_conn);
2211 cx.background()
2212 .spawn(server.handle_connection(server_conn, client_name, user_id))
2213 .detach();
2214 Ok(client_conn)
2215 }
2216 })
2217 });
2218
2219 let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2220 client
2221 .authenticate_and_connect(&cx.to_async())
2222 .await
2223 .unwrap();
2224
2225 let user_store = UserStore::new(client.clone(), http, &cx.background());
2226 let mut authed_user = user_store.watch_current_user();
2227 while authed_user.recv().await.unwrap().is_none() {}
2228
2229 (client, user_store)
2230 }
2231
2232 fn disconnect_client(&self, user_id: UserId) {
2233 if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2234 let _ = kill_conn.try_send(Some(()));
2235 }
2236 }
2237
2238 fn forbid_connections(&self) {
2239 self.forbid_connections.store(true, SeqCst);
2240 }
2241
2242 fn allow_connections(&self) {
2243 self.forbid_connections.store(false, SeqCst);
2244 }
2245
2246 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2247 let mut config = Config::default();
2248 config.session_secret = "a".repeat(32);
2249 config.database_url = test_db.url.clone();
2250 let github_client = github::AppClient::test();
2251 Arc::new(AppState {
2252 db: test_db.db().clone(),
2253 handlebars: Default::default(),
2254 auth_client: auth::build_client("", ""),
2255 repo_client: github::RepoClient::test(&github_client),
2256 github_client,
2257 config,
2258 })
2259 }
2260
2261 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2262 self.server.state.read().await
2263 }
2264
2265 async fn condition<F>(&mut self, mut predicate: F)
2266 where
2267 F: FnMut(&ServerState) -> bool,
2268 {
2269 async_std::future::timeout(Duration::from_millis(500), async {
2270 while !(predicate)(&*self.server.state.read().await) {
2271 self.notifications.recv().await;
2272 }
2273 })
2274 .await
2275 .expect("condition timed out");
2276 }
2277 }
2278
2279 impl Drop for TestServer {
2280 fn drop(&mut self) {
2281 task::block_on(self.peer.reset());
2282 }
2283 }
2284
2285 fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2286 UserId::from_proto(user_store.current_user().unwrap().id)
2287 }
2288
2289 fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2290 channel
2291 .messages()
2292 .cursor::<(), ()>()
2293 .map(|m| {
2294 (
2295 m.sender.github_login.clone(),
2296 m.body.clone(),
2297 m.is_pending(),
2298 )
2299 })
2300 .collect()
2301 }
2302
2303 struct EmptyView;
2304
2305 impl gpui::Entity for EmptyView {
2306 type Event = ();
2307 }
2308
2309 impl gpui::View for EmptyView {
2310 fn ui_name() -> &'static str {
2311 "empty view"
2312 }
2313
2314 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2315 gpui::Element::boxed(gpui::elements::Empty)
2316 }
2317 }
2318}