1use super::{
2 auth,
3 db::{ChannelId, MessageId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
9use futures::{future::BoxFuture, FutureExt};
10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
11use sha1::{Digest as _, Sha1};
12use std::{
13 any::TypeId,
14 collections::{hash_map, HashMap, HashSet},
15 future::Future,
16 mem,
17 sync::Arc,
18 time::Instant,
19};
20use surf::StatusCode;
21use tide::log;
22use tide::{
23 http::headers::{HeaderName, CONNECTION, UPGRADE},
24 Request, Response,
25};
26use time::OffsetDateTime;
27use zrpc::{
28 auth::random_token,
29 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
30 Conn, ConnectionId, Peer, TypedEnvelope,
31};
32
33type ReplicaId = u16;
34
35type MessageHandler = Box<
36 dyn Send
37 + Sync
38 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
39>;
40
41pub struct Server {
42 peer: Arc<Peer>,
43 state: RwLock<ServerState>,
44 app_state: Arc<AppState>,
45 handlers: HashMap<TypeId, MessageHandler>,
46 notifications: Option<mpsc::Sender<()>>,
47}
48
49#[derive(Default)]
50struct ServerState {
51 connections: HashMap<ConnectionId, Connection>,
52 pub worktrees: HashMap<u64, Worktree>,
53 channels: HashMap<ChannelId, Channel>,
54 next_worktree_id: u64,
55}
56
57struct Connection {
58 user_id: UserId,
59 worktrees: HashSet<u64>,
60 channels: HashSet<ChannelId>,
61}
62
63struct Worktree {
64 host_connection_id: Option<ConnectionId>,
65 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
66 active_replica_ids: HashSet<ReplicaId>,
67 access_token: String,
68 root_name: String,
69 entries: HashMap<u64, proto::Entry>,
70}
71
72#[derive(Default)]
73struct Channel {
74 connection_ids: HashSet<ConnectionId>,
75}
76
77const MESSAGE_COUNT_PER_PAGE: usize = 100;
78const MAX_MESSAGE_LEN: usize = 1024;
79
80impl Server {
81 pub fn new(
82 app_state: Arc<AppState>,
83 peer: Arc<Peer>,
84 notifications: Option<mpsc::Sender<()>>,
85 ) -> Arc<Self> {
86 let mut server = Self {
87 peer,
88 app_state,
89 state: Default::default(),
90 handlers: Default::default(),
91 notifications,
92 };
93
94 server
95 .add_handler(Server::ping)
96 .add_handler(Server::share_worktree)
97 .add_handler(Server::join_worktree)
98 .add_handler(Server::update_worktree)
99 .add_handler(Server::close_worktree)
100 .add_handler(Server::open_buffer)
101 .add_handler(Server::close_buffer)
102 .add_handler(Server::update_buffer)
103 .add_handler(Server::buffer_saved)
104 .add_handler(Server::save_buffer)
105 .add_handler(Server::get_channels)
106 .add_handler(Server::get_users)
107 .add_handler(Server::join_channel)
108 .add_handler(Server::leave_channel)
109 .add_handler(Server::send_channel_message)
110 .add_handler(Server::get_channel_messages);
111
112 Arc::new(server)
113 }
114
115 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
116 where
117 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
118 Fut: 'static + Send + Future<Output = tide::Result<()>>,
119 M: EnvelopedMessage,
120 {
121 let prev_handler = self.handlers.insert(
122 TypeId::of::<M>(),
123 Box::new(move |server, envelope| {
124 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
125 (handler)(server, *envelope).boxed()
126 }),
127 );
128 if prev_handler.is_some() {
129 panic!("registered a handler for the same message twice");
130 }
131 self
132 }
133
134 pub fn handle_connection(
135 self: &Arc<Self>,
136 connection: Conn,
137 addr: String,
138 user_id: UserId,
139 ) -> impl Future<Output = ()> {
140 let this = self.clone();
141 async move {
142 let (connection_id, handle_io, mut incoming_rx) =
143 this.peer.add_connection(connection).await;
144 this.add_connection(connection_id, user_id).await;
145
146 let handle_io = handle_io.fuse();
147 futures::pin_mut!(handle_io);
148 loop {
149 let next_message = incoming_rx.recv().fuse();
150 futures::pin_mut!(next_message);
151 futures::select_biased! {
152 message = next_message => {
153 if let Some(message) = message {
154 let start_time = Instant::now();
155 log::info!("RPC message received: {}", message.payload_type_name());
156 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
157 if let Err(err) = (handler)(this.clone(), message).await {
158 log::error!("error handling message: {:?}", err);
159 } else {
160 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
161 }
162
163 if let Some(mut notifications) = this.notifications.clone() {
164 let _ = notifications.send(()).await;
165 }
166 } else {
167 log::warn!("unhandled message: {}", message.payload_type_name());
168 }
169 } else {
170 log::info!("rpc connection closed {:?}", addr);
171 break;
172 }
173 }
174 handle_io = handle_io => {
175 if let Err(err) = handle_io {
176 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
177 }
178 break;
179 }
180 }
181 }
182
183 if let Err(err) = this.sign_out(connection_id).await {
184 log::error!("error signing out connection {:?} - {:?}", addr, err);
185 }
186 }
187 }
188
189 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
190 self.peer.disconnect(connection_id).await;
191 let worktree_ids = self.remove_connection(connection_id).await;
192 for worktree_id in worktree_ids {
193 let state = self.state.read().await;
194 if let Some(worktree) = state.worktrees.get(&worktree_id) {
195 broadcast(connection_id, worktree.connection_ids(), |conn_id| {
196 self.peer.send(
197 conn_id,
198 proto::RemovePeer {
199 worktree_id,
200 peer_id: connection_id.0,
201 },
202 )
203 })
204 .await?;
205 }
206 }
207 Ok(())
208 }
209
210 // Add a new connection associated with a given user.
211 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
212 self.state.write().await.connections.insert(
213 connection_id,
214 Connection {
215 user_id,
216 worktrees: Default::default(),
217 channels: Default::default(),
218 },
219 );
220 }
221
222 // Remove the given connection and its association with any worktrees.
223 async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
224 let mut worktree_ids = Vec::new();
225 let mut state = self.state.write().await;
226 if let Some(connection) = state.connections.remove(&connection_id) {
227 for channel_id in connection.channels {
228 if let Some(channel) = state.channels.get_mut(&channel_id) {
229 channel.connection_ids.remove(&connection_id);
230 }
231 }
232 for worktree_id in connection.worktrees {
233 if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
234 if worktree.host_connection_id == Some(connection_id) {
235 worktree_ids.push(worktree_id);
236 } else if let Some(replica_id) =
237 worktree.guest_connection_ids.remove(&connection_id)
238 {
239 worktree.active_replica_ids.remove(&replica_id);
240 worktree_ids.push(worktree_id);
241 }
242 }
243 }
244 }
245 worktree_ids
246 }
247
248 async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
249 self.peer
250 .respond(
251 request.receipt(),
252 proto::Pong {
253 id: request.payload.id,
254 },
255 )
256 .await?;
257 Ok(())
258 }
259
260 async fn share_worktree(
261 self: Arc<Server>,
262 mut request: TypedEnvelope<proto::ShareWorktree>,
263 ) -> tide::Result<()> {
264 let mut state = self.state.write().await;
265 let worktree_id = state.next_worktree_id;
266 state.next_worktree_id += 1;
267 let access_token = random_token();
268 let worktree = request
269 .payload
270 .worktree
271 .as_mut()
272 .ok_or_else(|| anyhow!("missing worktree"))?;
273 let entries = mem::take(&mut worktree.entries)
274 .into_iter()
275 .map(|entry| (entry.id, entry))
276 .collect();
277 state.worktrees.insert(
278 worktree_id,
279 Worktree {
280 host_connection_id: Some(request.sender_id),
281 guest_connection_ids: Default::default(),
282 active_replica_ids: Default::default(),
283 access_token: access_token.clone(),
284 root_name: mem::take(&mut worktree.root_name),
285 entries,
286 },
287 );
288
289 self.peer
290 .respond(
291 request.receipt(),
292 proto::ShareWorktreeResponse {
293 worktree_id,
294 access_token,
295 },
296 )
297 .await?;
298 Ok(())
299 }
300
301 async fn join_worktree(
302 self: Arc<Server>,
303 request: TypedEnvelope<proto::OpenWorktree>,
304 ) -> tide::Result<()> {
305 let worktree_id = request.payload.worktree_id;
306 let access_token = &request.payload.access_token;
307
308 let mut state = self.state.write().await;
309 if let Some((peer_replica_id, worktree)) =
310 state.join_worktree(request.sender_id, worktree_id, access_token)
311 {
312 let mut peers = Vec::new();
313 if let Some(host_connection_id) = worktree.host_connection_id {
314 peers.push(proto::Peer {
315 peer_id: host_connection_id.0,
316 replica_id: 0,
317 });
318 }
319 for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
320 if *peer_conn_id != request.sender_id {
321 peers.push(proto::Peer {
322 peer_id: peer_conn_id.0,
323 replica_id: *peer_replica_id as u32,
324 });
325 }
326 }
327
328 broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
329 self.peer.send(
330 conn_id,
331 proto::AddPeer {
332 worktree_id,
333 peer: Some(proto::Peer {
334 peer_id: request.sender_id.0,
335 replica_id: peer_replica_id as u32,
336 }),
337 },
338 )
339 })
340 .await?;
341 self.peer
342 .respond(
343 request.receipt(),
344 proto::OpenWorktreeResponse {
345 worktree_id,
346 worktree: Some(proto::Worktree {
347 root_name: worktree.root_name.clone(),
348 entries: worktree.entries.values().cloned().collect(),
349 }),
350 replica_id: peer_replica_id as u32,
351 peers,
352 },
353 )
354 .await?;
355 } else {
356 self.peer
357 .respond(
358 request.receipt(),
359 proto::OpenWorktreeResponse {
360 worktree_id,
361 worktree: None,
362 replica_id: 0,
363 peers: Vec::new(),
364 },
365 )
366 .await?;
367 }
368
369 Ok(())
370 }
371
372 async fn update_worktree(
373 self: Arc<Server>,
374 request: TypedEnvelope<proto::UpdateWorktree>,
375 ) -> tide::Result<()> {
376 {
377 let mut state = self.state.write().await;
378 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
379 for entry_id in &request.payload.removed_entries {
380 worktree.entries.remove(&entry_id);
381 }
382
383 for entry in &request.payload.updated_entries {
384 worktree.entries.insert(entry.id, entry.clone());
385 }
386 }
387
388 self.broadcast_in_worktree(request.payload.worktree_id, &request)
389 .await?;
390 Ok(())
391 }
392
393 async fn close_worktree(
394 self: Arc<Server>,
395 request: TypedEnvelope<proto::CloseWorktree>,
396 ) -> tide::Result<()> {
397 let connection_ids;
398 {
399 let mut state = self.state.write().await;
400 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
401 connection_ids = worktree.connection_ids();
402 if worktree.host_connection_id == Some(request.sender_id) {
403 worktree.host_connection_id = None;
404 } else if let Some(replica_id) =
405 worktree.guest_connection_ids.remove(&request.sender_id)
406 {
407 worktree.active_replica_ids.remove(&replica_id);
408 }
409 }
410
411 broadcast(request.sender_id, connection_ids, |conn_id| {
412 self.peer.send(
413 conn_id,
414 proto::RemovePeer {
415 worktree_id: request.payload.worktree_id,
416 peer_id: request.sender_id.0,
417 },
418 )
419 })
420 .await?;
421
422 Ok(())
423 }
424
425 async fn open_buffer(
426 self: Arc<Server>,
427 request: TypedEnvelope<proto::OpenBuffer>,
428 ) -> tide::Result<()> {
429 let receipt = request.receipt();
430 let worktree_id = request.payload.worktree_id;
431 let host_connection_id = self
432 .state
433 .read()
434 .await
435 .read_worktree(worktree_id, request.sender_id)?
436 .host_connection_id()?;
437
438 let response = self
439 .peer
440 .forward_request(request.sender_id, host_connection_id, request.payload)
441 .await?;
442 self.peer.respond(receipt, response).await?;
443 Ok(())
444 }
445
446 async fn close_buffer(
447 self: Arc<Server>,
448 request: TypedEnvelope<proto::CloseBuffer>,
449 ) -> tide::Result<()> {
450 let host_connection_id = self
451 .state
452 .read()
453 .await
454 .read_worktree(request.payload.worktree_id, request.sender_id)?
455 .host_connection_id()?;
456
457 self.peer
458 .forward_send(request.sender_id, host_connection_id, request.payload)
459 .await?;
460
461 Ok(())
462 }
463
464 async fn save_buffer(
465 self: Arc<Server>,
466 request: TypedEnvelope<proto::SaveBuffer>,
467 ) -> tide::Result<()> {
468 let host;
469 let guests;
470 {
471 let state = self.state.read().await;
472 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
473 host = worktree.host_connection_id()?;
474 guests = worktree
475 .guest_connection_ids
476 .keys()
477 .copied()
478 .collect::<Vec<_>>();
479 }
480
481 let sender = request.sender_id;
482 let receipt = request.receipt();
483 let response = self
484 .peer
485 .forward_request(sender, host, request.payload.clone())
486 .await?;
487
488 broadcast(host, guests, |conn_id| {
489 let response = response.clone();
490 let peer = &self.peer;
491 async move {
492 if conn_id == sender {
493 peer.respond(receipt, response).await
494 } else {
495 peer.forward_send(host, conn_id, response).await
496 }
497 }
498 })
499 .await?;
500
501 Ok(())
502 }
503
504 async fn update_buffer(
505 self: Arc<Server>,
506 request: TypedEnvelope<proto::UpdateBuffer>,
507 ) -> tide::Result<()> {
508 self.broadcast_in_worktree(request.payload.worktree_id, &request)
509 .await
510 }
511
512 async fn buffer_saved(
513 self: Arc<Server>,
514 request: TypedEnvelope<proto::BufferSaved>,
515 ) -> tide::Result<()> {
516 self.broadcast_in_worktree(request.payload.worktree_id, &request)
517 .await
518 }
519
520 async fn get_channels(
521 self: Arc<Server>,
522 request: TypedEnvelope<proto::GetChannels>,
523 ) -> tide::Result<()> {
524 let user_id = self
525 .state
526 .read()
527 .await
528 .user_id_for_connection(request.sender_id)?;
529 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
530 self.peer
531 .respond(
532 request.receipt(),
533 proto::GetChannelsResponse {
534 channels: channels
535 .into_iter()
536 .map(|chan| proto::Channel {
537 id: chan.id.to_proto(),
538 name: chan.name,
539 })
540 .collect(),
541 },
542 )
543 .await?;
544 Ok(())
545 }
546
547 async fn get_users(
548 self: Arc<Server>,
549 request: TypedEnvelope<proto::GetUsers>,
550 ) -> tide::Result<()> {
551 let user_id = self
552 .state
553 .read()
554 .await
555 .user_id_for_connection(request.sender_id)?;
556 let receipt = request.receipt();
557 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
558 let users = self
559 .app_state
560 .db
561 .get_users_by_ids(user_id, user_ids)
562 .await?
563 .into_iter()
564 .map(|user| proto::User {
565 id: user.id.to_proto(),
566 github_login: user.github_login,
567 avatar_url: String::new(),
568 })
569 .collect();
570 self.peer
571 .respond(receipt, proto::GetUsersResponse { users })
572 .await?;
573 Ok(())
574 }
575
576 async fn join_channel(
577 self: Arc<Self>,
578 request: TypedEnvelope<proto::JoinChannel>,
579 ) -> tide::Result<()> {
580 let user_id = self
581 .state
582 .read()
583 .await
584 .user_id_for_connection(request.sender_id)?;
585 let channel_id = ChannelId::from_proto(request.payload.channel_id);
586 if !self
587 .app_state
588 .db
589 .can_user_access_channel(user_id, channel_id)
590 .await?
591 {
592 Err(anyhow!("access denied"))?;
593 }
594
595 self.state
596 .write()
597 .await
598 .join_channel(request.sender_id, channel_id);
599 let messages = self
600 .app_state
601 .db
602 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
603 .await?
604 .into_iter()
605 .map(|msg| proto::ChannelMessage {
606 id: msg.id.to_proto(),
607 body: msg.body,
608 timestamp: msg.sent_at.unix_timestamp() as u64,
609 sender_id: msg.sender_id.to_proto(),
610 })
611 .collect::<Vec<_>>();
612 self.peer
613 .respond(
614 request.receipt(),
615 proto::JoinChannelResponse {
616 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
617 messages,
618 },
619 )
620 .await?;
621 Ok(())
622 }
623
624 async fn leave_channel(
625 self: Arc<Self>,
626 request: TypedEnvelope<proto::LeaveChannel>,
627 ) -> tide::Result<()> {
628 let user_id = self
629 .state
630 .read()
631 .await
632 .user_id_for_connection(request.sender_id)?;
633 let channel_id = ChannelId::from_proto(request.payload.channel_id);
634 if !self
635 .app_state
636 .db
637 .can_user_access_channel(user_id, channel_id)
638 .await?
639 {
640 Err(anyhow!("access denied"))?;
641 }
642
643 self.state
644 .write()
645 .await
646 .leave_channel(request.sender_id, channel_id);
647
648 Ok(())
649 }
650
651 async fn send_channel_message(
652 self: Arc<Self>,
653 request: TypedEnvelope<proto::SendChannelMessage>,
654 ) -> tide::Result<()> {
655 let receipt = request.receipt();
656 let channel_id = ChannelId::from_proto(request.payload.channel_id);
657 let user_id;
658 let connection_ids;
659 {
660 let state = self.state.read().await;
661 user_id = state.user_id_for_connection(request.sender_id)?;
662 if let Some(channel) = state.channels.get(&channel_id) {
663 connection_ids = channel.connection_ids();
664 } else {
665 return Ok(());
666 }
667 }
668
669 // Validate the message body.
670 let body = request.payload.body.trim().to_string();
671 if body.len() > MAX_MESSAGE_LEN {
672 self.peer
673 .respond_with_error(
674 receipt,
675 proto::Error {
676 message: "message is too long".to_string(),
677 },
678 )
679 .await?;
680 return Ok(());
681 }
682 if body.is_empty() {
683 self.peer
684 .respond_with_error(
685 receipt,
686 proto::Error {
687 message: "message can't be blank".to_string(),
688 },
689 )
690 .await?;
691 return Ok(());
692 }
693
694 let timestamp = OffsetDateTime::now_utc();
695 let message_id = self
696 .app_state
697 .db
698 .create_channel_message(channel_id, user_id, &body, timestamp)
699 .await?
700 .to_proto();
701 let message = proto::ChannelMessage {
702 sender_id: user_id.to_proto(),
703 id: message_id,
704 body,
705 timestamp: timestamp.unix_timestamp() as u64,
706 };
707 broadcast(request.sender_id, connection_ids, |conn_id| {
708 self.peer.send(
709 conn_id,
710 proto::ChannelMessageSent {
711 channel_id: channel_id.to_proto(),
712 message: Some(message.clone()),
713 },
714 )
715 })
716 .await?;
717 self.peer
718 .respond(
719 receipt,
720 proto::SendChannelMessageResponse {
721 message: Some(message),
722 },
723 )
724 .await?;
725 Ok(())
726 }
727
728 async fn get_channel_messages(
729 self: Arc<Self>,
730 request: TypedEnvelope<proto::GetChannelMessages>,
731 ) -> tide::Result<()> {
732 let user_id = self
733 .state
734 .read()
735 .await
736 .user_id_for_connection(request.sender_id)?;
737 let channel_id = ChannelId::from_proto(request.payload.channel_id);
738 if !self
739 .app_state
740 .db
741 .can_user_access_channel(user_id, channel_id)
742 .await?
743 {
744 Err(anyhow!("access denied"))?;
745 }
746
747 let messages = self
748 .app_state
749 .db
750 .get_channel_messages(
751 channel_id,
752 MESSAGE_COUNT_PER_PAGE,
753 Some(MessageId::from_proto(request.payload.before_message_id)),
754 )
755 .await?
756 .into_iter()
757 .map(|msg| proto::ChannelMessage {
758 id: msg.id.to_proto(),
759 body: msg.body,
760 timestamp: msg.sent_at.unix_timestamp() as u64,
761 sender_id: msg.sender_id.to_proto(),
762 })
763 .collect::<Vec<_>>();
764 self.peer
765 .respond(
766 request.receipt(),
767 proto::GetChannelMessagesResponse {
768 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
769 messages,
770 },
771 )
772 .await?;
773 Ok(())
774 }
775
776 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
777 &self,
778 worktree_id: u64,
779 message: &TypedEnvelope<T>,
780 ) -> tide::Result<()> {
781 let connection_ids = self
782 .state
783 .read()
784 .await
785 .read_worktree(worktree_id, message.sender_id)?
786 .connection_ids();
787
788 broadcast(message.sender_id, connection_ids, |conn_id| {
789 self.peer
790 .forward_send(message.sender_id, conn_id, message.payload.clone())
791 })
792 .await?;
793
794 Ok(())
795 }
796}
797
798pub async fn broadcast<F, T>(
799 sender_id: ConnectionId,
800 receiver_ids: Vec<ConnectionId>,
801 mut f: F,
802) -> anyhow::Result<()>
803where
804 F: FnMut(ConnectionId) -> T,
805 T: Future<Output = anyhow::Result<()>>,
806{
807 let futures = receiver_ids
808 .into_iter()
809 .filter(|id| *id != sender_id)
810 .map(|id| f(id));
811 futures::future::try_join_all(futures).await?;
812 Ok(())
813}
814
815impl ServerState {
816 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
817 if let Some(connection) = self.connections.get_mut(&connection_id) {
818 connection.channels.insert(channel_id);
819 self.channels
820 .entry(channel_id)
821 .or_default()
822 .connection_ids
823 .insert(connection_id);
824 }
825 }
826
827 fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
828 if let Some(connection) = self.connections.get_mut(&connection_id) {
829 connection.channels.remove(&channel_id);
830 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
831 entry.get_mut().connection_ids.remove(&connection_id);
832 if entry.get_mut().connection_ids.is_empty() {
833 entry.remove();
834 }
835 }
836 }
837 }
838
839 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
840 Ok(self
841 .connections
842 .get(&connection_id)
843 .ok_or_else(|| anyhow!("unknown connection"))?
844 .user_id)
845 }
846
847 // Add the given connection as a guest of the given worktree
848 fn join_worktree(
849 &mut self,
850 connection_id: ConnectionId,
851 worktree_id: u64,
852 access_token: &str,
853 ) -> Option<(ReplicaId, &Worktree)> {
854 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
855 if access_token == worktree.access_token {
856 if let Some(connection) = self.connections.get_mut(&connection_id) {
857 connection.worktrees.insert(worktree_id);
858 }
859
860 let mut replica_id = 1;
861 while worktree.active_replica_ids.contains(&replica_id) {
862 replica_id += 1;
863 }
864 worktree.active_replica_ids.insert(replica_id);
865 worktree
866 .guest_connection_ids
867 .insert(connection_id, replica_id);
868 Some((replica_id, worktree))
869 } else {
870 None
871 }
872 } else {
873 None
874 }
875 }
876
877 fn read_worktree(
878 &self,
879 worktree_id: u64,
880 connection_id: ConnectionId,
881 ) -> tide::Result<&Worktree> {
882 let worktree = self
883 .worktrees
884 .get(&worktree_id)
885 .ok_or_else(|| anyhow!("worktree not found"))?;
886
887 if worktree.host_connection_id == Some(connection_id)
888 || worktree.guest_connection_ids.contains_key(&connection_id)
889 {
890 Ok(worktree)
891 } else {
892 Err(anyhow!(
893 "{} is not a member of worktree {}",
894 connection_id,
895 worktree_id
896 ))?
897 }
898 }
899
900 fn write_worktree(
901 &mut self,
902 worktree_id: u64,
903 connection_id: ConnectionId,
904 ) -> tide::Result<&mut Worktree> {
905 let worktree = self
906 .worktrees
907 .get_mut(&worktree_id)
908 .ok_or_else(|| anyhow!("worktree not found"))?;
909
910 if worktree.host_connection_id == Some(connection_id)
911 || worktree.guest_connection_ids.contains_key(&connection_id)
912 {
913 Ok(worktree)
914 } else {
915 Err(anyhow!(
916 "{} is not a member of worktree {}",
917 connection_id,
918 worktree_id
919 ))?
920 }
921 }
922}
923
924impl Worktree {
925 pub fn connection_ids(&self) -> Vec<ConnectionId> {
926 self.guest_connection_ids
927 .keys()
928 .copied()
929 .chain(self.host_connection_id)
930 .collect()
931 }
932
933 fn host_connection_id(&self) -> tide::Result<ConnectionId> {
934 Ok(self
935 .host_connection_id
936 .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
937 }
938}
939
940impl Channel {
941 fn connection_ids(&self) -> Vec<ConnectionId> {
942 self.connection_ids.iter().copied().collect()
943 }
944}
945
946pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
947 let server = Server::new(app.state().clone(), rpc.clone(), None);
948 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
949 let user_id = request.ext::<UserId>().copied();
950 let server = server.clone();
951 async move {
952 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
953
954 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
955 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
956 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
957
958 if !upgrade_requested {
959 return Ok(Response::new(StatusCode::UpgradeRequired));
960 }
961
962 let header = match request.header("Sec-Websocket-Key") {
963 Some(h) => h.as_str(),
964 None => return Err(anyhow!("expected sec-websocket-key"))?,
965 };
966
967 let mut response = Response::new(StatusCode::SwitchingProtocols);
968 response.insert_header(UPGRADE, "websocket");
969 response.insert_header(CONNECTION, "Upgrade");
970 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
971 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
972 response.insert_header("Sec-Websocket-Version", "13");
973
974 let http_res: &mut tide::http::Response = response.as_mut();
975 let upgrade_receiver = http_res.recv_upgrade().await;
976 let addr = request.remote().unwrap_or("unknown").to_string();
977 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
978 task::spawn(async move {
979 if let Some(stream) = upgrade_receiver.await {
980 server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
981 }
982 });
983
984 Ok(response)
985 }
986 });
987}
988
989fn header_contains_ignore_case<T>(
990 request: &tide::Request<T>,
991 header_name: HeaderName,
992 value: &str,
993) -> bool {
994 request
995 .header(header_name)
996 .map(|h| {
997 h.as_str()
998 .split(',')
999 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1000 })
1001 .unwrap_or(false)
1002}
1003
1004#[cfg(test)]
1005mod tests {
1006 use super::*;
1007 use crate::{
1008 auth,
1009 db::{tests::TestDb, UserId},
1010 github, AppState, Config,
1011 };
1012 use async_std::{sync::RwLockReadGuard, task};
1013 use gpui::TestAppContext;
1014 use postage::mpsc;
1015 use serde_json::json;
1016 use sqlx::types::time::OffsetDateTime;
1017 use std::{path::Path, sync::Arc, time::Duration};
1018 use zed::{
1019 channel::{Channel, ChannelDetails, ChannelList},
1020 editor::{Editor, Insert},
1021 fs::{FakeFs, Fs as _},
1022 language::LanguageRegistry,
1023 rpc::Client,
1024 settings,
1025 user::UserStore,
1026 worktree::Worktree,
1027 };
1028 use zrpc::Peer;
1029
1030 #[gpui::test]
1031 async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1032 let (window_b, _) = cx_b.add_window(|_| EmptyView);
1033 let settings = cx_b.read(settings::test).1;
1034 let lang_registry = Arc::new(LanguageRegistry::new());
1035
1036 // Connect to a server as 2 clients.
1037 let mut server = TestServer::start().await;
1038 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1039 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1040
1041 cx_a.foreground().forbid_parking();
1042
1043 // Share a local worktree as client A
1044 let fs = Arc::new(FakeFs::new());
1045 fs.insert_tree(
1046 "/a",
1047 json!({
1048 "a.txt": "a-contents",
1049 "b.txt": "b-contents",
1050 }),
1051 )
1052 .await;
1053 let worktree_a = Worktree::open_local(
1054 "/a".as_ref(),
1055 lang_registry.clone(),
1056 fs,
1057 &mut cx_a.to_async(),
1058 )
1059 .await
1060 .unwrap();
1061 worktree_a
1062 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1063 .await;
1064 let (worktree_id, worktree_token) = worktree_a
1065 .update(&mut cx_a, |tree, cx| {
1066 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1067 })
1068 .await
1069 .unwrap();
1070
1071 // Join that worktree as client B, and see that a guest has joined as client A.
1072 let worktree_b = Worktree::open_remote(
1073 client_b.clone(),
1074 worktree_id,
1075 worktree_token,
1076 lang_registry.clone(),
1077 &mut cx_b.to_async(),
1078 )
1079 .await
1080 .unwrap();
1081 let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1082 worktree_a
1083 .condition(&cx_a, |tree, _| {
1084 tree.peers()
1085 .values()
1086 .any(|replica_id| *replica_id == replica_id_b)
1087 })
1088 .await;
1089
1090 // Open the same file as client B and client A.
1091 let buffer_b = worktree_b
1092 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1093 .await
1094 .unwrap();
1095 buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1096 worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1097 let buffer_a = worktree_a
1098 .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1099 .await
1100 .unwrap();
1101
1102 // Create a selection set as client B and see that selection set as client A.
1103 let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1104 buffer_a
1105 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1106 .await;
1107
1108 // Edit the buffer as client B and see that edit as client A.
1109 editor_b.update(&mut cx_b, |editor, cx| {
1110 editor.insert(&Insert("ok, ".into()), cx)
1111 });
1112 buffer_a
1113 .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1114 .await;
1115
1116 // Remove the selection set as client B, see those selections disappear as client A.
1117 cx_b.update(move |_| drop(editor_b));
1118 buffer_a
1119 .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1120 .await;
1121
1122 // Close the buffer as client A, see that the buffer is closed.
1123 drop(buffer_a);
1124 worktree_a
1125 .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1126 .await;
1127
1128 // Dropping the worktree removes client B from client A's peers.
1129 cx_b.update(move |_| drop(worktree_b));
1130 worktree_a
1131 .condition(&cx_a, |tree, _| tree.peers().is_empty())
1132 .await;
1133 }
1134
1135 #[gpui::test]
1136 async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1137 mut cx_a: TestAppContext,
1138 mut cx_b: TestAppContext,
1139 mut cx_c: TestAppContext,
1140 ) {
1141 cx_a.foreground().forbid_parking();
1142 let lang_registry = Arc::new(LanguageRegistry::new());
1143
1144 // Connect to a server as 3 clients.
1145 let mut server = TestServer::start().await;
1146 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1147 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1148 let (_, client_c) = server.create_client(&mut cx_c, "user_c").await;
1149
1150 let fs = Arc::new(FakeFs::new());
1151
1152 // Share a worktree as client A.
1153 fs.insert_tree(
1154 "/a",
1155 json!({
1156 "file1": "",
1157 "file2": ""
1158 }),
1159 )
1160 .await;
1161
1162 let worktree_a = Worktree::open_local(
1163 "/a".as_ref(),
1164 lang_registry.clone(),
1165 fs.clone(),
1166 &mut cx_a.to_async(),
1167 )
1168 .await
1169 .unwrap();
1170 worktree_a
1171 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1172 .await;
1173 let (worktree_id, worktree_token) = worktree_a
1174 .update(&mut cx_a, |tree, cx| {
1175 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1176 })
1177 .await
1178 .unwrap();
1179
1180 // Join that worktree as clients B and C.
1181 let worktree_b = Worktree::open_remote(
1182 client_b.clone(),
1183 worktree_id,
1184 worktree_token.clone(),
1185 lang_registry.clone(),
1186 &mut cx_b.to_async(),
1187 )
1188 .await
1189 .unwrap();
1190 let worktree_c = Worktree::open_remote(
1191 client_c.clone(),
1192 worktree_id,
1193 worktree_token,
1194 lang_registry.clone(),
1195 &mut cx_c.to_async(),
1196 )
1197 .await
1198 .unwrap();
1199
1200 // Open and edit a buffer as both guests B and C.
1201 let buffer_b = worktree_b
1202 .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1203 .await
1204 .unwrap();
1205 let buffer_c = worktree_c
1206 .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1207 .await
1208 .unwrap();
1209 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1210 buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1211
1212 // Open and edit that buffer as the host.
1213 let buffer_a = worktree_a
1214 .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1215 .await
1216 .unwrap();
1217
1218 buffer_a
1219 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1220 .await;
1221 buffer_a.update(&mut cx_a, |buf, cx| {
1222 buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1223 });
1224
1225 // Wait for edits to propagate
1226 buffer_a
1227 .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1228 .await;
1229 buffer_b
1230 .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1231 .await;
1232 buffer_c
1233 .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1234 .await;
1235
1236 // Edit the buffer as the host and concurrently save as guest B.
1237 let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1238 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1239 save_b.await.unwrap();
1240 assert_eq!(
1241 fs.load("/a/file1".as_ref()).await.unwrap(),
1242 "hi-a, i-am-c, i-am-b, i-am-a"
1243 );
1244 buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1245 buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1246 buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1247
1248 // Make changes on host's file system, see those changes on the guests.
1249 fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1250 .await
1251 .unwrap();
1252 fs.insert_file(Path::new("/a/file4"), "4".into())
1253 .await
1254 .unwrap();
1255
1256 worktree_b
1257 .condition(&cx_b, |tree, _| tree.file_count() == 3)
1258 .await;
1259 worktree_c
1260 .condition(&cx_c, |tree, _| tree.file_count() == 3)
1261 .await;
1262 worktree_b.read_with(&cx_b, |tree, _| {
1263 assert_eq!(
1264 tree.paths()
1265 .map(|p| p.to_string_lossy())
1266 .collect::<Vec<_>>(),
1267 &["file1", "file3", "file4"]
1268 )
1269 });
1270 worktree_c.read_with(&cx_c, |tree, _| {
1271 assert_eq!(
1272 tree.paths()
1273 .map(|p| p.to_string_lossy())
1274 .collect::<Vec<_>>(),
1275 &["file1", "file3", "file4"]
1276 )
1277 });
1278 }
1279
1280 #[gpui::test]
1281 async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1282 cx_a.foreground().forbid_parking();
1283 let lang_registry = Arc::new(LanguageRegistry::new());
1284
1285 // Connect to a server as 2 clients.
1286 let mut server = TestServer::start().await;
1287 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1288 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1289
1290 // Share a local worktree as client A
1291 let fs = Arc::new(FakeFs::new());
1292 fs.save(Path::new("/a.txt"), &"a-contents".into())
1293 .await
1294 .unwrap();
1295 let worktree_a = Worktree::open_local(
1296 "/".as_ref(),
1297 lang_registry.clone(),
1298 fs,
1299 &mut cx_a.to_async(),
1300 )
1301 .await
1302 .unwrap();
1303 worktree_a
1304 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1305 .await;
1306 let (worktree_id, worktree_token) = worktree_a
1307 .update(&mut cx_a, |tree, cx| {
1308 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1309 })
1310 .await
1311 .unwrap();
1312
1313 // Join that worktree as client B, and see that a guest has joined as client A.
1314 let worktree_b = Worktree::open_remote(
1315 client_b.clone(),
1316 worktree_id,
1317 worktree_token,
1318 lang_registry.clone(),
1319 &mut cx_b.to_async(),
1320 )
1321 .await
1322 .unwrap();
1323
1324 let buffer_b = worktree_b
1325 .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1326 .await
1327 .unwrap();
1328 let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1329
1330 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1331 buffer_b.read_with(&cx_b, |buf, _| {
1332 assert!(buf.is_dirty());
1333 assert!(!buf.has_conflict());
1334 });
1335
1336 buffer_b
1337 .update(&mut cx_b, |buf, cx| buf.save(cx))
1338 .unwrap()
1339 .await
1340 .unwrap();
1341 worktree_b
1342 .condition(&cx_b, |_, cx| {
1343 buffer_b.read(cx).file().unwrap().mtime != mtime
1344 })
1345 .await;
1346 buffer_b.read_with(&cx_b, |buf, _| {
1347 assert!(!buf.is_dirty());
1348 assert!(!buf.has_conflict());
1349 });
1350
1351 buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1352 buffer_b.read_with(&cx_b, |buf, _| {
1353 assert!(buf.is_dirty());
1354 assert!(!buf.has_conflict());
1355 });
1356 }
1357
1358 #[gpui::test]
1359 async fn test_editing_while_guest_opens_buffer(
1360 mut cx_a: TestAppContext,
1361 mut cx_b: TestAppContext,
1362 ) {
1363 cx_a.foreground().forbid_parking();
1364 let lang_registry = Arc::new(LanguageRegistry::new());
1365
1366 // Connect to a server as 2 clients.
1367 let mut server = TestServer::start().await;
1368 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1369 let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1370
1371 // Share a local worktree as client A
1372 let fs = Arc::new(FakeFs::new());
1373 fs.save(Path::new("/a.txt"), &"a-contents".into())
1374 .await
1375 .unwrap();
1376 let worktree_a = Worktree::open_local(
1377 "/".as_ref(),
1378 lang_registry.clone(),
1379 fs,
1380 &mut cx_a.to_async(),
1381 )
1382 .await
1383 .unwrap();
1384 worktree_a
1385 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1386 .await;
1387 let (worktree_id, worktree_token) = worktree_a
1388 .update(&mut cx_a, |tree, cx| {
1389 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1390 })
1391 .await
1392 .unwrap();
1393
1394 // Join that worktree as client B, and see that a guest has joined as client A.
1395 let worktree_b = Worktree::open_remote(
1396 client_b.clone(),
1397 worktree_id,
1398 worktree_token,
1399 lang_registry.clone(),
1400 &mut cx_b.to_async(),
1401 )
1402 .await
1403 .unwrap();
1404
1405 let buffer_a = worktree_a
1406 .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1407 .await
1408 .unwrap();
1409 let buffer_b = cx_b
1410 .background()
1411 .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1412
1413 task::yield_now().await;
1414 buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1415
1416 let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1417 let buffer_b = buffer_b.await.unwrap();
1418 buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1419 }
1420
1421 #[gpui::test]
1422 async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1423 cx_a.foreground().forbid_parking();
1424 let lang_registry = Arc::new(LanguageRegistry::new());
1425
1426 // Connect to a server as 2 clients.
1427 let mut server = TestServer::start().await;
1428 let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1429 let (_, client_b) = server.create_client(&mut cx_a, "user_b").await;
1430
1431 // Share a local worktree as client A
1432 let fs = Arc::new(FakeFs::new());
1433 fs.insert_tree(
1434 "/a",
1435 json!({
1436 "a.txt": "a-contents",
1437 "b.txt": "b-contents",
1438 }),
1439 )
1440 .await;
1441 let worktree_a = Worktree::open_local(
1442 "/a".as_ref(),
1443 lang_registry.clone(),
1444 fs,
1445 &mut cx_a.to_async(),
1446 )
1447 .await
1448 .unwrap();
1449 worktree_a
1450 .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1451 .await;
1452 let (worktree_id, worktree_token) = worktree_a
1453 .update(&mut cx_a, |tree, cx| {
1454 tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1455 })
1456 .await
1457 .unwrap();
1458
1459 // Join that worktree as client B, and see that a guest has joined as client A.
1460 let _worktree_b = Worktree::open_remote(
1461 client_b.clone(),
1462 worktree_id,
1463 worktree_token,
1464 lang_registry.clone(),
1465 &mut cx_b.to_async(),
1466 )
1467 .await
1468 .unwrap();
1469 worktree_a
1470 .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1471 .await;
1472
1473 // Drop client B's connection and ensure client A observes client B leaving the worktree.
1474 client_b.disconnect(&cx_b.to_async()).await.unwrap();
1475 worktree_a
1476 .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1477 .await;
1478 }
1479
1480 #[gpui::test]
1481 async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1482 cx_a.foreground().forbid_parking();
1483
1484 // Connect to a server as 2 clients.
1485 let mut server = TestServer::start().await;
1486 let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
1487 let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
1488
1489 // Create an org that includes these 2 users.
1490 let db = &server.app_state.db;
1491 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1492 db.add_org_member(org_id, user_id_a, false).await.unwrap();
1493 db.add_org_member(org_id, user_id_b, false).await.unwrap();
1494
1495 // Create a channel that includes all the users.
1496 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1497 db.add_channel_member(channel_id, user_id_a, false)
1498 .await
1499 .unwrap();
1500 db.add_channel_member(channel_id, user_id_b, false)
1501 .await
1502 .unwrap();
1503 db.create_channel_message(
1504 channel_id,
1505 user_id_b,
1506 "hello A, it's B.",
1507 OffsetDateTime::now_utc(),
1508 )
1509 .await
1510 .unwrap();
1511
1512 let user_store_a = Arc::new(UserStore::new(client_a.clone()));
1513 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1514 channels_a
1515 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1516 .await;
1517 channels_a.read_with(&cx_a, |list, _| {
1518 assert_eq!(
1519 list.available_channels().unwrap(),
1520 &[ChannelDetails {
1521 id: channel_id.to_proto(),
1522 name: "test-channel".to_string()
1523 }]
1524 )
1525 });
1526 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1527 this.get_channel(channel_id.to_proto(), cx).unwrap()
1528 });
1529 channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1530 channel_a
1531 .condition(&cx_a, |channel, _| {
1532 channel_messages(channel)
1533 == [("user_b".to_string(), "hello A, it's B.".to_string())]
1534 })
1535 .await;
1536
1537 let user_store_b = Arc::new(UserStore::new(client_b.clone()));
1538 let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1539 channels_b
1540 .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1541 .await;
1542 channels_b.read_with(&cx_b, |list, _| {
1543 assert_eq!(
1544 list.available_channels().unwrap(),
1545 &[ChannelDetails {
1546 id: channel_id.to_proto(),
1547 name: "test-channel".to_string()
1548 }]
1549 )
1550 });
1551
1552 let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1553 this.get_channel(channel_id.to_proto(), cx).unwrap()
1554 });
1555 channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1556 channel_b
1557 .condition(&cx_b, |channel, _| {
1558 channel_messages(channel)
1559 == [("user_b".to_string(), "hello A, it's B.".to_string())]
1560 })
1561 .await;
1562
1563 channel_a
1564 .update(&mut cx_a, |channel, cx| {
1565 channel
1566 .send_message("oh, hi B.".to_string(), cx)
1567 .unwrap()
1568 .detach();
1569 let task = channel.send_message("sup".to_string(), cx).unwrap();
1570 assert_eq!(
1571 channel
1572 .pending_messages()
1573 .iter()
1574 .map(|m| &m.body)
1575 .collect::<Vec<_>>(),
1576 &["oh, hi B.", "sup"]
1577 );
1578 task
1579 })
1580 .await
1581 .unwrap();
1582
1583 channel_a
1584 .condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
1585 .await;
1586 channel_b
1587 .condition(&cx_b, |channel, _| {
1588 channel_messages(channel)
1589 == [
1590 ("user_b".to_string(), "hello A, it's B.".to_string()),
1591 ("user_a".to_string(), "oh, hi B.".to_string()),
1592 ("user_a".to_string(), "sup".to_string()),
1593 ]
1594 })
1595 .await;
1596
1597 assert_eq!(
1598 server.state().await.channels[&channel_id]
1599 .connection_ids
1600 .len(),
1601 2
1602 );
1603 cx_b.update(|_| drop(channel_b));
1604 server
1605 .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1606 .await;
1607
1608 cx_a.update(|_| drop(channel_a));
1609 server
1610 .condition(|state| !state.channels.contains_key(&channel_id))
1611 .await;
1612
1613 fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
1614 channel
1615 .messages()
1616 .cursor::<(), ()>()
1617 .map(|m| (m.sender.github_login.clone(), m.body.clone()))
1618 .collect()
1619 }
1620 }
1621
1622 #[gpui::test]
1623 async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1624 cx_a.foreground().forbid_parking();
1625
1626 let mut server = TestServer::start().await;
1627 let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
1628
1629 let db = &server.app_state.db;
1630 let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1631 let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1632 db.add_org_member(org_id, user_id_a, false).await.unwrap();
1633 db.add_channel_member(channel_id, user_id_a, false)
1634 .await
1635 .unwrap();
1636
1637 let user_store_a = Arc::new(UserStore::new(client_a.clone()));
1638 let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1639 channels_a
1640 .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1641 .await;
1642 let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1643 this.get_channel(channel_id.to_proto(), cx).unwrap()
1644 });
1645
1646 // Messages aren't allowed to be too long.
1647 channel_a
1648 .update(&mut cx_a, |channel, cx| {
1649 let long_body = "this is long.\n".repeat(1024);
1650 channel.send_message(long_body, cx).unwrap()
1651 })
1652 .await
1653 .unwrap_err();
1654
1655 // Messages aren't allowed to be blank.
1656 channel_a.update(&mut cx_a, |channel, cx| {
1657 channel.send_message(String::new(), cx).unwrap_err()
1658 });
1659
1660 // Leading and trailing whitespace are trimmed.
1661 channel_a
1662 .update(&mut cx_a, |channel, cx| {
1663 channel
1664 .send_message("\n surrounded by whitespace \n".to_string(), cx)
1665 .unwrap()
1666 })
1667 .await
1668 .unwrap();
1669 assert_eq!(
1670 db.get_channel_messages(channel_id, 10, None)
1671 .await
1672 .unwrap()
1673 .iter()
1674 .map(|m| &m.body)
1675 .collect::<Vec<_>>(),
1676 &["surrounded by whitespace"]
1677 );
1678 }
1679
1680 struct TestServer {
1681 peer: Arc<Peer>,
1682 app_state: Arc<AppState>,
1683 server: Arc<Server>,
1684 notifications: mpsc::Receiver<()>,
1685 _test_db: TestDb,
1686 }
1687
1688 impl TestServer {
1689 async fn start() -> Self {
1690 let test_db = TestDb::new();
1691 let app_state = Self::build_app_state(&test_db).await;
1692 let peer = Peer::new();
1693 let notifications = mpsc::channel(128);
1694 let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
1695 Self {
1696 peer,
1697 app_state,
1698 server,
1699 notifications: notifications.1,
1700 _test_db: test_db,
1701 }
1702 }
1703
1704 async fn create_client(
1705 &mut self,
1706 cx: &mut TestAppContext,
1707 name: &str,
1708 ) -> (UserId, Arc<Client>) {
1709 let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
1710 let client_name = name.to_string();
1711 let mut client = Client::new();
1712 let server = self.server.clone();
1713 Arc::get_mut(&mut client)
1714 .unwrap()
1715 .set_login_and_connect_callbacks(
1716 move |cx| {
1717 cx.spawn(|_| async move {
1718 let access_token = "the-token".to_string();
1719 Ok((client_user_id.0 as u64, access_token))
1720 })
1721 },
1722 {
1723 move |user_id, access_token, cx| {
1724 assert_eq!(user_id, client_user_id.0 as u64);
1725 assert_eq!(access_token, "the-token");
1726
1727 let server = server.clone();
1728 let client_name = client_name.clone();
1729 cx.spawn(move |cx| async move {
1730 let (client_conn, server_conn) = Conn::in_memory();
1731 cx.background()
1732 .spawn(server.handle_connection(
1733 server_conn,
1734 client_name,
1735 client_user_id,
1736 ))
1737 .detach();
1738 Ok(client_conn)
1739 })
1740 }
1741 },
1742 );
1743
1744 client
1745 .authenticate_and_connect(&cx.to_async())
1746 .await
1747 .unwrap();
1748 (client_user_id, client)
1749 }
1750
1751 async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
1752 let mut config = Config::default();
1753 config.session_secret = "a".repeat(32);
1754 config.database_url = test_db.url.clone();
1755 let github_client = github::AppClient::test();
1756 Arc::new(AppState {
1757 db: test_db.db().clone(),
1758 handlebars: Default::default(),
1759 auth_client: auth::build_client("", ""),
1760 repo_client: github::RepoClient::test(&github_client),
1761 github_client,
1762 config,
1763 })
1764 }
1765
1766 async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
1767 self.server.state.read().await
1768 }
1769
1770 async fn condition<F>(&mut self, mut predicate: F)
1771 where
1772 F: FnMut(&ServerState) -> bool,
1773 {
1774 async_std::future::timeout(Duration::from_millis(500), async {
1775 while !(predicate)(&*self.server.state.read().await) {
1776 self.notifications.recv().await;
1777 }
1778 })
1779 .await
1780 .expect("condition timed out");
1781 }
1782 }
1783
1784 impl Drop for TestServer {
1785 fn drop(&mut self) {
1786 task::block_on(self.peer.reset());
1787 }
1788 }
1789
1790 struct EmptyView;
1791
1792 impl gpui::Entity for EmptyView {
1793 type Event = ();
1794 }
1795
1796 impl gpui::View for EmptyView {
1797 fn ui_name() -> &'static str {
1798 "empty view"
1799 }
1800
1801 fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
1802 gpui::Element::boxed(gpui::elements::Empty)
1803 }
1804 }
1805}