rpc.rs

   1use super::{
   2    auth,
   3    db::{ChannelId, MessageId, UserId},
   4    AppState,
   5};
   6use anyhow::anyhow;
   7use async_std::{sync::RwLock, task};
   8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
   9use futures::{future::BoxFuture, FutureExt};
  10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  11use sha1::{Digest as _, Sha1};
  12use std::{
  13    any::TypeId,
  14    collections::{hash_map, HashMap, HashSet},
  15    future::Future,
  16    mem,
  17    sync::Arc,
  18    time::Instant,
  19};
  20use surf::StatusCode;
  21use tide::log;
  22use tide::{
  23    http::headers::{HeaderName, CONNECTION, UPGRADE},
  24    Request, Response,
  25};
  26use time::OffsetDateTime;
  27use zrpc::{
  28    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  29    Connection, ConnectionId, Peer, TypedEnvelope,
  30};
  31
  32type ReplicaId = u16;
  33
  34type MessageHandler = Box<
  35    dyn Send
  36        + Sync
  37        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  38>;
  39
  40pub struct Server {
  41    peer: Arc<Peer>,
  42    state: RwLock<ServerState>,
  43    app_state: Arc<AppState>,
  44    handlers: HashMap<TypeId, MessageHandler>,
  45    notifications: Option<mpsc::Sender<()>>,
  46}
  47
  48#[derive(Default)]
  49struct ServerState {
  50    connections: HashMap<ConnectionId, ConnectionState>,
  51    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
  52    pub worktrees: HashMap<u64, Worktree>,
  53    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
  54    channels: HashMap<ChannelId, Channel>,
  55    next_worktree_id: u64,
  56}
  57
  58struct ConnectionState {
  59    user_id: UserId,
  60    worktrees: HashSet<u64>,
  61    channels: HashSet<ChannelId>,
  62}
  63
  64struct Worktree {
  65    host_connection_id: ConnectionId,
  66    collaborator_user_ids: Vec<UserId>,
  67    root_name: String,
  68    share: Option<WorktreeShare>,
  69}
  70
  71struct WorktreeShare {
  72    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
  73    active_replica_ids: HashSet<ReplicaId>,
  74    entries: HashMap<u64, proto::Entry>,
  75}
  76
  77#[derive(Default)]
  78struct Channel {
  79    connection_ids: HashSet<ConnectionId>,
  80}
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84
  85impl Server {
  86    pub fn new(
  87        app_state: Arc<AppState>,
  88        peer: Arc<Peer>,
  89        notifications: Option<mpsc::Sender<()>>,
  90    ) -> Arc<Self> {
  91        let mut server = Self {
  92            peer,
  93            app_state,
  94            state: Default::default(),
  95            handlers: Default::default(),
  96            notifications,
  97        };
  98
  99        server
 100            .add_handler(Server::ping)
 101            .add_handler(Server::open_worktree)
 102            .add_handler(Server::close_worktree)
 103            .add_handler(Server::share_worktree)
 104            .add_handler(Server::unshare_worktree)
 105            .add_handler(Server::join_worktree)
 106            .add_handler(Server::update_worktree)
 107            .add_handler(Server::open_buffer)
 108            .add_handler(Server::close_buffer)
 109            .add_handler(Server::update_buffer)
 110            .add_handler(Server::buffer_saved)
 111            .add_handler(Server::save_buffer)
 112            .add_handler(Server::get_channels)
 113            .add_handler(Server::get_users)
 114            .add_handler(Server::join_channel)
 115            .add_handler(Server::leave_channel)
 116            .add_handler(Server::send_channel_message)
 117            .add_handler(Server::get_channel_messages)
 118            .add_handler(Server::get_collaborators);
 119
 120        Arc::new(server)
 121    }
 122
 123    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 124    where
 125        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 126        Fut: 'static + Send + Future<Output = tide::Result<()>>,
 127        M: EnvelopedMessage,
 128    {
 129        let prev_handler = self.handlers.insert(
 130            TypeId::of::<M>(),
 131            Box::new(move |server, envelope| {
 132                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 133                (handler)(server, *envelope).boxed()
 134            }),
 135        );
 136        if prev_handler.is_some() {
 137            panic!("registered a handler for the same message twice");
 138        }
 139        self
 140    }
 141
 142    pub fn handle_connection(
 143        self: &Arc<Self>,
 144        connection: Connection,
 145        addr: String,
 146        user_id: UserId,
 147    ) -> impl Future<Output = ()> {
 148        let this = self.clone();
 149        async move {
 150            let (connection_id, handle_io, mut incoming_rx) =
 151                this.peer.add_connection(connection).await;
 152            this.add_connection(connection_id, user_id).await;
 153
 154            let handle_io = handle_io.fuse();
 155            futures::pin_mut!(handle_io);
 156            loop {
 157                let next_message = incoming_rx.recv().fuse();
 158                futures::pin_mut!(next_message);
 159                futures::select_biased! {
 160                    message = next_message => {
 161                        if let Some(message) = message {
 162                            let start_time = Instant::now();
 163                            log::info!("RPC message received: {}", message.payload_type_name());
 164                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 165                                if let Err(err) = (handler)(this.clone(), message).await {
 166                                    log::error!("error handling message: {:?}", err);
 167                                } else {
 168                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 169                                }
 170
 171                                if let Some(mut notifications) = this.notifications.clone() {
 172                                    let _ = notifications.send(()).await;
 173                                }
 174                            } else {
 175                                log::warn!("unhandled message: {}", message.payload_type_name());
 176                            }
 177                        } else {
 178                            log::info!("rpc connection closed {:?}", addr);
 179                            break;
 180                        }
 181                    }
 182                    handle_io = handle_io => {
 183                        if let Err(err) = handle_io {
 184                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 185                        }
 186                        break;
 187                    }
 188                }
 189            }
 190
 191            if let Err(err) = this.sign_out(connection_id).await {
 192                log::error!("error signing out connection {:?} - {:?}", addr, err);
 193            }
 194        }
 195    }
 196
 197    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
 198        self.peer.disconnect(connection_id).await;
 199        let worktree_ids = self.remove_connection(connection_id).await;
 200        for worktree_id in worktree_ids {
 201            let state = self.state.read().await;
 202            if let Some(worktree) = state.worktrees.get(&worktree_id) {
 203                broadcast(connection_id, worktree.connection_ids(), |conn_id| {
 204                    self.peer.send(
 205                        conn_id,
 206                        proto::RemovePeer {
 207                            worktree_id,
 208                            peer_id: connection_id.0,
 209                        },
 210                    )
 211                })
 212                .await?;
 213            }
 214        }
 215        Ok(())
 216    }
 217
 218    // Add a new connection associated with a given user.
 219    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
 220        let mut state = self.state.write().await;
 221        state.connections.insert(
 222            connection_id,
 223            ConnectionState {
 224                user_id,
 225                worktrees: Default::default(),
 226                channels: Default::default(),
 227            },
 228        );
 229        state
 230            .connections_by_user_id
 231            .entry(user_id)
 232            .or_default()
 233            .insert(connection_id);
 234    }
 235
 236    // Remove the given connection and its association with any worktrees.
 237    async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
 238        let mut worktree_ids = Vec::new();
 239        let mut state = self.state.write().await;
 240        if let Some(connection) = state.connections.remove(&connection_id) {
 241            for channel_id in connection.channels {
 242                if let Some(channel) = state.channels.get_mut(&channel_id) {
 243                    channel.connection_ids.remove(&connection_id);
 244                }
 245            }
 246            for worktree_id in connection.worktrees {
 247                if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
 248                    if worktree.host_connection_id == connection_id {
 249                        worktree_ids.push(worktree_id);
 250                    } else if let Some(share_state) = worktree.share.as_mut() {
 251                        if let Some(replica_id) =
 252                            share_state.guest_connection_ids.remove(&connection_id)
 253                        {
 254                            share_state.active_replica_ids.remove(&replica_id);
 255                            worktree_ids.push(worktree_id);
 256                        }
 257                    }
 258                }
 259            }
 260
 261            let user_connections = state
 262                .connections_by_user_id
 263                .get_mut(&connection.user_id)
 264                .unwrap();
 265            user_connections.remove(&connection_id);
 266            if user_connections.is_empty() {
 267                state.connections_by_user_id.remove(&connection.user_id);
 268            }
 269        }
 270        worktree_ids
 271    }
 272
 273    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
 274        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 275        Ok(())
 276    }
 277
 278    async fn open_worktree(
 279        self: Arc<Server>,
 280        request: TypedEnvelope<proto::OpenWorktree>,
 281    ) -> tide::Result<()> {
 282        let receipt = request.receipt();
 283        let user_id = self
 284            .state
 285            .read()
 286            .await
 287            .user_id_for_connection(request.sender_id)?;
 288
 289        let mut collaborator_user_ids = Vec::new();
 290        for github_login in request.payload.collaborator_logins {
 291            match self.app_state.db.create_user(&github_login, false).await {
 292                Ok(collaborator_user_id) => {
 293                    if collaborator_user_id != user_id {
 294                        collaborator_user_ids.push(collaborator_user_id);
 295                    }
 296                }
 297                Err(err) => {
 298                    let message = err.to_string();
 299                    self.peer
 300                        .respond_with_error(receipt, proto::Error { message })
 301                        .await?;
 302                    return Ok(());
 303                }
 304            }
 305        }
 306
 307        let mut state = self.state.write().await;
 308        let worktree_id = state.add_worktree(Worktree {
 309            host_connection_id: request.sender_id,
 310            collaborator_user_ids,
 311            root_name: request.payload.root_name,
 312            share: None,
 313        });
 314
 315        self.peer
 316            .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
 317            .await?;
 318        Ok(())
 319    }
 320
 321    async fn share_worktree(
 322        self: Arc<Server>,
 323        mut request: TypedEnvelope<proto::ShareWorktree>,
 324    ) -> tide::Result<()> {
 325        let worktree = request
 326            .payload
 327            .worktree
 328            .as_mut()
 329            .ok_or_else(|| anyhow!("missing worktree"))?;
 330        let entries = mem::take(&mut worktree.entries)
 331            .into_iter()
 332            .map(|entry| (entry.id, entry))
 333            .collect();
 334        let mut state = self.state.write().await;
 335        if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
 336            worktree.share = Some(WorktreeShare {
 337                guest_connection_ids: Default::default(),
 338                active_replica_ids: Default::default(),
 339                entries,
 340            });
 341            self.peer
 342                .respond(request.receipt(), proto::ShareWorktreeResponse {})
 343                .await?;
 344        } else {
 345            self.peer
 346                .respond_with_error(
 347                    request.receipt(),
 348                    proto::Error {
 349                        message: "no such worktree".to_string(),
 350                    },
 351                )
 352                .await?;
 353        }
 354        Ok(())
 355    }
 356
 357    async fn unshare_worktree(
 358        self: Arc<Server>,
 359        request: TypedEnvelope<proto::UnshareWorktree>,
 360    ) -> tide::Result<()> {
 361        let worktree_id = request.payload.worktree_id;
 362
 363        let connection_ids;
 364        {
 365            let mut state = self.state.write().await;
 366            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 367            if worktree.host_connection_id != request.sender_id {
 368                return Err(anyhow!("no such worktree"))?;
 369            }
 370
 371            connection_ids = worktree.connection_ids();
 372            worktree.share.take();
 373            for connection_id in &connection_ids {
 374                if let Some(connection) = state.connections.get_mut(connection_id) {
 375                    connection.worktrees.remove(&worktree_id);
 376                }
 377            }
 378        }
 379
 380        broadcast(request.sender_id, connection_ids, |conn_id| {
 381            self.peer
 382                .send(conn_id, proto::UnshareWorktree { worktree_id })
 383        })
 384        .await?;
 385
 386        Ok(())
 387    }
 388
 389    async fn join_worktree(
 390        self: Arc<Server>,
 391        request: TypedEnvelope<proto::JoinWorktree>,
 392    ) -> tide::Result<()> {
 393        let worktree_id = request.payload.worktree_id;
 394        let user_id = self
 395            .state
 396            .read()
 397            .await
 398            .user_id_for_connection(request.sender_id)?;
 399
 400        let response;
 401        let connection_ids;
 402        let mut state = self.state.write().await;
 403        match state.join_worktree(request.sender_id, user_id, worktree_id) {
 404            Ok((peer_replica_id, worktree)) => {
 405                let share = worktree.share()?;
 406                let peer_count = share.guest_connection_ids.len();
 407                let mut peers = Vec::with_capacity(peer_count);
 408                peers.push(proto::Peer {
 409                    peer_id: worktree.host_connection_id.0,
 410                    replica_id: 0,
 411                });
 412                for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
 413                    if *peer_conn_id != request.sender_id {
 414                        peers.push(proto::Peer {
 415                            peer_id: peer_conn_id.0,
 416                            replica_id: *peer_replica_id as u32,
 417                        });
 418                    }
 419                }
 420                connection_ids = worktree.connection_ids();
 421                response = proto::JoinWorktreeResponse {
 422                    worktree: Some(proto::Worktree {
 423                        id: worktree_id,
 424                        root_name: worktree.root_name.clone(),
 425                        entries: share.entries.values().cloned().collect(),
 426                    }),
 427                    replica_id: peer_replica_id as u32,
 428                    peers,
 429                };
 430            }
 431            Err(error) => {
 432                self.peer
 433                    .respond_with_error(
 434                        request.receipt(),
 435                        proto::Error {
 436                            message: error.to_string(),
 437                        },
 438                    )
 439                    .await?;
 440                return Ok(());
 441            }
 442        }
 443
 444        broadcast(request.sender_id, connection_ids, |conn_id| {
 445            self.peer.send(
 446                conn_id,
 447                proto::AddPeer {
 448                    worktree_id,
 449                    peer: Some(proto::Peer {
 450                        peer_id: request.sender_id.0,
 451                        replica_id: response.replica_id,
 452                    }),
 453                },
 454            )
 455        })
 456        .await?;
 457        self.peer.respond(request.receipt(), response).await?;
 458
 459        Ok(())
 460    }
 461
 462    async fn close_worktree(
 463        self: Arc<Server>,
 464        request: TypedEnvelope<proto::CloseWorktree>,
 465    ) -> tide::Result<()> {
 466        let worktree_id = request.payload.worktree_id;
 467        let connection_ids;
 468        let mut is_host = false;
 469        let mut is_guest = false;
 470        {
 471            let mut state = self.state.write().await;
 472            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 473            connection_ids = worktree.connection_ids();
 474
 475            if worktree.host_connection_id == request.sender_id {
 476                is_host = true;
 477                state.remove_worktree(worktree_id);
 478            } else {
 479                let share = worktree.share_mut()?;
 480                if let Some(replica_id) = share.guest_connection_ids.remove(&request.sender_id) {
 481                    is_guest = true;
 482                    share.active_replica_ids.remove(&replica_id);
 483                }
 484            }
 485        }
 486
 487        if is_host {
 488            broadcast(request.sender_id, connection_ids, |conn_id| {
 489                self.peer
 490                    .send(conn_id, proto::UnshareWorktree { worktree_id })
 491            })
 492            .await?;
 493        } else if is_guest {
 494            broadcast(request.sender_id, connection_ids, |conn_id| {
 495                self.peer.send(
 496                    conn_id,
 497                    proto::RemovePeer {
 498                        worktree_id,
 499                        peer_id: request.sender_id.0,
 500                    },
 501                )
 502            })
 503            .await?
 504        }
 505
 506        Ok(())
 507    }
 508
 509    async fn update_worktree(
 510        self: Arc<Server>,
 511        request: TypedEnvelope<proto::UpdateWorktree>,
 512    ) -> tide::Result<()> {
 513        {
 514            let mut state = self.state.write().await;
 515            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 516            let share = worktree.share_mut()?;
 517
 518            for entry_id in &request.payload.removed_entries {
 519                share.entries.remove(&entry_id);
 520            }
 521
 522            for entry in &request.payload.updated_entries {
 523                share.entries.insert(entry.id, entry.clone());
 524            }
 525        }
 526
 527        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 528            .await?;
 529        Ok(())
 530    }
 531
 532    async fn open_buffer(
 533        self: Arc<Server>,
 534        request: TypedEnvelope<proto::OpenBuffer>,
 535    ) -> tide::Result<()> {
 536        let receipt = request.receipt();
 537        let worktree_id = request.payload.worktree_id;
 538        let host_connection_id = self
 539            .state
 540            .read()
 541            .await
 542            .read_worktree(worktree_id, request.sender_id)?
 543            .host_connection_id;
 544
 545        let response = self
 546            .peer
 547            .forward_request(request.sender_id, host_connection_id, request.payload)
 548            .await?;
 549        self.peer.respond(receipt, response).await?;
 550        Ok(())
 551    }
 552
 553    async fn close_buffer(
 554        self: Arc<Server>,
 555        request: TypedEnvelope<proto::CloseBuffer>,
 556    ) -> tide::Result<()> {
 557        let host_connection_id = self
 558            .state
 559            .read()
 560            .await
 561            .read_worktree(request.payload.worktree_id, request.sender_id)?
 562            .host_connection_id;
 563
 564        self.peer
 565            .forward_send(request.sender_id, host_connection_id, request.payload)
 566            .await?;
 567
 568        Ok(())
 569    }
 570
 571    async fn save_buffer(
 572        self: Arc<Server>,
 573        request: TypedEnvelope<proto::SaveBuffer>,
 574    ) -> tide::Result<()> {
 575        let host;
 576        let guests;
 577        {
 578            let state = self.state.read().await;
 579            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
 580            host = worktree.host_connection_id;
 581            guests = worktree
 582                .share()?
 583                .guest_connection_ids
 584                .keys()
 585                .copied()
 586                .collect::<Vec<_>>();
 587        }
 588
 589        let sender = request.sender_id;
 590        let receipt = request.receipt();
 591        let response = self
 592            .peer
 593            .forward_request(sender, host, request.payload.clone())
 594            .await?;
 595
 596        broadcast(host, guests, |conn_id| {
 597            let response = response.clone();
 598            let peer = &self.peer;
 599            async move {
 600                if conn_id == sender {
 601                    peer.respond(receipt, response).await
 602                } else {
 603                    peer.forward_send(host, conn_id, response).await
 604                }
 605            }
 606        })
 607        .await?;
 608
 609        Ok(())
 610    }
 611
 612    async fn update_buffer(
 613        self: Arc<Server>,
 614        request: TypedEnvelope<proto::UpdateBuffer>,
 615    ) -> tide::Result<()> {
 616        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 617            .await?;
 618        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 619        Ok(())
 620    }
 621
 622    async fn buffer_saved(
 623        self: Arc<Server>,
 624        request: TypedEnvelope<proto::BufferSaved>,
 625    ) -> tide::Result<()> {
 626        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 627            .await
 628    }
 629
 630    async fn get_channels(
 631        self: Arc<Server>,
 632        request: TypedEnvelope<proto::GetChannels>,
 633    ) -> tide::Result<()> {
 634        let user_id = self
 635            .state
 636            .read()
 637            .await
 638            .user_id_for_connection(request.sender_id)?;
 639        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 640        self.peer
 641            .respond(
 642                request.receipt(),
 643                proto::GetChannelsResponse {
 644                    channels: channels
 645                        .into_iter()
 646                        .map(|chan| proto::Channel {
 647                            id: chan.id.to_proto(),
 648                            name: chan.name,
 649                        })
 650                        .collect(),
 651                },
 652            )
 653            .await?;
 654        Ok(())
 655    }
 656
 657    async fn get_users(
 658        self: Arc<Server>,
 659        request: TypedEnvelope<proto::GetUsers>,
 660    ) -> tide::Result<()> {
 661        let user_id = self
 662            .state
 663            .read()
 664            .await
 665            .user_id_for_connection(request.sender_id)?;
 666        let receipt = request.receipt();
 667        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 668        let users = self
 669            .app_state
 670            .db
 671            .get_users_by_ids(user_id, user_ids)
 672            .await?
 673            .into_iter()
 674            .map(|user| proto::User {
 675                id: user.id.to_proto(),
 676                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
 677                github_login: user.github_login,
 678            })
 679            .collect();
 680        self.peer
 681            .respond(receipt, proto::GetUsersResponse { users })
 682            .await?;
 683        Ok(())
 684    }
 685
 686    async fn get_collaborators(
 687        self: Arc<Server>,
 688        request: TypedEnvelope<proto::GetCollaborators>,
 689    ) -> tide::Result<()> {
 690        let mut collaborators = HashMap::new();
 691        {
 692            let state = self.state.read().await;
 693            let user_id = state.user_id_for_connection(request.sender_id)?;
 694            for worktree_id in state
 695                .visible_worktrees_by_user_id
 696                .get(&user_id)
 697                .unwrap_or(&HashSet::new())
 698            {
 699                let worktree = &state.worktrees[worktree_id];
 700
 701                let mut participants = Vec::new();
 702                for collaborator_user_id in &worktree.collaborator_user_ids {
 703                    collaborators
 704                        .entry(*collaborator_user_id)
 705                        .or_insert_with(|| proto::Collaborator {
 706                            user_id: collaborator_user_id.to_proto(),
 707                            worktrees: Vec::new(),
 708                            is_online: state.is_online(*collaborator_user_id),
 709                        });
 710
 711                    if let Ok(share) = worktree.share() {
 712                        let mut conn_ids = state.user_connection_ids(*collaborator_user_id);
 713                        if conn_ids.any(|c| share.guest_connection_ids.contains_key(&c)) {
 714                            participants.push(collaborator_user_id.to_proto());
 715                        }
 716                    }
 717                }
 718
 719                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
 720                let host =
 721                    collaborators
 722                        .entry(host_user_id)
 723                        .or_insert_with(|| proto::Collaborator {
 724                            user_id: host_user_id.to_proto(),
 725                            worktrees: Vec::new(),
 726                            is_online: true,
 727                        });
 728                host.worktrees.push(proto::CollaboratorWorktree {
 729                    root_name: worktree.root_name.clone(),
 730                    is_shared: worktree.share().is_ok(),
 731                    participants,
 732                });
 733            }
 734        }
 735
 736        self.peer
 737            .respond(
 738                request.receipt(),
 739                proto::GetCollaboratorsResponse {
 740                    collaborators: collaborators.into_values().collect(),
 741                },
 742            )
 743            .await?;
 744        Ok(())
 745    }
 746
 747    async fn join_channel(
 748        self: Arc<Self>,
 749        request: TypedEnvelope<proto::JoinChannel>,
 750    ) -> tide::Result<()> {
 751        let user_id = self
 752            .state
 753            .read()
 754            .await
 755            .user_id_for_connection(request.sender_id)?;
 756        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 757        if !self
 758            .app_state
 759            .db
 760            .can_user_access_channel(user_id, channel_id)
 761            .await?
 762        {
 763            Err(anyhow!("access denied"))?;
 764        }
 765
 766        self.state
 767            .write()
 768            .await
 769            .join_channel(request.sender_id, channel_id);
 770        let messages = self
 771            .app_state
 772            .db
 773            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 774            .await?
 775            .into_iter()
 776            .map(|msg| proto::ChannelMessage {
 777                id: msg.id.to_proto(),
 778                body: msg.body,
 779                timestamp: msg.sent_at.unix_timestamp() as u64,
 780                sender_id: msg.sender_id.to_proto(),
 781                nonce: Some(msg.nonce.as_u128().into()),
 782            })
 783            .collect::<Vec<_>>();
 784        self.peer
 785            .respond(
 786                request.receipt(),
 787                proto::JoinChannelResponse {
 788                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 789                    messages,
 790                },
 791            )
 792            .await?;
 793        Ok(())
 794    }
 795
 796    async fn leave_channel(
 797        self: Arc<Self>,
 798        request: TypedEnvelope<proto::LeaveChannel>,
 799    ) -> tide::Result<()> {
 800        let user_id = self
 801            .state
 802            .read()
 803            .await
 804            .user_id_for_connection(request.sender_id)?;
 805        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 806        if !self
 807            .app_state
 808            .db
 809            .can_user_access_channel(user_id, channel_id)
 810            .await?
 811        {
 812            Err(anyhow!("access denied"))?;
 813        }
 814
 815        self.state
 816            .write()
 817            .await
 818            .leave_channel(request.sender_id, channel_id);
 819
 820        Ok(())
 821    }
 822
 823    async fn send_channel_message(
 824        self: Arc<Self>,
 825        request: TypedEnvelope<proto::SendChannelMessage>,
 826    ) -> tide::Result<()> {
 827        let receipt = request.receipt();
 828        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 829        let user_id;
 830        let connection_ids;
 831        {
 832            let state = self.state.read().await;
 833            user_id = state.user_id_for_connection(request.sender_id)?;
 834            if let Some(channel) = state.channels.get(&channel_id) {
 835                connection_ids = channel.connection_ids();
 836            } else {
 837                return Ok(());
 838            }
 839        }
 840
 841        // Validate the message body.
 842        let body = request.payload.body.trim().to_string();
 843        if body.len() > MAX_MESSAGE_LEN {
 844            self.peer
 845                .respond_with_error(
 846                    receipt,
 847                    proto::Error {
 848                        message: "message is too long".to_string(),
 849                    },
 850                )
 851                .await?;
 852            return Ok(());
 853        }
 854        if body.is_empty() {
 855            self.peer
 856                .respond_with_error(
 857                    receipt,
 858                    proto::Error {
 859                        message: "message can't be blank".to_string(),
 860                    },
 861                )
 862                .await?;
 863            return Ok(());
 864        }
 865
 866        let timestamp = OffsetDateTime::now_utc();
 867        let nonce = if let Some(nonce) = request.payload.nonce {
 868            nonce
 869        } else {
 870            self.peer
 871                .respond_with_error(
 872                    receipt,
 873                    proto::Error {
 874                        message: "nonce can't be blank".to_string(),
 875                    },
 876                )
 877                .await?;
 878            return Ok(());
 879        };
 880
 881        let message_id = self
 882            .app_state
 883            .db
 884            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
 885            .await?
 886            .to_proto();
 887        let message = proto::ChannelMessage {
 888            sender_id: user_id.to_proto(),
 889            id: message_id,
 890            body,
 891            timestamp: timestamp.unix_timestamp() as u64,
 892            nonce: Some(nonce),
 893        };
 894        broadcast(request.sender_id, connection_ids, |conn_id| {
 895            self.peer.send(
 896                conn_id,
 897                proto::ChannelMessageSent {
 898                    channel_id: channel_id.to_proto(),
 899                    message: Some(message.clone()),
 900                },
 901            )
 902        })
 903        .await?;
 904        self.peer
 905            .respond(
 906                receipt,
 907                proto::SendChannelMessageResponse {
 908                    message: Some(message),
 909                },
 910            )
 911            .await?;
 912        Ok(())
 913    }
 914
 915    async fn get_channel_messages(
 916        self: Arc<Self>,
 917        request: TypedEnvelope<proto::GetChannelMessages>,
 918    ) -> tide::Result<()> {
 919        let user_id = self
 920            .state
 921            .read()
 922            .await
 923            .user_id_for_connection(request.sender_id)?;
 924        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 925        if !self
 926            .app_state
 927            .db
 928            .can_user_access_channel(user_id, channel_id)
 929            .await?
 930        {
 931            Err(anyhow!("access denied"))?;
 932        }
 933
 934        let messages = self
 935            .app_state
 936            .db
 937            .get_channel_messages(
 938                channel_id,
 939                MESSAGE_COUNT_PER_PAGE,
 940                Some(MessageId::from_proto(request.payload.before_message_id)),
 941            )
 942            .await?
 943            .into_iter()
 944            .map(|msg| proto::ChannelMessage {
 945                id: msg.id.to_proto(),
 946                body: msg.body,
 947                timestamp: msg.sent_at.unix_timestamp() as u64,
 948                sender_id: msg.sender_id.to_proto(),
 949                nonce: Some(msg.nonce.as_u128().into()),
 950            })
 951            .collect::<Vec<_>>();
 952        self.peer
 953            .respond(
 954                request.receipt(),
 955                proto::GetChannelMessagesResponse {
 956                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 957                    messages,
 958                },
 959            )
 960            .await?;
 961        Ok(())
 962    }
 963
 964    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
 965        &self,
 966        worktree_id: u64,
 967        message: &TypedEnvelope<T>,
 968    ) -> tide::Result<()> {
 969        let connection_ids = self
 970            .state
 971            .read()
 972            .await
 973            .read_worktree(worktree_id, message.sender_id)?
 974            .connection_ids();
 975
 976        broadcast(message.sender_id, connection_ids, |conn_id| {
 977            self.peer
 978                .forward_send(message.sender_id, conn_id, message.payload.clone())
 979        })
 980        .await?;
 981
 982        Ok(())
 983    }
 984}
 985
 986pub async fn broadcast<F, T>(
 987    sender_id: ConnectionId,
 988    receiver_ids: Vec<ConnectionId>,
 989    mut f: F,
 990) -> anyhow::Result<()>
 991where
 992    F: FnMut(ConnectionId) -> T,
 993    T: Future<Output = anyhow::Result<()>>,
 994{
 995    let futures = receiver_ids
 996        .into_iter()
 997        .filter(|id| *id != sender_id)
 998        .map(|id| f(id));
 999    futures::future::try_join_all(futures).await?;
1000    Ok(())
1001}
1002
1003impl ServerState {
1004    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1005        if let Some(connection) = self.connections.get_mut(&connection_id) {
1006            connection.channels.insert(channel_id);
1007            self.channels
1008                .entry(channel_id)
1009                .or_default()
1010                .connection_ids
1011                .insert(connection_id);
1012        }
1013    }
1014
1015    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1016        if let Some(connection) = self.connections.get_mut(&connection_id) {
1017            connection.channels.remove(&channel_id);
1018            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1019                entry.get_mut().connection_ids.remove(&connection_id);
1020                if entry.get_mut().connection_ids.is_empty() {
1021                    entry.remove();
1022                }
1023            }
1024        }
1025    }
1026
1027    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1028        Ok(self
1029            .connections
1030            .get(&connection_id)
1031            .ok_or_else(|| anyhow!("unknown connection"))?
1032            .user_id)
1033    }
1034
1035    fn user_connection_ids<'a>(
1036        &'a self,
1037        user_id: UserId,
1038    ) -> impl 'a + Iterator<Item = ConnectionId> {
1039        self.connections_by_user_id
1040            .get(&user_id)
1041            .into_iter()
1042            .flatten()
1043            .copied()
1044    }
1045
1046    fn is_online(&self, user_id: UserId) -> bool {
1047        self.connections_by_user_id.contains_key(&user_id)
1048    }
1049
1050    // Add the given connection as a guest of the given worktree
1051    fn join_worktree(
1052        &mut self,
1053        connection_id: ConnectionId,
1054        user_id: UserId,
1055        worktree_id: u64,
1056    ) -> tide::Result<(ReplicaId, &Worktree)> {
1057        let connection = self
1058            .connections
1059            .get_mut(&connection_id)
1060            .ok_or_else(|| anyhow!("no such connection"))?;
1061        let worktree = self
1062            .worktrees
1063            .get_mut(&worktree_id)
1064            .ok_or_else(|| anyhow!("no such worktree"))?;
1065        if !worktree.collaborator_user_ids.contains(&user_id) {
1066            Err(anyhow!("no such worktree"))?;
1067        }
1068
1069        let share = worktree.share_mut()?;
1070        connection.worktrees.insert(worktree_id);
1071
1072        let mut replica_id = 1;
1073        while share.active_replica_ids.contains(&replica_id) {
1074            replica_id += 1;
1075        }
1076        share.active_replica_ids.insert(replica_id);
1077        share.guest_connection_ids.insert(connection_id, replica_id);
1078        return Ok((replica_id, worktree));
1079    }
1080
1081    fn read_worktree(
1082        &self,
1083        worktree_id: u64,
1084        connection_id: ConnectionId,
1085    ) -> tide::Result<&Worktree> {
1086        let worktree = self
1087            .worktrees
1088            .get(&worktree_id)
1089            .ok_or_else(|| anyhow!("worktree not found"))?;
1090
1091        if worktree.host_connection_id == connection_id
1092            || worktree
1093                .share()?
1094                .guest_connection_ids
1095                .contains_key(&connection_id)
1096        {
1097            Ok(worktree)
1098        } else {
1099            Err(anyhow!(
1100                "{} is not a member of worktree {}",
1101                connection_id,
1102                worktree_id
1103            ))?
1104        }
1105    }
1106
1107    fn write_worktree(
1108        &mut self,
1109        worktree_id: u64,
1110        connection_id: ConnectionId,
1111    ) -> tide::Result<&mut Worktree> {
1112        let worktree = self
1113            .worktrees
1114            .get_mut(&worktree_id)
1115            .ok_or_else(|| anyhow!("worktree not found"))?;
1116
1117        if worktree.host_connection_id == connection_id
1118            || worktree.share.as_ref().map_or(false, |share| {
1119                share.guest_connection_ids.contains_key(&connection_id)
1120            })
1121        {
1122            Ok(worktree)
1123        } else {
1124            Err(anyhow!(
1125                "{} is not a member of worktree {}",
1126                connection_id,
1127                worktree_id
1128            ))?
1129        }
1130    }
1131
1132    fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1133        let worktree_id = self.next_worktree_id;
1134        for collaborator_user_id in &worktree.collaborator_user_ids {
1135            self.visible_worktrees_by_user_id
1136                .entry(*collaborator_user_id)
1137                .or_default()
1138                .insert(worktree_id);
1139        }
1140        self.next_worktree_id += 1;
1141        self.worktrees.insert(worktree_id, worktree);
1142        worktree_id
1143    }
1144
1145    fn remove_worktree(&mut self, worktree_id: u64) {
1146        let worktree = self.worktrees.remove(&worktree_id).unwrap();
1147        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1148            connection.worktrees.remove(&worktree_id);
1149        }
1150        if let Some(share) = worktree.share {
1151            for connection_id in share.guest_connection_ids.keys() {
1152                if let Some(connection) = self.connections.get_mut(connection_id) {
1153                    connection.worktrees.remove(&worktree_id);
1154                }
1155            }
1156        }
1157        for collaborator_user_id in worktree.collaborator_user_ids {
1158            if let Some(visible_worktrees) = self
1159                .visible_worktrees_by_user_id
1160                .get_mut(&collaborator_user_id)
1161            {
1162                visible_worktrees.remove(&worktree_id);
1163            }
1164        }
1165    }
1166}
1167
1168impl Worktree {
1169    pub fn connection_ids(&self) -> Vec<ConnectionId> {
1170        if let Some(share) = &self.share {
1171            share
1172                .guest_connection_ids
1173                .keys()
1174                .copied()
1175                .chain(Some(self.host_connection_id))
1176                .collect()
1177        } else {
1178            vec![self.host_connection_id]
1179        }
1180    }
1181
1182    fn share(&self) -> tide::Result<&WorktreeShare> {
1183        Ok(self
1184            .share
1185            .as_ref()
1186            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1187    }
1188
1189    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1190        Ok(self
1191            .share
1192            .as_mut()
1193            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1194    }
1195}
1196
1197impl Channel {
1198    fn connection_ids(&self) -> Vec<ConnectionId> {
1199        self.connection_ids.iter().copied().collect()
1200    }
1201}
1202
1203pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1204    let server = Server::new(app.state().clone(), rpc.clone(), None);
1205    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1206        let user_id = request.ext::<UserId>().copied();
1207        let server = server.clone();
1208        async move {
1209            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1210
1211            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1212            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1213            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1214
1215            if !upgrade_requested {
1216                return Ok(Response::new(StatusCode::UpgradeRequired));
1217            }
1218
1219            let header = match request.header("Sec-Websocket-Key") {
1220                Some(h) => h.as_str(),
1221                None => return Err(anyhow!("expected sec-websocket-key"))?,
1222            };
1223
1224            let mut response = Response::new(StatusCode::SwitchingProtocols);
1225            response.insert_header(UPGRADE, "websocket");
1226            response.insert_header(CONNECTION, "Upgrade");
1227            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1228            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1229            response.insert_header("Sec-Websocket-Version", "13");
1230
1231            let http_res: &mut tide::http::Response = response.as_mut();
1232            let upgrade_receiver = http_res.recv_upgrade().await;
1233            let addr = request.remote().unwrap_or("unknown").to_string();
1234            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1235            task::spawn(async move {
1236                if let Some(stream) = upgrade_receiver.await {
1237                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1238                }
1239            });
1240
1241            Ok(response)
1242        }
1243    });
1244}
1245
1246fn header_contains_ignore_case<T>(
1247    request: &tide::Request<T>,
1248    header_name: HeaderName,
1249    value: &str,
1250) -> bool {
1251    request
1252        .header(header_name)
1253        .map(|h| {
1254            h.as_str()
1255                .split(',')
1256                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1257        })
1258        .unwrap_or(false)
1259}
1260
1261#[cfg(test)]
1262mod tests {
1263    use super::*;
1264    use crate::{
1265        auth,
1266        db::{tests::TestDb, UserId},
1267        github, AppState, Config,
1268    };
1269    use async_std::{sync::RwLockReadGuard, task};
1270    use gpui::TestAppContext;
1271    use parking_lot::Mutex;
1272    use postage::{mpsc, watch};
1273    use serde_json::json;
1274    use sqlx::types::time::OffsetDateTime;
1275    use std::{
1276        path::Path,
1277        sync::{
1278            atomic::{AtomicBool, Ordering::SeqCst},
1279            Arc,
1280        },
1281        time::Duration,
1282    };
1283    use zed::{
1284        channel::{Channel, ChannelDetails, ChannelList},
1285        editor::{Editor, Insert},
1286        fs::{FakeFs, Fs as _},
1287        language::LanguageRegistry,
1288        rpc::{self, Client, Credentials, EstablishConnectionError},
1289        settings,
1290        test::FakeHttpClient,
1291        user::UserStore,
1292        worktree::Worktree,
1293    };
1294    use zrpc::Peer;
1295
1296    #[gpui::test]
1297    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1298        let (window_b, _) = cx_b.add_window(|_| EmptyView);
1299        let settings = cx_b.read(settings::test).1;
1300        let lang_registry = Arc::new(LanguageRegistry::new());
1301
1302        // Connect to a server as 2 clients.
1303        let mut server = TestServer::start().await;
1304        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1305        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1306
1307        cx_a.foreground().forbid_parking();
1308
1309        // Share a local worktree as client A
1310        let fs = Arc::new(FakeFs::new());
1311        fs.insert_tree(
1312            "/a",
1313            json!({
1314                ".zed.toml": r#"collaborators = ["user_b"]"#,
1315                "a.txt": "a-contents",
1316                "b.txt": "b-contents",
1317            }),
1318        )
1319        .await;
1320        let worktree_a = Worktree::open_local(
1321            client_a.clone(),
1322            "/a".as_ref(),
1323            fs,
1324            lang_registry.clone(),
1325            &mut cx_a.to_async(),
1326        )
1327        .await
1328        .unwrap();
1329        worktree_a
1330            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1331            .await;
1332        let worktree_id = worktree_a
1333            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1334            .await
1335            .unwrap();
1336
1337        // Join that worktree as client B, and see that a guest has joined as client A.
1338        let worktree_b = Worktree::open_remote(
1339            client_b.clone(),
1340            worktree_id,
1341            lang_registry.clone(),
1342            &mut cx_b.to_async(),
1343        )
1344        .await
1345        .unwrap();
1346        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1347        worktree_a
1348            .condition(&cx_a, |tree, _| {
1349                tree.peers()
1350                    .values()
1351                    .any(|replica_id| *replica_id == replica_id_b)
1352            })
1353            .await;
1354
1355        // Open the same file as client B and client A.
1356        let buffer_b = worktree_b
1357            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1358            .await
1359            .unwrap();
1360        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1361        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1362        let buffer_a = worktree_a
1363            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1364            .await
1365            .unwrap();
1366
1367        // Create a selection set as client B and see that selection set as client A.
1368        let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1369        buffer_a
1370            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1371            .await;
1372
1373        // Edit the buffer as client B and see that edit as client A.
1374        editor_b.update(&mut cx_b, |editor, cx| {
1375            editor.insert(&Insert("ok, ".into()), cx)
1376        });
1377        buffer_a
1378            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1379            .await;
1380
1381        // Remove the selection set as client B, see those selections disappear as client A.
1382        cx_b.update(move |_| drop(editor_b));
1383        buffer_a
1384            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1385            .await;
1386
1387        // Close the buffer as client A, see that the buffer is closed.
1388        cx_a.update(move |_| drop(buffer_a));
1389        worktree_a
1390            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1391            .await;
1392
1393        // Dropping the worktree removes client B from client A's peers.
1394        cx_b.update(move |_| drop(worktree_b));
1395        worktree_a
1396            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1397            .await;
1398    }
1399
1400    #[gpui::test]
1401    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1402        mut cx_a: TestAppContext,
1403        mut cx_b: TestAppContext,
1404        mut cx_c: TestAppContext,
1405    ) {
1406        cx_a.foreground().forbid_parking();
1407        let lang_registry = Arc::new(LanguageRegistry::new());
1408
1409        // Connect to a server as 3 clients.
1410        let mut server = TestServer::start().await;
1411        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1412        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1413        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1414
1415        let fs = Arc::new(FakeFs::new());
1416
1417        // Share a worktree as client A.
1418        fs.insert_tree(
1419            "/a",
1420            json!({
1421                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1422                "file1": "",
1423                "file2": ""
1424            }),
1425        )
1426        .await;
1427
1428        let worktree_a = Worktree::open_local(
1429            client_a.clone(),
1430            "/a".as_ref(),
1431            fs.clone(),
1432            lang_registry.clone(),
1433            &mut cx_a.to_async(),
1434        )
1435        .await
1436        .unwrap();
1437        worktree_a
1438            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1439            .await;
1440        let worktree_id = worktree_a
1441            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1442            .await
1443            .unwrap();
1444
1445        // Join that worktree as clients B and C.
1446        let worktree_b = Worktree::open_remote(
1447            client_b.clone(),
1448            worktree_id,
1449            lang_registry.clone(),
1450            &mut cx_b.to_async(),
1451        )
1452        .await
1453        .unwrap();
1454        let worktree_c = Worktree::open_remote(
1455            client_c.clone(),
1456            worktree_id,
1457            lang_registry.clone(),
1458            &mut cx_c.to_async(),
1459        )
1460        .await
1461        .unwrap();
1462
1463        // Open and edit a buffer as both guests B and C.
1464        let buffer_b = worktree_b
1465            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1466            .await
1467            .unwrap();
1468        let buffer_c = worktree_c
1469            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1470            .await
1471            .unwrap();
1472        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1473        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1474
1475        // Open and edit that buffer as the host.
1476        let buffer_a = worktree_a
1477            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1478            .await
1479            .unwrap();
1480
1481        buffer_a
1482            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1483            .await;
1484        buffer_a.update(&mut cx_a, |buf, cx| {
1485            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1486        });
1487
1488        // Wait for edits to propagate
1489        buffer_a
1490            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1491            .await;
1492        buffer_b
1493            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1494            .await;
1495        buffer_c
1496            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1497            .await;
1498
1499        // Edit the buffer as the host and concurrently save as guest B.
1500        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1501        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1502        save_b.await.unwrap();
1503        assert_eq!(
1504            fs.load("/a/file1".as_ref()).await.unwrap(),
1505            "hi-a, i-am-c, i-am-b, i-am-a"
1506        );
1507        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1508        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1509        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1510
1511        // Make changes on host's file system, see those changes on the guests.
1512        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1513            .await
1514            .unwrap();
1515        fs.insert_file(Path::new("/a/file4"), "4".into())
1516            .await
1517            .unwrap();
1518
1519        worktree_b
1520            .condition(&cx_b, |tree, _| tree.file_count() == 4)
1521            .await;
1522        worktree_c
1523            .condition(&cx_c, |tree, _| tree.file_count() == 4)
1524            .await;
1525        worktree_b.read_with(&cx_b, |tree, _| {
1526            assert_eq!(
1527                tree.paths()
1528                    .map(|p| p.to_string_lossy())
1529                    .collect::<Vec<_>>(),
1530                &[".zed.toml", "file1", "file3", "file4"]
1531            )
1532        });
1533        worktree_c.read_with(&cx_c, |tree, _| {
1534            assert_eq!(
1535                tree.paths()
1536                    .map(|p| p.to_string_lossy())
1537                    .collect::<Vec<_>>(),
1538                &[".zed.toml", "file1", "file3", "file4"]
1539            )
1540        });
1541    }
1542
1543    #[gpui::test]
1544    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1545        cx_a.foreground().forbid_parking();
1546        let lang_registry = Arc::new(LanguageRegistry::new());
1547
1548        // Connect to a server as 2 clients.
1549        let mut server = TestServer::start().await;
1550        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1551        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1552
1553        // Share a local worktree as client A
1554        let fs = Arc::new(FakeFs::new());
1555        fs.insert_tree(
1556            "/dir",
1557            json!({
1558                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1559                "a.txt": "a-contents",
1560            }),
1561        )
1562        .await;
1563
1564        let worktree_a = Worktree::open_local(
1565            client_a.clone(),
1566            "/dir".as_ref(),
1567            fs,
1568            lang_registry.clone(),
1569            &mut cx_a.to_async(),
1570        )
1571        .await
1572        .unwrap();
1573        worktree_a
1574            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1575            .await;
1576        let worktree_id = worktree_a
1577            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1578            .await
1579            .unwrap();
1580
1581        // Join that worktree as client B, and see that a guest has joined as client A.
1582        let worktree_b = Worktree::open_remote(
1583            client_b.clone(),
1584            worktree_id,
1585            lang_registry.clone(),
1586            &mut cx_b.to_async(),
1587        )
1588        .await
1589        .unwrap();
1590
1591        let buffer_b = worktree_b
1592            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1593            .await
1594            .unwrap();
1595        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1596
1597        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1598        buffer_b.read_with(&cx_b, |buf, _| {
1599            assert!(buf.is_dirty());
1600            assert!(!buf.has_conflict());
1601        });
1602
1603        buffer_b
1604            .update(&mut cx_b, |buf, cx| buf.save(cx))
1605            .unwrap()
1606            .await
1607            .unwrap();
1608        worktree_b
1609            .condition(&cx_b, |_, cx| {
1610                buffer_b.read(cx).file().unwrap().mtime != mtime
1611            })
1612            .await;
1613        buffer_b.read_with(&cx_b, |buf, _| {
1614            assert!(!buf.is_dirty());
1615            assert!(!buf.has_conflict());
1616        });
1617
1618        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1619        buffer_b.read_with(&cx_b, |buf, _| {
1620            assert!(buf.is_dirty());
1621            assert!(!buf.has_conflict());
1622        });
1623    }
1624
1625    #[gpui::test]
1626    async fn test_editing_while_guest_opens_buffer(
1627        mut cx_a: TestAppContext,
1628        mut cx_b: TestAppContext,
1629    ) {
1630        cx_a.foreground().forbid_parking();
1631        let lang_registry = Arc::new(LanguageRegistry::new());
1632
1633        // Connect to a server as 2 clients.
1634        let mut server = TestServer::start().await;
1635        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1636        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1637
1638        // Share a local worktree as client A
1639        let fs = Arc::new(FakeFs::new());
1640        fs.insert_tree(
1641            "/dir",
1642            json!({
1643                ".zed.toml": r#"collaborators = ["user_b"]"#,
1644                "a.txt": "a-contents",
1645            }),
1646        )
1647        .await;
1648        let worktree_a = Worktree::open_local(
1649            client_a.clone(),
1650            "/dir".as_ref(),
1651            fs,
1652            lang_registry.clone(),
1653            &mut cx_a.to_async(),
1654        )
1655        .await
1656        .unwrap();
1657        worktree_a
1658            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1659            .await;
1660        let worktree_id = worktree_a
1661            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1662            .await
1663            .unwrap();
1664
1665        // Join that worktree as client B, and see that a guest has joined as client A.
1666        let worktree_b = Worktree::open_remote(
1667            client_b.clone(),
1668            worktree_id,
1669            lang_registry.clone(),
1670            &mut cx_b.to_async(),
1671        )
1672        .await
1673        .unwrap();
1674
1675        let buffer_a = worktree_a
1676            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1677            .await
1678            .unwrap();
1679        let buffer_b = cx_b
1680            .background()
1681            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1682
1683        task::yield_now().await;
1684        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1685
1686        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1687        let buffer_b = buffer_b.await.unwrap();
1688        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1689    }
1690
1691    #[gpui::test]
1692    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1693        cx_a.foreground().forbid_parking();
1694        let lang_registry = Arc::new(LanguageRegistry::new());
1695
1696        // Connect to a server as 2 clients.
1697        let mut server = TestServer::start().await;
1698        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1699        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1700
1701        // Share a local worktree as client A
1702        let fs = Arc::new(FakeFs::new());
1703        fs.insert_tree(
1704            "/a",
1705            json!({
1706                ".zed.toml": r#"collaborators = ["user_b"]"#,
1707                "a.txt": "a-contents",
1708                "b.txt": "b-contents",
1709            }),
1710        )
1711        .await;
1712        let worktree_a = Worktree::open_local(
1713            client_a.clone(),
1714            "/a".as_ref(),
1715            fs,
1716            lang_registry.clone(),
1717            &mut cx_a.to_async(),
1718        )
1719        .await
1720        .unwrap();
1721        worktree_a
1722            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1723            .await;
1724        let worktree_id = worktree_a
1725            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1726            .await
1727            .unwrap();
1728
1729        // Join that worktree as client B, and see that a guest has joined as client A.
1730        let _worktree_b = Worktree::open_remote(
1731            client_b.clone(),
1732            worktree_id,
1733            lang_registry.clone(),
1734            &mut cx_b.to_async(),
1735        )
1736        .await
1737        .unwrap();
1738        worktree_a
1739            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1740            .await;
1741
1742        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1743        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1744        worktree_a
1745            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1746            .await;
1747    }
1748
1749    #[gpui::test]
1750    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1751        cx_a.foreground().forbid_parking();
1752
1753        // Connect to a server as 2 clients.
1754        let mut server = TestServer::start().await;
1755        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1756        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1757
1758        // Create an org that includes these 2 users.
1759        let db = &server.app_state.db;
1760        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1761        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1762            .await
1763            .unwrap();
1764        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1765            .await
1766            .unwrap();
1767
1768        // Create a channel that includes all the users.
1769        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1770        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1771            .await
1772            .unwrap();
1773        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1774            .await
1775            .unwrap();
1776        db.create_channel_message(
1777            channel_id,
1778            current_user_id(&user_store_b),
1779            "hello A, it's B.",
1780            OffsetDateTime::now_utc(),
1781            1,
1782        )
1783        .await
1784        .unwrap();
1785
1786        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1787        channels_a
1788            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1789            .await;
1790        channels_a.read_with(&cx_a, |list, _| {
1791            assert_eq!(
1792                list.available_channels().unwrap(),
1793                &[ChannelDetails {
1794                    id: channel_id.to_proto(),
1795                    name: "test-channel".to_string()
1796                }]
1797            )
1798        });
1799        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1800            this.get_channel(channel_id.to_proto(), cx).unwrap()
1801        });
1802        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1803        channel_a
1804            .condition(&cx_a, |channel, _| {
1805                channel_messages(channel)
1806                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1807            })
1808            .await;
1809
1810        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1811        channels_b
1812            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1813            .await;
1814        channels_b.read_with(&cx_b, |list, _| {
1815            assert_eq!(
1816                list.available_channels().unwrap(),
1817                &[ChannelDetails {
1818                    id: channel_id.to_proto(),
1819                    name: "test-channel".to_string()
1820                }]
1821            )
1822        });
1823
1824        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1825            this.get_channel(channel_id.to_proto(), cx).unwrap()
1826        });
1827        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1828        channel_b
1829            .condition(&cx_b, |channel, _| {
1830                channel_messages(channel)
1831                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1832            })
1833            .await;
1834
1835        channel_a
1836            .update(&mut cx_a, |channel, cx| {
1837                channel
1838                    .send_message("oh, hi B.".to_string(), cx)
1839                    .unwrap()
1840                    .detach();
1841                let task = channel.send_message("sup".to_string(), cx).unwrap();
1842                assert_eq!(
1843                    channel_messages(channel),
1844                    &[
1845                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1846                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
1847                        ("user_a".to_string(), "sup".to_string(), true)
1848                    ]
1849                );
1850                task
1851            })
1852            .await
1853            .unwrap();
1854
1855        channel_b
1856            .condition(&cx_b, |channel, _| {
1857                channel_messages(channel)
1858                    == [
1859                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1860                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
1861                        ("user_a".to_string(), "sup".to_string(), false),
1862                    ]
1863            })
1864            .await;
1865
1866        assert_eq!(
1867            server.state().await.channels[&channel_id]
1868                .connection_ids
1869                .len(),
1870            2
1871        );
1872        cx_b.update(|_| drop(channel_b));
1873        server
1874            .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1875            .await;
1876
1877        cx_a.update(|_| drop(channel_a));
1878        server
1879            .condition(|state| !state.channels.contains_key(&channel_id))
1880            .await;
1881    }
1882
1883    #[gpui::test]
1884    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1885        cx_a.foreground().forbid_parking();
1886
1887        let mut server = TestServer::start().await;
1888        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1889
1890        let db = &server.app_state.db;
1891        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1892        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1893        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1894            .await
1895            .unwrap();
1896        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1897            .await
1898            .unwrap();
1899
1900        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1901        channels_a
1902            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1903            .await;
1904        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1905            this.get_channel(channel_id.to_proto(), cx).unwrap()
1906        });
1907
1908        // Messages aren't allowed to be too long.
1909        channel_a
1910            .update(&mut cx_a, |channel, cx| {
1911                let long_body = "this is long.\n".repeat(1024);
1912                channel.send_message(long_body, cx).unwrap()
1913            })
1914            .await
1915            .unwrap_err();
1916
1917        // Messages aren't allowed to be blank.
1918        channel_a.update(&mut cx_a, |channel, cx| {
1919            channel.send_message(String::new(), cx).unwrap_err()
1920        });
1921
1922        // Leading and trailing whitespace are trimmed.
1923        channel_a
1924            .update(&mut cx_a, |channel, cx| {
1925                channel
1926                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1927                    .unwrap()
1928            })
1929            .await
1930            .unwrap();
1931        assert_eq!(
1932            db.get_channel_messages(channel_id, 10, None)
1933                .await
1934                .unwrap()
1935                .iter()
1936                .map(|m| &m.body)
1937                .collect::<Vec<_>>(),
1938            &["surrounded by whitespace"]
1939        );
1940    }
1941
1942    #[gpui::test]
1943    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1944        cx_a.foreground().forbid_parking();
1945        let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1946
1947        // Connect to a server as 2 clients.
1948        let mut server = TestServer::start().await;
1949        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1950        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1951        let mut status_b = client_b.status();
1952
1953        // Create an org that includes these 2 users.
1954        let db = &server.app_state.db;
1955        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1956        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1957            .await
1958            .unwrap();
1959        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1960            .await
1961            .unwrap();
1962
1963        // Create a channel that includes all the users.
1964        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1965        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1966            .await
1967            .unwrap();
1968        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1969            .await
1970            .unwrap();
1971        db.create_channel_message(
1972            channel_id,
1973            current_user_id(&user_store_b),
1974            "hello A, it's B.",
1975            OffsetDateTime::now_utc(),
1976            2,
1977        )
1978        .await
1979        .unwrap();
1980
1981        let user_store_a =
1982            UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
1983        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1984        channels_a
1985            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1986            .await;
1987
1988        channels_a.read_with(&cx_a, |list, _| {
1989            assert_eq!(
1990                list.available_channels().unwrap(),
1991                &[ChannelDetails {
1992                    id: channel_id.to_proto(),
1993                    name: "test-channel".to_string()
1994                }]
1995            )
1996        });
1997        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1998            this.get_channel(channel_id.to_proto(), cx).unwrap()
1999        });
2000        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
2001        channel_a
2002            .condition(&cx_a, |channel, _| {
2003                channel_messages(channel)
2004                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2005            })
2006            .await;
2007
2008        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2009        channels_b
2010            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2011            .await;
2012        channels_b.read_with(&cx_b, |list, _| {
2013            assert_eq!(
2014                list.available_channels().unwrap(),
2015                &[ChannelDetails {
2016                    id: channel_id.to_proto(),
2017                    name: "test-channel".to_string()
2018                }]
2019            )
2020        });
2021
2022        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2023            this.get_channel(channel_id.to_proto(), cx).unwrap()
2024        });
2025        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2026        channel_b
2027            .condition(&cx_b, |channel, _| {
2028                channel_messages(channel)
2029                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2030            })
2031            .await;
2032
2033        // Disconnect client B, ensuring we can still access its cached channel data.
2034        server.forbid_connections();
2035        server.disconnect_client(current_user_id(&user_store_b));
2036        while !matches!(
2037            status_b.recv().await,
2038            Some(rpc::Status::ReconnectionError { .. })
2039        ) {}
2040
2041        channels_b.read_with(&cx_b, |channels, _| {
2042            assert_eq!(
2043                channels.available_channels().unwrap(),
2044                [ChannelDetails {
2045                    id: channel_id.to_proto(),
2046                    name: "test-channel".to_string()
2047                }]
2048            )
2049        });
2050        channel_b.read_with(&cx_b, |channel, _| {
2051            assert_eq!(
2052                channel_messages(channel),
2053                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2054            )
2055        });
2056
2057        // Send a message from client B while it is disconnected.
2058        channel_b
2059            .update(&mut cx_b, |channel, cx| {
2060                let task = channel
2061                    .send_message("can you see this?".to_string(), cx)
2062                    .unwrap();
2063                assert_eq!(
2064                    channel_messages(channel),
2065                    &[
2066                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2067                        ("user_b".to_string(), "can you see this?".to_string(), true)
2068                    ]
2069                );
2070                task
2071            })
2072            .await
2073            .unwrap_err();
2074
2075        // Send a message from client A while B is disconnected.
2076        channel_a
2077            .update(&mut cx_a, |channel, cx| {
2078                channel
2079                    .send_message("oh, hi B.".to_string(), cx)
2080                    .unwrap()
2081                    .detach();
2082                let task = channel.send_message("sup".to_string(), cx).unwrap();
2083                assert_eq!(
2084                    channel_messages(channel),
2085                    &[
2086                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2087                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
2088                        ("user_a".to_string(), "sup".to_string(), true)
2089                    ]
2090                );
2091                task
2092            })
2093            .await
2094            .unwrap();
2095
2096        // Give client B a chance to reconnect.
2097        server.allow_connections();
2098        cx_b.foreground().advance_clock(Duration::from_secs(10));
2099
2100        // Verify that B sees the new messages upon reconnection, as well as the message client B
2101        // sent while offline.
2102        channel_b
2103            .condition(&cx_b, |channel, _| {
2104                channel_messages(channel)
2105                    == [
2106                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2107                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2108                        ("user_a".to_string(), "sup".to_string(), false),
2109                        ("user_b".to_string(), "can you see this?".to_string(), false),
2110                    ]
2111            })
2112            .await;
2113
2114        // Ensure client A and B can communicate normally after reconnection.
2115        channel_a
2116            .update(&mut cx_a, |channel, cx| {
2117                channel.send_message("you online?".to_string(), cx).unwrap()
2118            })
2119            .await
2120            .unwrap();
2121        channel_b
2122            .condition(&cx_b, |channel, _| {
2123                channel_messages(channel)
2124                    == [
2125                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2126                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2127                        ("user_a".to_string(), "sup".to_string(), false),
2128                        ("user_b".to_string(), "can you see this?".to_string(), false),
2129                        ("user_a".to_string(), "you online?".to_string(), false),
2130                    ]
2131            })
2132            .await;
2133
2134        channel_b
2135            .update(&mut cx_b, |channel, cx| {
2136                channel.send_message("yep".to_string(), cx).unwrap()
2137            })
2138            .await
2139            .unwrap();
2140        channel_a
2141            .condition(&cx_a, |channel, _| {
2142                channel_messages(channel)
2143                    == [
2144                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2145                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2146                        ("user_a".to_string(), "sup".to_string(), false),
2147                        ("user_b".to_string(), "can you see this?".to_string(), false),
2148                        ("user_a".to_string(), "you online?".to_string(), false),
2149                        ("user_b".to_string(), "yep".to_string(), false),
2150                    ]
2151            })
2152            .await;
2153    }
2154
2155    struct TestServer {
2156        peer: Arc<Peer>,
2157        app_state: Arc<AppState>,
2158        server: Arc<Server>,
2159        notifications: mpsc::Receiver<()>,
2160        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2161        forbid_connections: Arc<AtomicBool>,
2162        _test_db: TestDb,
2163    }
2164
2165    impl TestServer {
2166        async fn start() -> Self {
2167            let test_db = TestDb::new();
2168            let app_state = Self::build_app_state(&test_db).await;
2169            let peer = Peer::new();
2170            let notifications = mpsc::channel(128);
2171            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2172            Self {
2173                peer,
2174                app_state,
2175                server,
2176                notifications: notifications.1,
2177                connection_killers: Default::default(),
2178                forbid_connections: Default::default(),
2179                _test_db: test_db,
2180            }
2181        }
2182
2183        async fn create_client(
2184            &mut self,
2185            cx: &mut TestAppContext,
2186            name: &str,
2187        ) -> (Arc<Client>, Arc<UserStore>) {
2188            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2189            let client_name = name.to_string();
2190            let mut client = Client::new();
2191            let server = self.server.clone();
2192            let connection_killers = self.connection_killers.clone();
2193            let forbid_connections = self.forbid_connections.clone();
2194            Arc::get_mut(&mut client)
2195                .unwrap()
2196                .override_authenticate(move |cx| {
2197                    cx.spawn(|_| async move {
2198                        let access_token = "the-token".to_string();
2199                        Ok(Credentials {
2200                            user_id: user_id.0 as u64,
2201                            access_token,
2202                        })
2203                    })
2204                })
2205                .override_establish_connection(move |credentials, cx| {
2206                    assert_eq!(credentials.user_id, user_id.0 as u64);
2207                    assert_eq!(credentials.access_token, "the-token");
2208
2209                    let server = server.clone();
2210                    let connection_killers = connection_killers.clone();
2211                    let forbid_connections = forbid_connections.clone();
2212                    let client_name = client_name.clone();
2213                    cx.spawn(move |cx| async move {
2214                        if forbid_connections.load(SeqCst) {
2215                            Err(EstablishConnectionError::other(anyhow!(
2216                                "server is forbidding connections"
2217                            )))
2218                        } else {
2219                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2220                            connection_killers.lock().insert(user_id, kill_conn);
2221                            cx.background()
2222                                .spawn(server.handle_connection(server_conn, client_name, user_id))
2223                                .detach();
2224                            Ok(client_conn)
2225                        }
2226                    })
2227                });
2228
2229            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2230            client
2231                .authenticate_and_connect(&cx.to_async())
2232                .await
2233                .unwrap();
2234
2235            let user_store = UserStore::new(client.clone(), http, &cx.background());
2236            let mut authed_user = user_store.watch_current_user();
2237            while authed_user.recv().await.unwrap().is_none() {}
2238
2239            (client, user_store)
2240        }
2241
2242        fn disconnect_client(&self, user_id: UserId) {
2243            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2244                let _ = kill_conn.try_send(Some(()));
2245            }
2246        }
2247
2248        fn forbid_connections(&self) {
2249            self.forbid_connections.store(true, SeqCst);
2250        }
2251
2252        fn allow_connections(&self) {
2253            self.forbid_connections.store(false, SeqCst);
2254        }
2255
2256        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2257            let mut config = Config::default();
2258            config.session_secret = "a".repeat(32);
2259            config.database_url = test_db.url.clone();
2260            let github_client = github::AppClient::test();
2261            Arc::new(AppState {
2262                db: test_db.db().clone(),
2263                handlebars: Default::default(),
2264                auth_client: auth::build_client("", ""),
2265                repo_client: github::RepoClient::test(&github_client),
2266                github_client,
2267                config,
2268            })
2269        }
2270
2271        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2272            self.server.state.read().await
2273        }
2274
2275        async fn condition<F>(&mut self, mut predicate: F)
2276        where
2277            F: FnMut(&ServerState) -> bool,
2278        {
2279            async_std::future::timeout(Duration::from_millis(500), async {
2280                while !(predicate)(&*self.server.state.read().await) {
2281                    self.notifications.recv().await;
2282                }
2283            })
2284            .await
2285            .expect("condition timed out");
2286        }
2287    }
2288
2289    impl Drop for TestServer {
2290        fn drop(&mut self) {
2291            task::block_on(self.peer.reset());
2292        }
2293    }
2294
2295    fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2296        UserId::from_proto(user_store.current_user().unwrap().id)
2297    }
2298
2299    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2300        channel
2301            .messages()
2302            .cursor::<(), ()>()
2303            .map(|m| {
2304                (
2305                    m.sender.github_login.clone(),
2306                    m.body.clone(),
2307                    m.is_pending(),
2308                )
2309            })
2310            .collect()
2311    }
2312
2313    struct EmptyView;
2314
2315    impl gpui::Entity for EmptyView {
2316        type Event = ();
2317    }
2318
2319    impl gpui::View for EmptyView {
2320        fn ui_name() -> &'static str {
2321            "empty view"
2322        }
2323
2324        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2325            gpui::Element::boxed(gpui::elements::Empty)
2326        }
2327    }
2328}