rpc.rs

   1use super::{
   2    auth,
   3    db::{ChannelId, MessageId, UserId},
   4    AppState,
   5};
   6use anyhow::anyhow;
   7use async_std::{sync::RwLock, task};
   8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
   9use futures::{future::BoxFuture, FutureExt};
  10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  11use sha1::{Digest as _, Sha1};
  12use std::{
  13    any::TypeId,
  14    collections::{hash_map, HashMap, HashSet},
  15    future::Future,
  16    mem,
  17    sync::Arc,
  18    time::Instant,
  19};
  20use surf::StatusCode;
  21use tide::log;
  22use tide::{
  23    http::headers::{HeaderName, CONNECTION, UPGRADE},
  24    Request, Response,
  25};
  26use time::OffsetDateTime;
  27use zrpc::{
  28    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  29    Connection, ConnectionId, Peer, TypedEnvelope,
  30};
  31
  32type ReplicaId = u16;
  33
  34type MessageHandler = Box<
  35    dyn Send
  36        + Sync
  37        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  38>;
  39
  40pub struct Server {
  41    peer: Arc<Peer>,
  42    state: RwLock<ServerState>,
  43    app_state: Arc<AppState>,
  44    handlers: HashMap<TypeId, MessageHandler>,
  45    notifications: Option<mpsc::Sender<()>>,
  46}
  47
  48#[derive(Default)]
  49struct ServerState {
  50    connections: HashMap<ConnectionId, ConnectionState>,
  51    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
  52    pub worktrees: HashMap<u64, Worktree>,
  53    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
  54    channels: HashMap<ChannelId, Channel>,
  55    next_worktree_id: u64,
  56}
  57
  58struct ConnectionState {
  59    user_id: UserId,
  60    worktrees: HashSet<u64>,
  61    channels: HashSet<ChannelId>,
  62}
  63
  64struct Worktree {
  65    host_connection_id: ConnectionId,
  66    collaborator_user_ids: Vec<UserId>,
  67    root_name: String,
  68    share: Option<WorktreeShare>,
  69}
  70
  71struct WorktreeShare {
  72    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
  73    active_replica_ids: HashSet<ReplicaId>,
  74    entries: HashMap<u64, proto::Entry>,
  75}
  76
  77#[derive(Default)]
  78struct Channel {
  79    connection_ids: HashSet<ConnectionId>,
  80}
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84
  85impl Server {
  86    pub fn new(
  87        app_state: Arc<AppState>,
  88        peer: Arc<Peer>,
  89        notifications: Option<mpsc::Sender<()>>,
  90    ) -> Arc<Self> {
  91        let mut server = Self {
  92            peer,
  93            app_state,
  94            state: Default::default(),
  95            handlers: Default::default(),
  96            notifications,
  97        };
  98
  99        server
 100            .add_handler(Server::ping)
 101            .add_handler(Server::open_worktree)
 102            .add_handler(Server::close_worktree)
 103            .add_handler(Server::share_worktree)
 104            .add_handler(Server::unshare_worktree)
 105            .add_handler(Server::join_worktree)
 106            .add_handler(Server::update_worktree)
 107            .add_handler(Server::open_buffer)
 108            .add_handler(Server::close_buffer)
 109            .add_handler(Server::update_buffer)
 110            .add_handler(Server::buffer_saved)
 111            .add_handler(Server::save_buffer)
 112            .add_handler(Server::get_channels)
 113            .add_handler(Server::get_users)
 114            .add_handler(Server::join_channel)
 115            .add_handler(Server::leave_channel)
 116            .add_handler(Server::send_channel_message)
 117            .add_handler(Server::get_channel_messages);
 118
 119        Arc::new(server)
 120    }
 121
 122    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 123    where
 124        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 125        Fut: 'static + Send + Future<Output = tide::Result<()>>,
 126        M: EnvelopedMessage,
 127    {
 128        let prev_handler = self.handlers.insert(
 129            TypeId::of::<M>(),
 130            Box::new(move |server, envelope| {
 131                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 132                (handler)(server, *envelope).boxed()
 133            }),
 134        );
 135        if prev_handler.is_some() {
 136            panic!("registered a handler for the same message twice");
 137        }
 138        self
 139    }
 140
 141    pub fn handle_connection(
 142        self: &Arc<Self>,
 143        connection: Connection,
 144        addr: String,
 145        user_id: UserId,
 146    ) -> impl Future<Output = ()> {
 147        let this = self.clone();
 148        async move {
 149            let (connection_id, handle_io, mut incoming_rx) =
 150                this.peer.add_connection(connection).await;
 151            this.add_connection(connection_id, user_id).await;
 152
 153            let handle_io = handle_io.fuse();
 154            futures::pin_mut!(handle_io);
 155            loop {
 156                let next_message = incoming_rx.recv().fuse();
 157                futures::pin_mut!(next_message);
 158                futures::select_biased! {
 159                    message = next_message => {
 160                        if let Some(message) = message {
 161                            let start_time = Instant::now();
 162                            log::info!("RPC message received: {}", message.payload_type_name());
 163                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 164                                if let Err(err) = (handler)(this.clone(), message).await {
 165                                    log::error!("error handling message: {:?}", err);
 166                                } else {
 167                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 168                                }
 169
 170                                if let Some(mut notifications) = this.notifications.clone() {
 171                                    let _ = notifications.send(()).await;
 172                                }
 173                            } else {
 174                                log::warn!("unhandled message: {}", message.payload_type_name());
 175                            }
 176                        } else {
 177                            log::info!("rpc connection closed {:?}", addr);
 178                            break;
 179                        }
 180                    }
 181                    handle_io = handle_io => {
 182                        if let Err(err) = handle_io {
 183                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 184                        }
 185                        break;
 186                    }
 187                }
 188            }
 189
 190            if let Err(err) = this.sign_out(connection_id).await {
 191                log::error!("error signing out connection {:?} - {:?}", addr, err);
 192            }
 193        }
 194    }
 195
 196    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
 197        self.peer.disconnect(connection_id).await;
 198        let worktree_ids = self.remove_connection(connection_id).await;
 199        for worktree_id in worktree_ids {
 200            let state = self.state.read().await;
 201            if let Some(worktree) = state.worktrees.get(&worktree_id) {
 202                broadcast(connection_id, worktree.connection_ids(), |conn_id| {
 203                    self.peer.send(
 204                        conn_id,
 205                        proto::RemovePeer {
 206                            worktree_id,
 207                            peer_id: connection_id.0,
 208                        },
 209                    )
 210                })
 211                .await?;
 212            }
 213        }
 214        Ok(())
 215    }
 216
 217    // Add a new connection associated with a given user.
 218    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
 219        let mut state = self.state.write().await;
 220        state.connections.insert(
 221            connection_id,
 222            ConnectionState {
 223                user_id,
 224                worktrees: Default::default(),
 225                channels: Default::default(),
 226            },
 227        );
 228        state
 229            .connections_by_user_id
 230            .entry(user_id)
 231            .or_default()
 232            .insert(connection_id);
 233    }
 234
 235    // Remove the given connection and its association with any worktrees.
 236    async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
 237        let mut worktree_ids = Vec::new();
 238        let mut state = self.state.write().await;
 239        if let Some(connection) = state.connections.remove(&connection_id) {
 240            for channel_id in connection.channels {
 241                if let Some(channel) = state.channels.get_mut(&channel_id) {
 242                    channel.connection_ids.remove(&connection_id);
 243                }
 244            }
 245            for worktree_id in connection.worktrees {
 246                if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
 247                    if worktree.host_connection_id == connection_id {
 248                        worktree_ids.push(worktree_id);
 249                    } else if let Some(share_state) = worktree.share.as_mut() {
 250                        if let Some(replica_id) =
 251                            share_state.guest_connection_ids.remove(&connection_id)
 252                        {
 253                            share_state.active_replica_ids.remove(&replica_id);
 254                            worktree_ids.push(worktree_id);
 255                        }
 256                    }
 257                }
 258            }
 259
 260            let user_connections = state
 261                .connections_by_user_id
 262                .get_mut(&connection.user_id)
 263                .unwrap();
 264            user_connections.remove(&connection_id);
 265            if user_connections.is_empty() {
 266                state.connections_by_user_id.remove(&connection.user_id);
 267            }
 268        }
 269        worktree_ids
 270    }
 271
 272    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
 273        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 274        Ok(())
 275    }
 276
 277    async fn open_worktree(
 278        self: Arc<Server>,
 279        request: TypedEnvelope<proto::OpenWorktree>,
 280    ) -> tide::Result<()> {
 281        let receipt = request.receipt();
 282        let user_id = self
 283            .state
 284            .read()
 285            .await
 286            .user_id_for_connection(request.sender_id)?;
 287
 288        let mut collaborator_user_ids = Vec::new();
 289        for github_login in request.payload.collaborator_logins {
 290            match self.app_state.db.create_user(&github_login, false).await {
 291                Ok(collaborator_user_id) => {
 292                    if collaborator_user_id != user_id {
 293                        collaborator_user_ids.push(collaborator_user_id);
 294                    }
 295                }
 296                Err(err) => {
 297                    let message = err.to_string();
 298                    self.peer
 299                        .respond_with_error(receipt, proto::Error { message })
 300                        .await?;
 301                    return Ok(());
 302                }
 303            }
 304        }
 305
 306        let mut state = self.state.write().await;
 307        let worktree_id = state.add_worktree(Worktree {
 308            host_connection_id: request.sender_id,
 309            collaborator_user_ids: collaborator_user_ids.clone(),
 310            root_name: request.payload.root_name,
 311            share: None,
 312        });
 313
 314        self.peer
 315            .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
 316            .await?;
 317        self.update_collaborators(&collaborator_user_ids).await?;
 318
 319        Ok(())
 320    }
 321
 322    async fn share_worktree(
 323        self: Arc<Server>,
 324        mut request: TypedEnvelope<proto::ShareWorktree>,
 325    ) -> tide::Result<()> {
 326        let worktree = request
 327            .payload
 328            .worktree
 329            .as_mut()
 330            .ok_or_else(|| anyhow!("missing worktree"))?;
 331        let entries = mem::take(&mut worktree.entries)
 332            .into_iter()
 333            .map(|entry| (entry.id, entry))
 334            .collect();
 335        let mut state = self.state.write().await;
 336        if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
 337            worktree.share = Some(WorktreeShare {
 338                guest_connection_ids: Default::default(),
 339                active_replica_ids: Default::default(),
 340                entries,
 341            });
 342            self.peer
 343                .respond(request.receipt(), proto::ShareWorktreeResponse {})
 344                .await?;
 345
 346            let collaborator_user_ids = worktree.collaborator_user_ids.clone();
 347            drop(state);
 348            self.update_collaborators(&collaborator_user_ids).await?;
 349        } else {
 350            self.peer
 351                .respond_with_error(
 352                    request.receipt(),
 353                    proto::Error {
 354                        message: "no such worktree".to_string(),
 355                    },
 356                )
 357                .await?;
 358        }
 359        Ok(())
 360    }
 361
 362    async fn unshare_worktree(
 363        self: Arc<Server>,
 364        request: TypedEnvelope<proto::UnshareWorktree>,
 365    ) -> tide::Result<()> {
 366        let worktree_id = request.payload.worktree_id;
 367
 368        let connection_ids;
 369        let collaborator_user_ids;
 370        {
 371            let mut state = self.state.write().await;
 372            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 373            if worktree.host_connection_id != request.sender_id {
 374                return Err(anyhow!("no such worktree"))?;
 375            }
 376
 377            connection_ids = worktree.connection_ids();
 378            collaborator_user_ids = worktree.collaborator_user_ids.clone();
 379            worktree.share.take();
 380            for connection_id in &connection_ids {
 381                if let Some(connection) = state.connections.get_mut(connection_id) {
 382                    connection.worktrees.remove(&worktree_id);
 383                }
 384            }
 385        }
 386
 387        broadcast(request.sender_id, connection_ids, |conn_id| {
 388            self.peer
 389                .send(conn_id, proto::UnshareWorktree { worktree_id })
 390        })
 391        .await?;
 392        self.update_collaborators(&collaborator_user_ids).await?;
 393
 394        Ok(())
 395    }
 396
 397    async fn join_worktree(
 398        self: Arc<Server>,
 399        request: TypedEnvelope<proto::JoinWorktree>,
 400    ) -> tide::Result<()> {
 401        let worktree_id = request.payload.worktree_id;
 402        let user_id = self
 403            .state
 404            .read()
 405            .await
 406            .user_id_for_connection(request.sender_id)?;
 407
 408        let response;
 409        let connection_ids;
 410        let collaborator_user_ids;
 411        let mut state = self.state.write().await;
 412        match state.join_worktree(request.sender_id, user_id, worktree_id) {
 413            Ok((peer_replica_id, worktree)) => {
 414                let share = worktree.share()?;
 415                let peer_count = share.guest_connection_ids.len();
 416                let mut peers = Vec::with_capacity(peer_count);
 417                peers.push(proto::Peer {
 418                    peer_id: worktree.host_connection_id.0,
 419                    replica_id: 0,
 420                });
 421                for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
 422                    if *peer_conn_id != request.sender_id {
 423                        peers.push(proto::Peer {
 424                            peer_id: peer_conn_id.0,
 425                            replica_id: *peer_replica_id as u32,
 426                        });
 427                    }
 428                }
 429                connection_ids = worktree.connection_ids();
 430                collaborator_user_ids = worktree.collaborator_user_ids.clone();
 431                response = proto::JoinWorktreeResponse {
 432                    worktree: Some(proto::Worktree {
 433                        id: worktree_id,
 434                        root_name: worktree.root_name.clone(),
 435                        entries: share.entries.values().cloned().collect(),
 436                    }),
 437                    replica_id: peer_replica_id as u32,
 438                    peers,
 439                };
 440            }
 441            Err(error) => {
 442                self.peer
 443                    .respond_with_error(
 444                        request.receipt(),
 445                        proto::Error {
 446                            message: error.to_string(),
 447                        },
 448                    )
 449                    .await?;
 450                return Ok(());
 451            }
 452        }
 453
 454        broadcast(request.sender_id, connection_ids, |conn_id| {
 455            self.peer.send(
 456                conn_id,
 457                proto::AddPeer {
 458                    worktree_id,
 459                    peer: Some(proto::Peer {
 460                        peer_id: request.sender_id.0,
 461                        replica_id: response.replica_id,
 462                    }),
 463                },
 464            )
 465        })
 466        .await?;
 467        self.peer.respond(request.receipt(), response).await?;
 468        self.update_collaborators(&collaborator_user_ids).await?;
 469
 470        Ok(())
 471    }
 472
 473    async fn close_worktree(
 474        self: Arc<Server>,
 475        request: TypedEnvelope<proto::CloseWorktree>,
 476    ) -> tide::Result<()> {
 477        let worktree_id = request.payload.worktree_id;
 478        let connection_ids;
 479        let mut is_host = false;
 480        let mut is_guest = false;
 481        {
 482            let mut state = self.state.write().await;
 483            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 484            connection_ids = worktree.connection_ids();
 485
 486            if worktree.host_connection_id == request.sender_id {
 487                is_host = true;
 488                state.remove_worktree(worktree_id);
 489            } else {
 490                let share = worktree.share_mut()?;
 491                if let Some(replica_id) = share.guest_connection_ids.remove(&request.sender_id) {
 492                    is_guest = true;
 493                    share.active_replica_ids.remove(&replica_id);
 494                }
 495            }
 496        }
 497
 498        if is_host {
 499            broadcast(request.sender_id, connection_ids, |conn_id| {
 500                self.peer
 501                    .send(conn_id, proto::UnshareWorktree { worktree_id })
 502            })
 503            .await?;
 504        } else if is_guest {
 505            broadcast(request.sender_id, connection_ids, |conn_id| {
 506                self.peer.send(
 507                    conn_id,
 508                    proto::RemovePeer {
 509                        worktree_id,
 510                        peer_id: request.sender_id.0,
 511                    },
 512                )
 513            })
 514            .await?
 515        }
 516
 517        Ok(())
 518    }
 519
 520    async fn update_worktree(
 521        self: Arc<Server>,
 522        request: TypedEnvelope<proto::UpdateWorktree>,
 523    ) -> tide::Result<()> {
 524        {
 525            let mut state = self.state.write().await;
 526            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 527            let share = worktree.share_mut()?;
 528
 529            for entry_id in &request.payload.removed_entries {
 530                share.entries.remove(&entry_id);
 531            }
 532
 533            for entry in &request.payload.updated_entries {
 534                share.entries.insert(entry.id, entry.clone());
 535            }
 536        }
 537
 538        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 539            .await?;
 540        Ok(())
 541    }
 542
 543    async fn open_buffer(
 544        self: Arc<Server>,
 545        request: TypedEnvelope<proto::OpenBuffer>,
 546    ) -> tide::Result<()> {
 547        let receipt = request.receipt();
 548        let worktree_id = request.payload.worktree_id;
 549        let host_connection_id = self
 550            .state
 551            .read()
 552            .await
 553            .read_worktree(worktree_id, request.sender_id)?
 554            .host_connection_id;
 555
 556        let response = self
 557            .peer
 558            .forward_request(request.sender_id, host_connection_id, request.payload)
 559            .await?;
 560        self.peer.respond(receipt, response).await?;
 561        Ok(())
 562    }
 563
 564    async fn close_buffer(
 565        self: Arc<Server>,
 566        request: TypedEnvelope<proto::CloseBuffer>,
 567    ) -> tide::Result<()> {
 568        let host_connection_id = self
 569            .state
 570            .read()
 571            .await
 572            .read_worktree(request.payload.worktree_id, request.sender_id)?
 573            .host_connection_id;
 574
 575        self.peer
 576            .forward_send(request.sender_id, host_connection_id, request.payload)
 577            .await?;
 578
 579        Ok(())
 580    }
 581
 582    async fn save_buffer(
 583        self: Arc<Server>,
 584        request: TypedEnvelope<proto::SaveBuffer>,
 585    ) -> tide::Result<()> {
 586        let host;
 587        let guests;
 588        {
 589            let state = self.state.read().await;
 590            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
 591            host = worktree.host_connection_id;
 592            guests = worktree
 593                .share()?
 594                .guest_connection_ids
 595                .keys()
 596                .copied()
 597                .collect::<Vec<_>>();
 598        }
 599
 600        let sender = request.sender_id;
 601        let receipt = request.receipt();
 602        let response = self
 603            .peer
 604            .forward_request(sender, host, request.payload.clone())
 605            .await?;
 606
 607        broadcast(host, guests, |conn_id| {
 608            let response = response.clone();
 609            let peer = &self.peer;
 610            async move {
 611                if conn_id == sender {
 612                    peer.respond(receipt, response).await
 613                } else {
 614                    peer.forward_send(host, conn_id, response).await
 615                }
 616            }
 617        })
 618        .await?;
 619
 620        Ok(())
 621    }
 622
 623    async fn update_buffer(
 624        self: Arc<Server>,
 625        request: TypedEnvelope<proto::UpdateBuffer>,
 626    ) -> tide::Result<()> {
 627        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 628            .await?;
 629        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 630        Ok(())
 631    }
 632
 633    async fn buffer_saved(
 634        self: Arc<Server>,
 635        request: TypedEnvelope<proto::BufferSaved>,
 636    ) -> tide::Result<()> {
 637        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 638            .await
 639    }
 640
 641    async fn get_channels(
 642        self: Arc<Server>,
 643        request: TypedEnvelope<proto::GetChannels>,
 644    ) -> tide::Result<()> {
 645        let user_id = self
 646            .state
 647            .read()
 648            .await
 649            .user_id_for_connection(request.sender_id)?;
 650        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 651        self.peer
 652            .respond(
 653                request.receipt(),
 654                proto::GetChannelsResponse {
 655                    channels: channels
 656                        .into_iter()
 657                        .map(|chan| proto::Channel {
 658                            id: chan.id.to_proto(),
 659                            name: chan.name,
 660                        })
 661                        .collect(),
 662                },
 663            )
 664            .await?;
 665        Ok(())
 666    }
 667
 668    async fn get_users(
 669        self: Arc<Server>,
 670        request: TypedEnvelope<proto::GetUsers>,
 671    ) -> tide::Result<()> {
 672        let user_id = self
 673            .state
 674            .read()
 675            .await
 676            .user_id_for_connection(request.sender_id)?;
 677        let receipt = request.receipt();
 678        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 679        let users = self
 680            .app_state
 681            .db
 682            .get_users_by_ids(user_id, user_ids)
 683            .await?
 684            .into_iter()
 685            .map(|user| proto::User {
 686                id: user.id.to_proto(),
 687                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
 688                github_login: user.github_login,
 689            })
 690            .collect();
 691        self.peer
 692            .respond(receipt, proto::GetUsersResponse { users })
 693            .await?;
 694        Ok(())
 695    }
 696
 697    async fn update_collaborators(self: &Arc<Server>, user_ids: &[UserId]) -> tide::Result<()> {
 698        let mut send_futures = Vec::new();
 699
 700        let state = self.state.read().await;
 701        for user_id in user_ids {
 702            let mut collaborators = HashMap::new();
 703            for worktree_id in state
 704                .visible_worktrees_by_user_id
 705                .get(&user_id)
 706                .unwrap_or(&HashSet::new())
 707            {
 708                let worktree = &state.worktrees[worktree_id];
 709
 710                let mut participants = HashSet::new();
 711                if let Ok(share) = worktree.share() {
 712                    for guest_connection_id in share.guest_connection_ids.keys() {
 713                        let user_id = state.user_id_for_connection(*guest_connection_id)?;
 714                        participants.insert(user_id.to_proto());
 715                    }
 716                }
 717
 718                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
 719                let host =
 720                    collaborators
 721                        .entry(host_user_id)
 722                        .or_insert_with(|| proto::Collaborator {
 723                            user_id: host_user_id.to_proto(),
 724                            worktrees: Vec::new(),
 725                        });
 726                host.worktrees.push(proto::WorktreeMetadata {
 727                    root_name: worktree.root_name.clone(),
 728                    is_shared: worktree.share().is_ok(),
 729                    participants: participants.into_iter().collect(),
 730                });
 731            }
 732
 733            let connection_ids = self
 734                .state
 735                .read()
 736                .await
 737                .user_connection_ids(*user_id)
 738                .collect::<Vec<_>>();
 739
 740            let collaborators = collaborators.into_values().collect::<Vec<_>>();
 741            for connection_id in connection_ids {
 742                send_futures.push(self.peer.send(
 743                    connection_id,
 744                    proto::UpdateCollaborators {
 745                        collaborators: collaborators.clone(),
 746                    },
 747                ));
 748            }
 749        }
 750
 751        futures::future::try_join_all(send_futures).await?;
 752
 753        Ok(())
 754    }
 755
 756    async fn join_channel(
 757        self: Arc<Self>,
 758        request: TypedEnvelope<proto::JoinChannel>,
 759    ) -> tide::Result<()> {
 760        let user_id = self
 761            .state
 762            .read()
 763            .await
 764            .user_id_for_connection(request.sender_id)?;
 765        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 766        if !self
 767            .app_state
 768            .db
 769            .can_user_access_channel(user_id, channel_id)
 770            .await?
 771        {
 772            Err(anyhow!("access denied"))?;
 773        }
 774
 775        self.state
 776            .write()
 777            .await
 778            .join_channel(request.sender_id, channel_id);
 779        let messages = self
 780            .app_state
 781            .db
 782            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 783            .await?
 784            .into_iter()
 785            .map(|msg| proto::ChannelMessage {
 786                id: msg.id.to_proto(),
 787                body: msg.body,
 788                timestamp: msg.sent_at.unix_timestamp() as u64,
 789                sender_id: msg.sender_id.to_proto(),
 790                nonce: Some(msg.nonce.as_u128().into()),
 791            })
 792            .collect::<Vec<_>>();
 793        self.peer
 794            .respond(
 795                request.receipt(),
 796                proto::JoinChannelResponse {
 797                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 798                    messages,
 799                },
 800            )
 801            .await?;
 802        Ok(())
 803    }
 804
 805    async fn leave_channel(
 806        self: Arc<Self>,
 807        request: TypedEnvelope<proto::LeaveChannel>,
 808    ) -> tide::Result<()> {
 809        let user_id = self
 810            .state
 811            .read()
 812            .await
 813            .user_id_for_connection(request.sender_id)?;
 814        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 815        if !self
 816            .app_state
 817            .db
 818            .can_user_access_channel(user_id, channel_id)
 819            .await?
 820        {
 821            Err(anyhow!("access denied"))?;
 822        }
 823
 824        self.state
 825            .write()
 826            .await
 827            .leave_channel(request.sender_id, channel_id);
 828
 829        Ok(())
 830    }
 831
 832    async fn send_channel_message(
 833        self: Arc<Self>,
 834        request: TypedEnvelope<proto::SendChannelMessage>,
 835    ) -> tide::Result<()> {
 836        let receipt = request.receipt();
 837        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 838        let user_id;
 839        let connection_ids;
 840        {
 841            let state = self.state.read().await;
 842            user_id = state.user_id_for_connection(request.sender_id)?;
 843            if let Some(channel) = state.channels.get(&channel_id) {
 844                connection_ids = channel.connection_ids();
 845            } else {
 846                return Ok(());
 847            }
 848        }
 849
 850        // Validate the message body.
 851        let body = request.payload.body.trim().to_string();
 852        if body.len() > MAX_MESSAGE_LEN {
 853            self.peer
 854                .respond_with_error(
 855                    receipt,
 856                    proto::Error {
 857                        message: "message is too long".to_string(),
 858                    },
 859                )
 860                .await?;
 861            return Ok(());
 862        }
 863        if body.is_empty() {
 864            self.peer
 865                .respond_with_error(
 866                    receipt,
 867                    proto::Error {
 868                        message: "message can't be blank".to_string(),
 869                    },
 870                )
 871                .await?;
 872            return Ok(());
 873        }
 874
 875        let timestamp = OffsetDateTime::now_utc();
 876        let nonce = if let Some(nonce) = request.payload.nonce {
 877            nonce
 878        } else {
 879            self.peer
 880                .respond_with_error(
 881                    receipt,
 882                    proto::Error {
 883                        message: "nonce can't be blank".to_string(),
 884                    },
 885                )
 886                .await?;
 887            return Ok(());
 888        };
 889
 890        let message_id = self
 891            .app_state
 892            .db
 893            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
 894            .await?
 895            .to_proto();
 896        let message = proto::ChannelMessage {
 897            sender_id: user_id.to_proto(),
 898            id: message_id,
 899            body,
 900            timestamp: timestamp.unix_timestamp() as u64,
 901            nonce: Some(nonce),
 902        };
 903        broadcast(request.sender_id, connection_ids, |conn_id| {
 904            self.peer.send(
 905                conn_id,
 906                proto::ChannelMessageSent {
 907                    channel_id: channel_id.to_proto(),
 908                    message: Some(message.clone()),
 909                },
 910            )
 911        })
 912        .await?;
 913        self.peer
 914            .respond(
 915                receipt,
 916                proto::SendChannelMessageResponse {
 917                    message: Some(message),
 918                },
 919            )
 920            .await?;
 921        Ok(())
 922    }
 923
 924    async fn get_channel_messages(
 925        self: Arc<Self>,
 926        request: TypedEnvelope<proto::GetChannelMessages>,
 927    ) -> tide::Result<()> {
 928        let user_id = self
 929            .state
 930            .read()
 931            .await
 932            .user_id_for_connection(request.sender_id)?;
 933        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 934        if !self
 935            .app_state
 936            .db
 937            .can_user_access_channel(user_id, channel_id)
 938            .await?
 939        {
 940            Err(anyhow!("access denied"))?;
 941        }
 942
 943        let messages = self
 944            .app_state
 945            .db
 946            .get_channel_messages(
 947                channel_id,
 948                MESSAGE_COUNT_PER_PAGE,
 949                Some(MessageId::from_proto(request.payload.before_message_id)),
 950            )
 951            .await?
 952            .into_iter()
 953            .map(|msg| proto::ChannelMessage {
 954                id: msg.id.to_proto(),
 955                body: msg.body,
 956                timestamp: msg.sent_at.unix_timestamp() as u64,
 957                sender_id: msg.sender_id.to_proto(),
 958                nonce: Some(msg.nonce.as_u128().into()),
 959            })
 960            .collect::<Vec<_>>();
 961        self.peer
 962            .respond(
 963                request.receipt(),
 964                proto::GetChannelMessagesResponse {
 965                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 966                    messages,
 967                },
 968            )
 969            .await?;
 970        Ok(())
 971    }
 972
 973    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
 974        &self,
 975        worktree_id: u64,
 976        message: &TypedEnvelope<T>,
 977    ) -> tide::Result<()> {
 978        let connection_ids = self
 979            .state
 980            .read()
 981            .await
 982            .read_worktree(worktree_id, message.sender_id)?
 983            .connection_ids();
 984
 985        broadcast(message.sender_id, connection_ids, |conn_id| {
 986            self.peer
 987                .forward_send(message.sender_id, conn_id, message.payload.clone())
 988        })
 989        .await?;
 990
 991        Ok(())
 992    }
 993}
 994
 995pub async fn broadcast<F, T>(
 996    sender_id: ConnectionId,
 997    receiver_ids: Vec<ConnectionId>,
 998    mut f: F,
 999) -> anyhow::Result<()>
1000where
1001    F: FnMut(ConnectionId) -> T,
1002    T: Future<Output = anyhow::Result<()>>,
1003{
1004    let futures = receiver_ids
1005        .into_iter()
1006        .filter(|id| *id != sender_id)
1007        .map(|id| f(id));
1008    futures::future::try_join_all(futures).await?;
1009    Ok(())
1010}
1011
1012impl ServerState {
1013    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1014        if let Some(connection) = self.connections.get_mut(&connection_id) {
1015            connection.channels.insert(channel_id);
1016            self.channels
1017                .entry(channel_id)
1018                .or_default()
1019                .connection_ids
1020                .insert(connection_id);
1021        }
1022    }
1023
1024    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1025        if let Some(connection) = self.connections.get_mut(&connection_id) {
1026            connection.channels.remove(&channel_id);
1027            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1028                entry.get_mut().connection_ids.remove(&connection_id);
1029                if entry.get_mut().connection_ids.is_empty() {
1030                    entry.remove();
1031                }
1032            }
1033        }
1034    }
1035
1036    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1037        Ok(self
1038            .connections
1039            .get(&connection_id)
1040            .ok_or_else(|| anyhow!("unknown connection"))?
1041            .user_id)
1042    }
1043
1044    fn user_connection_ids<'a>(
1045        &'a self,
1046        user_id: UserId,
1047    ) -> impl 'a + Iterator<Item = ConnectionId> {
1048        self.connections_by_user_id
1049            .get(&user_id)
1050            .into_iter()
1051            .flatten()
1052            .copied()
1053    }
1054
1055    fn is_online(&self, user_id: UserId) -> bool {
1056        self.connections_by_user_id.contains_key(&user_id)
1057    }
1058
1059    // Add the given connection as a guest of the given worktree
1060    fn join_worktree(
1061        &mut self,
1062        connection_id: ConnectionId,
1063        user_id: UserId,
1064        worktree_id: u64,
1065    ) -> tide::Result<(ReplicaId, &Worktree)> {
1066        let connection = self
1067            .connections
1068            .get_mut(&connection_id)
1069            .ok_or_else(|| anyhow!("no such connection"))?;
1070        let worktree = self
1071            .worktrees
1072            .get_mut(&worktree_id)
1073            .ok_or_else(|| anyhow!("no such worktree"))?;
1074        if !worktree.collaborator_user_ids.contains(&user_id) {
1075            Err(anyhow!("no such worktree"))?;
1076        }
1077
1078        let share = worktree.share_mut()?;
1079        connection.worktrees.insert(worktree_id);
1080
1081        let mut replica_id = 1;
1082        while share.active_replica_ids.contains(&replica_id) {
1083            replica_id += 1;
1084        }
1085        share.active_replica_ids.insert(replica_id);
1086        share.guest_connection_ids.insert(connection_id, replica_id);
1087        return Ok((replica_id, worktree));
1088    }
1089
1090    fn read_worktree(
1091        &self,
1092        worktree_id: u64,
1093        connection_id: ConnectionId,
1094    ) -> tide::Result<&Worktree> {
1095        let worktree = self
1096            .worktrees
1097            .get(&worktree_id)
1098            .ok_or_else(|| anyhow!("worktree not found"))?;
1099
1100        if worktree.host_connection_id == connection_id
1101            || worktree
1102                .share()?
1103                .guest_connection_ids
1104                .contains_key(&connection_id)
1105        {
1106            Ok(worktree)
1107        } else {
1108            Err(anyhow!(
1109                "{} is not a member of worktree {}",
1110                connection_id,
1111                worktree_id
1112            ))?
1113        }
1114    }
1115
1116    fn write_worktree(
1117        &mut self,
1118        worktree_id: u64,
1119        connection_id: ConnectionId,
1120    ) -> tide::Result<&mut Worktree> {
1121        let worktree = self
1122            .worktrees
1123            .get_mut(&worktree_id)
1124            .ok_or_else(|| anyhow!("worktree not found"))?;
1125
1126        if worktree.host_connection_id == connection_id
1127            || worktree.share.as_ref().map_or(false, |share| {
1128                share.guest_connection_ids.contains_key(&connection_id)
1129            })
1130        {
1131            Ok(worktree)
1132        } else {
1133            Err(anyhow!(
1134                "{} is not a member of worktree {}",
1135                connection_id,
1136                worktree_id
1137            ))?
1138        }
1139    }
1140
1141    fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1142        let worktree_id = self.next_worktree_id;
1143        for collaborator_user_id in &worktree.collaborator_user_ids {
1144            self.visible_worktrees_by_user_id
1145                .entry(*collaborator_user_id)
1146                .or_default()
1147                .insert(worktree_id);
1148        }
1149        self.next_worktree_id += 1;
1150        self.worktrees.insert(worktree_id, worktree);
1151        worktree_id
1152    }
1153
1154    fn remove_worktree(&mut self, worktree_id: u64) {
1155        let worktree = self.worktrees.remove(&worktree_id).unwrap();
1156        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1157            connection.worktrees.remove(&worktree_id);
1158        }
1159        if let Some(share) = worktree.share {
1160            for connection_id in share.guest_connection_ids.keys() {
1161                if let Some(connection) = self.connections.get_mut(connection_id) {
1162                    connection.worktrees.remove(&worktree_id);
1163                }
1164            }
1165        }
1166        for collaborator_user_id in worktree.collaborator_user_ids {
1167            if let Some(visible_worktrees) = self
1168                .visible_worktrees_by_user_id
1169                .get_mut(&collaborator_user_id)
1170            {
1171                visible_worktrees.remove(&worktree_id);
1172            }
1173        }
1174    }
1175}
1176
1177impl Worktree {
1178    pub fn connection_ids(&self) -> Vec<ConnectionId> {
1179        if let Some(share) = &self.share {
1180            share
1181                .guest_connection_ids
1182                .keys()
1183                .copied()
1184                .chain(Some(self.host_connection_id))
1185                .collect()
1186        } else {
1187            vec![self.host_connection_id]
1188        }
1189    }
1190
1191    fn share(&self) -> tide::Result<&WorktreeShare> {
1192        Ok(self
1193            .share
1194            .as_ref()
1195            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1196    }
1197
1198    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1199        Ok(self
1200            .share
1201            .as_mut()
1202            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1203    }
1204}
1205
1206impl Channel {
1207    fn connection_ids(&self) -> Vec<ConnectionId> {
1208        self.connection_ids.iter().copied().collect()
1209    }
1210}
1211
1212pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1213    let server = Server::new(app.state().clone(), rpc.clone(), None);
1214    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1215        let user_id = request.ext::<UserId>().copied();
1216        let server = server.clone();
1217        async move {
1218            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1219
1220            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1221            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1222            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1223
1224            if !upgrade_requested {
1225                return Ok(Response::new(StatusCode::UpgradeRequired));
1226            }
1227
1228            let header = match request.header("Sec-Websocket-Key") {
1229                Some(h) => h.as_str(),
1230                None => return Err(anyhow!("expected sec-websocket-key"))?,
1231            };
1232
1233            let mut response = Response::new(StatusCode::SwitchingProtocols);
1234            response.insert_header(UPGRADE, "websocket");
1235            response.insert_header(CONNECTION, "Upgrade");
1236            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1237            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1238            response.insert_header("Sec-Websocket-Version", "13");
1239
1240            let http_res: &mut tide::http::Response = response.as_mut();
1241            let upgrade_receiver = http_res.recv_upgrade().await;
1242            let addr = request.remote().unwrap_or("unknown").to_string();
1243            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1244            task::spawn(async move {
1245                if let Some(stream) = upgrade_receiver.await {
1246                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1247                }
1248            });
1249
1250            Ok(response)
1251        }
1252    });
1253}
1254
1255fn header_contains_ignore_case<T>(
1256    request: &tide::Request<T>,
1257    header_name: HeaderName,
1258    value: &str,
1259) -> bool {
1260    request
1261        .header(header_name)
1262        .map(|h| {
1263            h.as_str()
1264                .split(',')
1265                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1266        })
1267        .unwrap_or(false)
1268}
1269
1270#[cfg(test)]
1271mod tests {
1272    use super::*;
1273    use crate::{
1274        auth,
1275        db::{tests::TestDb, UserId},
1276        github, AppState, Config,
1277    };
1278    use async_std::{sync::RwLockReadGuard, task};
1279    use gpui::TestAppContext;
1280    use parking_lot::Mutex;
1281    use postage::{mpsc, watch};
1282    use serde_json::json;
1283    use sqlx::types::time::OffsetDateTime;
1284    use std::{
1285        path::Path,
1286        sync::{
1287            atomic::{AtomicBool, Ordering::SeqCst},
1288            Arc,
1289        },
1290        time::Duration,
1291    };
1292    use zed::{
1293        channel::{Channel, ChannelDetails, ChannelList},
1294        editor::{Editor, Insert},
1295        fs::{FakeFs, Fs as _},
1296        language::LanguageRegistry,
1297        rpc::{self, Client, Credentials, EstablishConnectionError},
1298        settings,
1299        test::FakeHttpClient,
1300        user::UserStore,
1301        worktree::Worktree,
1302    };
1303    use zrpc::Peer;
1304
1305    #[gpui::test]
1306    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1307        let (window_b, _) = cx_b.add_window(|_| EmptyView);
1308        let settings = cx_b.read(settings::test).1;
1309        let lang_registry = Arc::new(LanguageRegistry::new());
1310
1311        // Connect to a server as 2 clients.
1312        let mut server = TestServer::start().await;
1313        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1314        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1315
1316        cx_a.foreground().forbid_parking();
1317
1318        // Share a local worktree as client A
1319        let fs = Arc::new(FakeFs::new());
1320        fs.insert_tree(
1321            "/a",
1322            json!({
1323                ".zed.toml": r#"collaborators = ["user_b"]"#,
1324                "a.txt": "a-contents",
1325                "b.txt": "b-contents",
1326            }),
1327        )
1328        .await;
1329        let worktree_a = Worktree::open_local(
1330            client_a.clone(),
1331            "/a".as_ref(),
1332            fs,
1333            lang_registry.clone(),
1334            &mut cx_a.to_async(),
1335        )
1336        .await
1337        .unwrap();
1338        worktree_a
1339            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1340            .await;
1341        let worktree_id = worktree_a
1342            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1343            .await
1344            .unwrap();
1345
1346        // Join that worktree as client B, and see that a guest has joined as client A.
1347        let worktree_b = Worktree::open_remote(
1348            client_b.clone(),
1349            worktree_id,
1350            lang_registry.clone(),
1351            &mut cx_b.to_async(),
1352        )
1353        .await
1354        .unwrap();
1355        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1356        worktree_a
1357            .condition(&cx_a, |tree, _| {
1358                tree.peers()
1359                    .values()
1360                    .any(|replica_id| *replica_id == replica_id_b)
1361            })
1362            .await;
1363
1364        // Open the same file as client B and client A.
1365        let buffer_b = worktree_b
1366            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1367            .await
1368            .unwrap();
1369        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1370        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1371        let buffer_a = worktree_a
1372            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1373            .await
1374            .unwrap();
1375
1376        // Create a selection set as client B and see that selection set as client A.
1377        let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1378        buffer_a
1379            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1380            .await;
1381
1382        // Edit the buffer as client B and see that edit as client A.
1383        editor_b.update(&mut cx_b, |editor, cx| {
1384            editor.insert(&Insert("ok, ".into()), cx)
1385        });
1386        buffer_a
1387            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1388            .await;
1389
1390        // Remove the selection set as client B, see those selections disappear as client A.
1391        cx_b.update(move |_| drop(editor_b));
1392        buffer_a
1393            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1394            .await;
1395
1396        // Close the buffer as client A, see that the buffer is closed.
1397        cx_a.update(move |_| drop(buffer_a));
1398        worktree_a
1399            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1400            .await;
1401
1402        // Dropping the worktree removes client B from client A's peers.
1403        cx_b.update(move |_| drop(worktree_b));
1404        worktree_a
1405            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1406            .await;
1407    }
1408
1409    #[gpui::test]
1410    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1411        mut cx_a: TestAppContext,
1412        mut cx_b: TestAppContext,
1413        mut cx_c: TestAppContext,
1414    ) {
1415        cx_a.foreground().forbid_parking();
1416        let lang_registry = Arc::new(LanguageRegistry::new());
1417
1418        // Connect to a server as 3 clients.
1419        let mut server = TestServer::start().await;
1420        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1421        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1422        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1423
1424        let fs = Arc::new(FakeFs::new());
1425
1426        // Share a worktree as client A.
1427        fs.insert_tree(
1428            "/a",
1429            json!({
1430                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1431                "file1": "",
1432                "file2": ""
1433            }),
1434        )
1435        .await;
1436
1437        let worktree_a = Worktree::open_local(
1438            client_a.clone(),
1439            "/a".as_ref(),
1440            fs.clone(),
1441            lang_registry.clone(),
1442            &mut cx_a.to_async(),
1443        )
1444        .await
1445        .unwrap();
1446        worktree_a
1447            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1448            .await;
1449        let worktree_id = worktree_a
1450            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1451            .await
1452            .unwrap();
1453
1454        // Join that worktree as clients B and C.
1455        let worktree_b = Worktree::open_remote(
1456            client_b.clone(),
1457            worktree_id,
1458            lang_registry.clone(),
1459            &mut cx_b.to_async(),
1460        )
1461        .await
1462        .unwrap();
1463        let worktree_c = Worktree::open_remote(
1464            client_c.clone(),
1465            worktree_id,
1466            lang_registry.clone(),
1467            &mut cx_c.to_async(),
1468        )
1469        .await
1470        .unwrap();
1471
1472        // Open and edit a buffer as both guests B and C.
1473        let buffer_b = worktree_b
1474            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1475            .await
1476            .unwrap();
1477        let buffer_c = worktree_c
1478            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1479            .await
1480            .unwrap();
1481        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1482        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1483
1484        // Open and edit that buffer as the host.
1485        let buffer_a = worktree_a
1486            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1487            .await
1488            .unwrap();
1489
1490        buffer_a
1491            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1492            .await;
1493        buffer_a.update(&mut cx_a, |buf, cx| {
1494            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1495        });
1496
1497        // Wait for edits to propagate
1498        buffer_a
1499            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1500            .await;
1501        buffer_b
1502            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1503            .await;
1504        buffer_c
1505            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1506            .await;
1507
1508        // Edit the buffer as the host and concurrently save as guest B.
1509        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1510        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1511        save_b.await.unwrap();
1512        assert_eq!(
1513            fs.load("/a/file1".as_ref()).await.unwrap(),
1514            "hi-a, i-am-c, i-am-b, i-am-a"
1515        );
1516        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1517        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1518        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1519
1520        // Make changes on host's file system, see those changes on the guests.
1521        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1522            .await
1523            .unwrap();
1524        fs.insert_file(Path::new("/a/file4"), "4".into())
1525            .await
1526            .unwrap();
1527
1528        worktree_b
1529            .condition(&cx_b, |tree, _| tree.file_count() == 4)
1530            .await;
1531        worktree_c
1532            .condition(&cx_c, |tree, _| tree.file_count() == 4)
1533            .await;
1534        worktree_b.read_with(&cx_b, |tree, _| {
1535            assert_eq!(
1536                tree.paths()
1537                    .map(|p| p.to_string_lossy())
1538                    .collect::<Vec<_>>(),
1539                &[".zed.toml", "file1", "file3", "file4"]
1540            )
1541        });
1542        worktree_c.read_with(&cx_c, |tree, _| {
1543            assert_eq!(
1544                tree.paths()
1545                    .map(|p| p.to_string_lossy())
1546                    .collect::<Vec<_>>(),
1547                &[".zed.toml", "file1", "file3", "file4"]
1548            )
1549        });
1550    }
1551
1552    #[gpui::test]
1553    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1554        cx_a.foreground().forbid_parking();
1555        let lang_registry = Arc::new(LanguageRegistry::new());
1556
1557        // Connect to a server as 2 clients.
1558        let mut server = TestServer::start().await;
1559        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1560        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1561
1562        // Share a local worktree as client A
1563        let fs = Arc::new(FakeFs::new());
1564        fs.insert_tree(
1565            "/dir",
1566            json!({
1567                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1568                "a.txt": "a-contents",
1569            }),
1570        )
1571        .await;
1572
1573        let worktree_a = Worktree::open_local(
1574            client_a.clone(),
1575            "/dir".as_ref(),
1576            fs,
1577            lang_registry.clone(),
1578            &mut cx_a.to_async(),
1579        )
1580        .await
1581        .unwrap();
1582        worktree_a
1583            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1584            .await;
1585        let worktree_id = worktree_a
1586            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1587            .await
1588            .unwrap();
1589
1590        // Join that worktree as client B, and see that a guest has joined as client A.
1591        let worktree_b = Worktree::open_remote(
1592            client_b.clone(),
1593            worktree_id,
1594            lang_registry.clone(),
1595            &mut cx_b.to_async(),
1596        )
1597        .await
1598        .unwrap();
1599
1600        let buffer_b = worktree_b
1601            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1602            .await
1603            .unwrap();
1604        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1605
1606        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1607        buffer_b.read_with(&cx_b, |buf, _| {
1608            assert!(buf.is_dirty());
1609            assert!(!buf.has_conflict());
1610        });
1611
1612        buffer_b
1613            .update(&mut cx_b, |buf, cx| buf.save(cx))
1614            .unwrap()
1615            .await
1616            .unwrap();
1617        worktree_b
1618            .condition(&cx_b, |_, cx| {
1619                buffer_b.read(cx).file().unwrap().mtime != mtime
1620            })
1621            .await;
1622        buffer_b.read_with(&cx_b, |buf, _| {
1623            assert!(!buf.is_dirty());
1624            assert!(!buf.has_conflict());
1625        });
1626
1627        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1628        buffer_b.read_with(&cx_b, |buf, _| {
1629            assert!(buf.is_dirty());
1630            assert!(!buf.has_conflict());
1631        });
1632    }
1633
1634    #[gpui::test]
1635    async fn test_editing_while_guest_opens_buffer(
1636        mut cx_a: TestAppContext,
1637        mut cx_b: TestAppContext,
1638    ) {
1639        cx_a.foreground().forbid_parking();
1640        let lang_registry = Arc::new(LanguageRegistry::new());
1641
1642        // Connect to a server as 2 clients.
1643        let mut server = TestServer::start().await;
1644        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1645        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1646
1647        // Share a local worktree as client A
1648        let fs = Arc::new(FakeFs::new());
1649        fs.insert_tree(
1650            "/dir",
1651            json!({
1652                ".zed.toml": r#"collaborators = ["user_b"]"#,
1653                "a.txt": "a-contents",
1654            }),
1655        )
1656        .await;
1657        let worktree_a = Worktree::open_local(
1658            client_a.clone(),
1659            "/dir".as_ref(),
1660            fs,
1661            lang_registry.clone(),
1662            &mut cx_a.to_async(),
1663        )
1664        .await
1665        .unwrap();
1666        worktree_a
1667            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1668            .await;
1669        let worktree_id = worktree_a
1670            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1671            .await
1672            .unwrap();
1673
1674        // Join that worktree as client B, and see that a guest has joined as client A.
1675        let worktree_b = Worktree::open_remote(
1676            client_b.clone(),
1677            worktree_id,
1678            lang_registry.clone(),
1679            &mut cx_b.to_async(),
1680        )
1681        .await
1682        .unwrap();
1683
1684        let buffer_a = worktree_a
1685            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1686            .await
1687            .unwrap();
1688        let buffer_b = cx_b
1689            .background()
1690            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1691
1692        task::yield_now().await;
1693        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1694
1695        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1696        let buffer_b = buffer_b.await.unwrap();
1697        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1698    }
1699
1700    #[gpui::test]
1701    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1702        cx_a.foreground().forbid_parking();
1703        let lang_registry = Arc::new(LanguageRegistry::new());
1704
1705        // Connect to a server as 2 clients.
1706        let mut server = TestServer::start().await;
1707        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1708        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1709
1710        // Share a local worktree as client A
1711        let fs = Arc::new(FakeFs::new());
1712        fs.insert_tree(
1713            "/a",
1714            json!({
1715                ".zed.toml": r#"collaborators = ["user_b"]"#,
1716                "a.txt": "a-contents",
1717                "b.txt": "b-contents",
1718            }),
1719        )
1720        .await;
1721        let worktree_a = Worktree::open_local(
1722            client_a.clone(),
1723            "/a".as_ref(),
1724            fs,
1725            lang_registry.clone(),
1726            &mut cx_a.to_async(),
1727        )
1728        .await
1729        .unwrap();
1730        worktree_a
1731            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1732            .await;
1733        let worktree_id = worktree_a
1734            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1735            .await
1736            .unwrap();
1737
1738        // Join that worktree as client B, and see that a guest has joined as client A.
1739        let _worktree_b = Worktree::open_remote(
1740            client_b.clone(),
1741            worktree_id,
1742            lang_registry.clone(),
1743            &mut cx_b.to_async(),
1744        )
1745        .await
1746        .unwrap();
1747        worktree_a
1748            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1749            .await;
1750
1751        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1752        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1753        worktree_a
1754            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1755            .await;
1756    }
1757
1758    #[gpui::test]
1759    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1760        cx_a.foreground().forbid_parking();
1761
1762        // Connect to a server as 2 clients.
1763        let mut server = TestServer::start().await;
1764        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1765        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1766
1767        // Create an org that includes these 2 users.
1768        let db = &server.app_state.db;
1769        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1770        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1771            .await
1772            .unwrap();
1773        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1774            .await
1775            .unwrap();
1776
1777        // Create a channel that includes all the users.
1778        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1779        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1780            .await
1781            .unwrap();
1782        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1783            .await
1784            .unwrap();
1785        db.create_channel_message(
1786            channel_id,
1787            current_user_id(&user_store_b),
1788            "hello A, it's B.",
1789            OffsetDateTime::now_utc(),
1790            1,
1791        )
1792        .await
1793        .unwrap();
1794
1795        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1796        channels_a
1797            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1798            .await;
1799        channels_a.read_with(&cx_a, |list, _| {
1800            assert_eq!(
1801                list.available_channels().unwrap(),
1802                &[ChannelDetails {
1803                    id: channel_id.to_proto(),
1804                    name: "test-channel".to_string()
1805                }]
1806            )
1807        });
1808        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1809            this.get_channel(channel_id.to_proto(), cx).unwrap()
1810        });
1811        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1812        channel_a
1813            .condition(&cx_a, |channel, _| {
1814                channel_messages(channel)
1815                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1816            })
1817            .await;
1818
1819        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1820        channels_b
1821            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1822            .await;
1823        channels_b.read_with(&cx_b, |list, _| {
1824            assert_eq!(
1825                list.available_channels().unwrap(),
1826                &[ChannelDetails {
1827                    id: channel_id.to_proto(),
1828                    name: "test-channel".to_string()
1829                }]
1830            )
1831        });
1832
1833        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1834            this.get_channel(channel_id.to_proto(), cx).unwrap()
1835        });
1836        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1837        channel_b
1838            .condition(&cx_b, |channel, _| {
1839                channel_messages(channel)
1840                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1841            })
1842            .await;
1843
1844        channel_a
1845            .update(&mut cx_a, |channel, cx| {
1846                channel
1847                    .send_message("oh, hi B.".to_string(), cx)
1848                    .unwrap()
1849                    .detach();
1850                let task = channel.send_message("sup".to_string(), cx).unwrap();
1851                assert_eq!(
1852                    channel_messages(channel),
1853                    &[
1854                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1855                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
1856                        ("user_a".to_string(), "sup".to_string(), true)
1857                    ]
1858                );
1859                task
1860            })
1861            .await
1862            .unwrap();
1863
1864        channel_b
1865            .condition(&cx_b, |channel, _| {
1866                channel_messages(channel)
1867                    == [
1868                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1869                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
1870                        ("user_a".to_string(), "sup".to_string(), false),
1871                    ]
1872            })
1873            .await;
1874
1875        assert_eq!(
1876            server.state().await.channels[&channel_id]
1877                .connection_ids
1878                .len(),
1879            2
1880        );
1881        cx_b.update(|_| drop(channel_b));
1882        server
1883            .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1884            .await;
1885
1886        cx_a.update(|_| drop(channel_a));
1887        server
1888            .condition(|state| !state.channels.contains_key(&channel_id))
1889            .await;
1890    }
1891
1892    #[gpui::test]
1893    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1894        cx_a.foreground().forbid_parking();
1895
1896        let mut server = TestServer::start().await;
1897        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1898
1899        let db = &server.app_state.db;
1900        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1901        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1902        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1903            .await
1904            .unwrap();
1905        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1906            .await
1907            .unwrap();
1908
1909        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1910        channels_a
1911            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1912            .await;
1913        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1914            this.get_channel(channel_id.to_proto(), cx).unwrap()
1915        });
1916
1917        // Messages aren't allowed to be too long.
1918        channel_a
1919            .update(&mut cx_a, |channel, cx| {
1920                let long_body = "this is long.\n".repeat(1024);
1921                channel.send_message(long_body, cx).unwrap()
1922            })
1923            .await
1924            .unwrap_err();
1925
1926        // Messages aren't allowed to be blank.
1927        channel_a.update(&mut cx_a, |channel, cx| {
1928            channel.send_message(String::new(), cx).unwrap_err()
1929        });
1930
1931        // Leading and trailing whitespace are trimmed.
1932        channel_a
1933            .update(&mut cx_a, |channel, cx| {
1934                channel
1935                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1936                    .unwrap()
1937            })
1938            .await
1939            .unwrap();
1940        assert_eq!(
1941            db.get_channel_messages(channel_id, 10, None)
1942                .await
1943                .unwrap()
1944                .iter()
1945                .map(|m| &m.body)
1946                .collect::<Vec<_>>(),
1947            &["surrounded by whitespace"]
1948        );
1949    }
1950
1951    #[gpui::test]
1952    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1953        cx_a.foreground().forbid_parking();
1954        let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1955
1956        // Connect to a server as 2 clients.
1957        let mut server = TestServer::start().await;
1958        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1959        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1960        let mut status_b = client_b.status();
1961
1962        // Create an org that includes these 2 users.
1963        let db = &server.app_state.db;
1964        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1965        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1966            .await
1967            .unwrap();
1968        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1969            .await
1970            .unwrap();
1971
1972        // Create a channel that includes all the users.
1973        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1974        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1975            .await
1976            .unwrap();
1977        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1978            .await
1979            .unwrap();
1980        db.create_channel_message(
1981            channel_id,
1982            current_user_id(&user_store_b),
1983            "hello A, it's B.",
1984            OffsetDateTime::now_utc(),
1985            2,
1986        )
1987        .await
1988        .unwrap();
1989
1990        let user_store_a =
1991            UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
1992        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1993        channels_a
1994            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1995            .await;
1996
1997        channels_a.read_with(&cx_a, |list, _| {
1998            assert_eq!(
1999                list.available_channels().unwrap(),
2000                &[ChannelDetails {
2001                    id: channel_id.to_proto(),
2002                    name: "test-channel".to_string()
2003                }]
2004            )
2005        });
2006        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
2007            this.get_channel(channel_id.to_proto(), cx).unwrap()
2008        });
2009        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
2010        channel_a
2011            .condition(&cx_a, |channel, _| {
2012                channel_messages(channel)
2013                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2014            })
2015            .await;
2016
2017        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2018        channels_b
2019            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2020            .await;
2021        channels_b.read_with(&cx_b, |list, _| {
2022            assert_eq!(
2023                list.available_channels().unwrap(),
2024                &[ChannelDetails {
2025                    id: channel_id.to_proto(),
2026                    name: "test-channel".to_string()
2027                }]
2028            )
2029        });
2030
2031        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2032            this.get_channel(channel_id.to_proto(), cx).unwrap()
2033        });
2034        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2035        channel_b
2036            .condition(&cx_b, |channel, _| {
2037                channel_messages(channel)
2038                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2039            })
2040            .await;
2041
2042        // Disconnect client B, ensuring we can still access its cached channel data.
2043        server.forbid_connections();
2044        server.disconnect_client(current_user_id(&user_store_b));
2045        while !matches!(
2046            status_b.recv().await,
2047            Some(rpc::Status::ReconnectionError { .. })
2048        ) {}
2049
2050        channels_b.read_with(&cx_b, |channels, _| {
2051            assert_eq!(
2052                channels.available_channels().unwrap(),
2053                [ChannelDetails {
2054                    id: channel_id.to_proto(),
2055                    name: "test-channel".to_string()
2056                }]
2057            )
2058        });
2059        channel_b.read_with(&cx_b, |channel, _| {
2060            assert_eq!(
2061                channel_messages(channel),
2062                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2063            )
2064        });
2065
2066        // Send a message from client B while it is disconnected.
2067        channel_b
2068            .update(&mut cx_b, |channel, cx| {
2069                let task = channel
2070                    .send_message("can you see this?".to_string(), cx)
2071                    .unwrap();
2072                assert_eq!(
2073                    channel_messages(channel),
2074                    &[
2075                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2076                        ("user_b".to_string(), "can you see this?".to_string(), true)
2077                    ]
2078                );
2079                task
2080            })
2081            .await
2082            .unwrap_err();
2083
2084        // Send a message from client A while B is disconnected.
2085        channel_a
2086            .update(&mut cx_a, |channel, cx| {
2087                channel
2088                    .send_message("oh, hi B.".to_string(), cx)
2089                    .unwrap()
2090                    .detach();
2091                let task = channel.send_message("sup".to_string(), cx).unwrap();
2092                assert_eq!(
2093                    channel_messages(channel),
2094                    &[
2095                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2096                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
2097                        ("user_a".to_string(), "sup".to_string(), true)
2098                    ]
2099                );
2100                task
2101            })
2102            .await
2103            .unwrap();
2104
2105        // Give client B a chance to reconnect.
2106        server.allow_connections();
2107        cx_b.foreground().advance_clock(Duration::from_secs(10));
2108
2109        // Verify that B sees the new messages upon reconnection, as well as the message client B
2110        // sent while offline.
2111        channel_b
2112            .condition(&cx_b, |channel, _| {
2113                channel_messages(channel)
2114                    == [
2115                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2116                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2117                        ("user_a".to_string(), "sup".to_string(), false),
2118                        ("user_b".to_string(), "can you see this?".to_string(), false),
2119                    ]
2120            })
2121            .await;
2122
2123        // Ensure client A and B can communicate normally after reconnection.
2124        channel_a
2125            .update(&mut cx_a, |channel, cx| {
2126                channel.send_message("you online?".to_string(), cx).unwrap()
2127            })
2128            .await
2129            .unwrap();
2130        channel_b
2131            .condition(&cx_b, |channel, _| {
2132                channel_messages(channel)
2133                    == [
2134                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2135                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2136                        ("user_a".to_string(), "sup".to_string(), false),
2137                        ("user_b".to_string(), "can you see this?".to_string(), false),
2138                        ("user_a".to_string(), "you online?".to_string(), false),
2139                    ]
2140            })
2141            .await;
2142
2143        channel_b
2144            .update(&mut cx_b, |channel, cx| {
2145                channel.send_message("yep".to_string(), cx).unwrap()
2146            })
2147            .await
2148            .unwrap();
2149        channel_a
2150            .condition(&cx_a, |channel, _| {
2151                channel_messages(channel)
2152                    == [
2153                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2154                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2155                        ("user_a".to_string(), "sup".to_string(), false),
2156                        ("user_b".to_string(), "can you see this?".to_string(), false),
2157                        ("user_a".to_string(), "you online?".to_string(), false),
2158                        ("user_b".to_string(), "yep".to_string(), false),
2159                    ]
2160            })
2161            .await;
2162    }
2163
2164    struct TestServer {
2165        peer: Arc<Peer>,
2166        app_state: Arc<AppState>,
2167        server: Arc<Server>,
2168        notifications: mpsc::Receiver<()>,
2169        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2170        forbid_connections: Arc<AtomicBool>,
2171        _test_db: TestDb,
2172    }
2173
2174    impl TestServer {
2175        async fn start() -> Self {
2176            let test_db = TestDb::new();
2177            let app_state = Self::build_app_state(&test_db).await;
2178            let peer = Peer::new();
2179            let notifications = mpsc::channel(128);
2180            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2181            Self {
2182                peer,
2183                app_state,
2184                server,
2185                notifications: notifications.1,
2186                connection_killers: Default::default(),
2187                forbid_connections: Default::default(),
2188                _test_db: test_db,
2189            }
2190        }
2191
2192        async fn create_client(
2193            &mut self,
2194            cx: &mut TestAppContext,
2195            name: &str,
2196        ) -> (Arc<Client>, Arc<UserStore>) {
2197            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2198            let client_name = name.to_string();
2199            let mut client = Client::new();
2200            let server = self.server.clone();
2201            let connection_killers = self.connection_killers.clone();
2202            let forbid_connections = self.forbid_connections.clone();
2203            Arc::get_mut(&mut client)
2204                .unwrap()
2205                .override_authenticate(move |cx| {
2206                    cx.spawn(|_| async move {
2207                        let access_token = "the-token".to_string();
2208                        Ok(Credentials {
2209                            user_id: user_id.0 as u64,
2210                            access_token,
2211                        })
2212                    })
2213                })
2214                .override_establish_connection(move |credentials, cx| {
2215                    assert_eq!(credentials.user_id, user_id.0 as u64);
2216                    assert_eq!(credentials.access_token, "the-token");
2217
2218                    let server = server.clone();
2219                    let connection_killers = connection_killers.clone();
2220                    let forbid_connections = forbid_connections.clone();
2221                    let client_name = client_name.clone();
2222                    cx.spawn(move |cx| async move {
2223                        if forbid_connections.load(SeqCst) {
2224                            Err(EstablishConnectionError::other(anyhow!(
2225                                "server is forbidding connections"
2226                            )))
2227                        } else {
2228                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2229                            connection_killers.lock().insert(user_id, kill_conn);
2230                            cx.background()
2231                                .spawn(server.handle_connection(server_conn, client_name, user_id))
2232                                .detach();
2233                            Ok(client_conn)
2234                        }
2235                    })
2236                });
2237
2238            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2239            client
2240                .authenticate_and_connect(&cx.to_async())
2241                .await
2242                .unwrap();
2243
2244            let user_store = UserStore::new(client.clone(), http, &cx.background());
2245            let mut authed_user = user_store.watch_current_user();
2246            while authed_user.recv().await.unwrap().is_none() {}
2247
2248            (client, user_store)
2249        }
2250
2251        fn disconnect_client(&self, user_id: UserId) {
2252            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2253                let _ = kill_conn.try_send(Some(()));
2254            }
2255        }
2256
2257        fn forbid_connections(&self) {
2258            self.forbid_connections.store(true, SeqCst);
2259        }
2260
2261        fn allow_connections(&self) {
2262            self.forbid_connections.store(false, SeqCst);
2263        }
2264
2265        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2266            let mut config = Config::default();
2267            config.session_secret = "a".repeat(32);
2268            config.database_url = test_db.url.clone();
2269            let github_client = github::AppClient::test();
2270            Arc::new(AppState {
2271                db: test_db.db().clone(),
2272                handlebars: Default::default(),
2273                auth_client: auth::build_client("", ""),
2274                repo_client: github::RepoClient::test(&github_client),
2275                github_client,
2276                config,
2277            })
2278        }
2279
2280        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2281            self.server.state.read().await
2282        }
2283
2284        async fn condition<F>(&mut self, mut predicate: F)
2285        where
2286            F: FnMut(&ServerState) -> bool,
2287        {
2288            async_std::future::timeout(Duration::from_millis(500), async {
2289                while !(predicate)(&*self.server.state.read().await) {
2290                    self.notifications.recv().await;
2291                }
2292            })
2293            .await
2294            .expect("condition timed out");
2295        }
2296    }
2297
2298    impl Drop for TestServer {
2299        fn drop(&mut self) {
2300            task::block_on(self.peer.reset());
2301        }
2302    }
2303
2304    fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2305        UserId::from_proto(user_store.current_user().unwrap().id)
2306    }
2307
2308    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2309        channel
2310            .messages()
2311            .cursor::<(), ()>()
2312            .map(|m| {
2313                (
2314                    m.sender.github_login.clone(),
2315                    m.body.clone(),
2316                    m.is_pending(),
2317                )
2318            })
2319            .collect()
2320    }
2321
2322    struct EmptyView;
2323
2324    impl gpui::Entity for EmptyView {
2325        type Event = ();
2326    }
2327
2328    impl gpui::View for EmptyView {
2329        fn ui_name() -> &'static str {
2330            "empty view"
2331        }
2332
2333        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2334            gpui::Element::boxed(gpui::elements::Empty)
2335        }
2336    }
2337}