rpc.rs

   1use super::{
   2    auth,
   3    db::{ChannelId, MessageId, UserId},
   4    AppState,
   5};
   6use anyhow::anyhow;
   7use async_std::{sync::RwLock, task};
   8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
   9use futures::{future::BoxFuture, FutureExt};
  10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  11use sha1::{Digest as _, Sha1};
  12use std::{
  13    any::TypeId,
  14    collections::{hash_map, HashMap, HashSet},
  15    future::Future,
  16    mem,
  17    sync::Arc,
  18    time::Instant,
  19};
  20use surf::StatusCode;
  21use tide::log;
  22use tide::{
  23    http::headers::{HeaderName, CONNECTION, UPGRADE},
  24    Request, Response,
  25};
  26use time::OffsetDateTime;
  27use zrpc::{
  28    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  29    Connection, ConnectionId, Peer, TypedEnvelope,
  30};
  31
  32type ReplicaId = u16;
  33
  34type MessageHandler = Box<
  35    dyn Send
  36        + Sync
  37        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  38>;
  39
  40pub struct Server {
  41    peer: Arc<Peer>,
  42    state: RwLock<ServerState>,
  43    app_state: Arc<AppState>,
  44    handlers: HashMap<TypeId, MessageHandler>,
  45    notifications: Option<mpsc::Sender<()>>,
  46}
  47
  48#[derive(Default)]
  49struct ServerState {
  50    connections: HashMap<ConnectionId, ConnectionState>,
  51    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
  52    pub worktrees: HashMap<u64, Worktree>,
  53    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
  54    channels: HashMap<ChannelId, Channel>,
  55    next_worktree_id: u64,
  56}
  57
  58struct ConnectionState {
  59    user_id: UserId,
  60    worktrees: HashSet<u64>,
  61    channels: HashSet<ChannelId>,
  62}
  63
  64struct Worktree {
  65    host_connection_id: ConnectionId,
  66    collaborator_user_ids: Vec<UserId>,
  67    root_name: String,
  68    share: Option<WorktreeShare>,
  69}
  70
  71struct WorktreeShare {
  72    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
  73    active_replica_ids: HashSet<ReplicaId>,
  74    entries: HashMap<u64, proto::Entry>,
  75}
  76
  77#[derive(Default)]
  78struct Channel {
  79    connection_ids: HashSet<ConnectionId>,
  80}
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84
  85impl Server {
  86    pub fn new(
  87        app_state: Arc<AppState>,
  88        peer: Arc<Peer>,
  89        notifications: Option<mpsc::Sender<()>>,
  90    ) -> Arc<Self> {
  91        let mut server = Self {
  92            peer,
  93            app_state,
  94            state: Default::default(),
  95            handlers: Default::default(),
  96            notifications,
  97        };
  98
  99        server
 100            .add_handler(Server::ping)
 101            .add_handler(Server::open_worktree)
 102            .add_handler(Server::handle_close_worktree)
 103            .add_handler(Server::share_worktree)
 104            .add_handler(Server::unshare_worktree)
 105            .add_handler(Server::join_worktree)
 106            .add_handler(Server::update_worktree)
 107            .add_handler(Server::open_buffer)
 108            .add_handler(Server::close_buffer)
 109            .add_handler(Server::update_buffer)
 110            .add_handler(Server::buffer_saved)
 111            .add_handler(Server::save_buffer)
 112            .add_handler(Server::get_channels)
 113            .add_handler(Server::get_users)
 114            .add_handler(Server::join_channel)
 115            .add_handler(Server::leave_channel)
 116            .add_handler(Server::send_channel_message)
 117            .add_handler(Server::get_channel_messages);
 118
 119        Arc::new(server)
 120    }
 121
 122    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 123    where
 124        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 125        Fut: 'static + Send + Future<Output = tide::Result<()>>,
 126        M: EnvelopedMessage,
 127    {
 128        let prev_handler = self.handlers.insert(
 129            TypeId::of::<M>(),
 130            Box::new(move |server, envelope| {
 131                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 132                (handler)(server, *envelope).boxed()
 133            }),
 134        );
 135        if prev_handler.is_some() {
 136            panic!("registered a handler for the same message twice");
 137        }
 138        self
 139    }
 140
 141    pub fn handle_connection(
 142        self: &Arc<Self>,
 143        connection: Connection,
 144        addr: String,
 145        user_id: UserId,
 146    ) -> impl Future<Output = ()> {
 147        let this = self.clone();
 148        async move {
 149            let (connection_id, handle_io, mut incoming_rx) =
 150                this.peer.add_connection(connection).await;
 151            this.add_connection(connection_id, user_id).await;
 152
 153            let handle_io = handle_io.fuse();
 154            futures::pin_mut!(handle_io);
 155            loop {
 156                let next_message = incoming_rx.recv().fuse();
 157                futures::pin_mut!(next_message);
 158                futures::select_biased! {
 159                    message = next_message => {
 160                        if let Some(message) = message {
 161                            let start_time = Instant::now();
 162                            log::info!("RPC message received: {}", message.payload_type_name());
 163                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 164                                if let Err(err) = (handler)(this.clone(), message).await {
 165                                    log::error!("error handling message: {:?}", err);
 166                                } else {
 167                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 168                                }
 169
 170                                if let Some(mut notifications) = this.notifications.clone() {
 171                                    let _ = notifications.send(()).await;
 172                                }
 173                            } else {
 174                                log::warn!("unhandled message: {}", message.payload_type_name());
 175                            }
 176                        } else {
 177                            log::info!("rpc connection closed {:?}", addr);
 178                            break;
 179                        }
 180                    }
 181                    handle_io = handle_io => {
 182                        if let Err(err) = handle_io {
 183                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 184                        }
 185                        break;
 186                    }
 187                }
 188            }
 189
 190            if let Err(err) = this.sign_out(connection_id).await {
 191                log::error!("error signing out connection {:?} - {:?}", addr, err);
 192            }
 193        }
 194    }
 195
 196    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
 197        self.peer.disconnect(connection_id).await;
 198        self.remove_connection(connection_id).await?;
 199        Ok(())
 200    }
 201
 202    // Add a new connection associated with a given user.
 203    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
 204        let mut state = self.state.write().await;
 205        state.connections.insert(
 206            connection_id,
 207            ConnectionState {
 208                user_id,
 209                worktrees: Default::default(),
 210                channels: Default::default(),
 211            },
 212        );
 213        state
 214            .connections_by_user_id
 215            .entry(user_id)
 216            .or_default()
 217            .insert(connection_id);
 218    }
 219
 220    // Remove the given connection and its association with any worktrees.
 221    async fn remove_connection(
 222        self: &Arc<Server>,
 223        connection_id: ConnectionId,
 224    ) -> tide::Result<()> {
 225        let mut worktree_ids = Vec::new();
 226        let mut state = self.state.write().await;
 227        if let Some(connection) = state.connections.remove(&connection_id) {
 228            worktree_ids = connection.worktrees.into_iter().collect();
 229
 230            for channel_id in connection.channels {
 231                if let Some(channel) = state.channels.get_mut(&channel_id) {
 232                    channel.connection_ids.remove(&connection_id);
 233                }
 234            }
 235
 236            let user_connections = state
 237                .connections_by_user_id
 238                .get_mut(&connection.user_id)
 239                .unwrap();
 240            user_connections.remove(&connection_id);
 241            if user_connections.is_empty() {
 242                state.connections_by_user_id.remove(&connection.user_id);
 243            }
 244        }
 245
 246        drop(state);
 247        for worktree_id in worktree_ids {
 248            self.close_worktree(worktree_id, connection_id).await?;
 249        }
 250
 251        Ok(())
 252    }
 253
 254    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
 255        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 256        Ok(())
 257    }
 258
 259    async fn open_worktree(
 260        self: Arc<Server>,
 261        request: TypedEnvelope<proto::OpenWorktree>,
 262    ) -> tide::Result<()> {
 263        let receipt = request.receipt();
 264        let host_user_id = self
 265            .state
 266            .read()
 267            .await
 268            .user_id_for_connection(request.sender_id)?;
 269
 270        let mut collaborator_user_ids = HashSet::new();
 271        collaborator_user_ids.insert(host_user_id);
 272        for github_login in request.payload.collaborator_logins {
 273            match self.app_state.db.create_user(&github_login, false).await {
 274                Ok(collaborator_user_id) => {
 275                    collaborator_user_ids.insert(collaborator_user_id);
 276                }
 277                Err(err) => {
 278                    let message = err.to_string();
 279                    self.peer
 280                        .respond_with_error(receipt, proto::Error { message })
 281                        .await?;
 282                    return Ok(());
 283                }
 284            }
 285        }
 286
 287        let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
 288        let worktree_id = self.state.write().await.add_worktree(Worktree {
 289            host_connection_id: request.sender_id,
 290            collaborator_user_ids: collaborator_user_ids.clone(),
 291            root_name: request.payload.root_name,
 292            share: None,
 293        });
 294
 295        self.peer
 296            .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
 297            .await?;
 298        self.update_collaborators_for_users(&collaborator_user_ids)
 299            .await?;
 300
 301        Ok(())
 302    }
 303
 304    async fn share_worktree(
 305        self: Arc<Server>,
 306        mut request: TypedEnvelope<proto::ShareWorktree>,
 307    ) -> tide::Result<()> {
 308        let worktree = request
 309            .payload
 310            .worktree
 311            .as_mut()
 312            .ok_or_else(|| anyhow!("missing worktree"))?;
 313        let entries = mem::take(&mut worktree.entries)
 314            .into_iter()
 315            .map(|entry| (entry.id, entry))
 316            .collect();
 317
 318        let mut state = self.state.write().await;
 319        if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
 320            worktree.share = Some(WorktreeShare {
 321                guest_connection_ids: Default::default(),
 322                active_replica_ids: Default::default(),
 323                entries,
 324            });
 325            let collaborator_user_ids = worktree.collaborator_user_ids.clone();
 326
 327            drop(state);
 328            self.peer
 329                .respond(request.receipt(), proto::ShareWorktreeResponse {})
 330                .await?;
 331            self.update_collaborators_for_users(&collaborator_user_ids)
 332                .await?;
 333        } else {
 334            self.peer
 335                .respond_with_error(
 336                    request.receipt(),
 337                    proto::Error {
 338                        message: "no such worktree".to_string(),
 339                    },
 340                )
 341                .await?;
 342        }
 343        Ok(())
 344    }
 345
 346    async fn unshare_worktree(
 347        self: Arc<Server>,
 348        request: TypedEnvelope<proto::UnshareWorktree>,
 349    ) -> tide::Result<()> {
 350        let worktree_id = request.payload.worktree_id;
 351
 352        let connection_ids;
 353        let collaborator_user_ids;
 354        {
 355            let mut state = self.state.write().await;
 356            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 357            if worktree.host_connection_id != request.sender_id {
 358                return Err(anyhow!("no such worktree"))?;
 359            }
 360
 361            connection_ids = worktree.connection_ids();
 362            collaborator_user_ids = worktree.collaborator_user_ids.clone();
 363
 364            worktree.share.take();
 365            for connection_id in &connection_ids {
 366                if let Some(connection) = state.connections.get_mut(connection_id) {
 367                    connection.worktrees.remove(&worktree_id);
 368                }
 369            }
 370        }
 371
 372        broadcast(request.sender_id, connection_ids, |conn_id| {
 373            self.peer
 374                .send(conn_id, proto::UnshareWorktree { worktree_id })
 375        })
 376        .await?;
 377        self.update_collaborators_for_users(&collaborator_user_ids)
 378            .await?;
 379
 380        Ok(())
 381    }
 382
 383    async fn join_worktree(
 384        self: Arc<Server>,
 385        request: TypedEnvelope<proto::JoinWorktree>,
 386    ) -> tide::Result<()> {
 387        let worktree_id = request.payload.worktree_id;
 388        let user_id = self
 389            .state
 390            .read()
 391            .await
 392            .user_id_for_connection(request.sender_id)?;
 393
 394        let response;
 395        let connection_ids;
 396        let collaborator_user_ids;
 397        let mut state = self.state.write().await;
 398        match state.join_worktree(request.sender_id, user_id, worktree_id) {
 399            Ok((peer_replica_id, worktree)) => {
 400                let share = worktree.share()?;
 401                let peer_count = share.guest_connection_ids.len();
 402                let mut peers = Vec::with_capacity(peer_count);
 403                peers.push(proto::Peer {
 404                    peer_id: worktree.host_connection_id.0,
 405                    replica_id: 0,
 406                });
 407                for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
 408                    if *peer_conn_id != request.sender_id {
 409                        peers.push(proto::Peer {
 410                            peer_id: peer_conn_id.0,
 411                            replica_id: *peer_replica_id as u32,
 412                        });
 413                    }
 414                }
 415                response = proto::JoinWorktreeResponse {
 416                    worktree: Some(proto::Worktree {
 417                        id: worktree_id,
 418                        root_name: worktree.root_name.clone(),
 419                        entries: share.entries.values().cloned().collect(),
 420                    }),
 421                    replica_id: peer_replica_id as u32,
 422                    peers,
 423                };
 424                connection_ids = worktree.connection_ids();
 425                collaborator_user_ids = worktree.collaborator_user_ids.clone();
 426            }
 427            Err(error) => {
 428                self.peer
 429                    .respond_with_error(
 430                        request.receipt(),
 431                        proto::Error {
 432                            message: error.to_string(),
 433                        },
 434                    )
 435                    .await?;
 436                return Ok(());
 437            }
 438        }
 439
 440        drop(state);
 441        broadcast(request.sender_id, connection_ids, |conn_id| {
 442            self.peer.send(
 443                conn_id,
 444                proto::AddPeer {
 445                    worktree_id,
 446                    peer: Some(proto::Peer {
 447                        peer_id: request.sender_id.0,
 448                        replica_id: response.replica_id,
 449                    }),
 450                },
 451            )
 452        })
 453        .await?;
 454        self.peer.respond(request.receipt(), response).await?;
 455        self.update_collaborators_for_users(&collaborator_user_ids)
 456            .await?;
 457
 458        Ok(())
 459    }
 460
 461    async fn handle_close_worktree(
 462        self: Arc<Server>,
 463        request: TypedEnvelope<proto::CloseWorktree>,
 464    ) -> tide::Result<()> {
 465        self.close_worktree(request.payload.worktree_id, request.sender_id)
 466            .await
 467    }
 468
 469    async fn close_worktree(
 470        self: &Arc<Server>,
 471        worktree_id: u64,
 472        sender_conn_id: ConnectionId,
 473    ) -> tide::Result<()> {
 474        let connection_ids;
 475        let collaborator_user_ids;
 476        let mut is_host = false;
 477        let mut is_guest = false;
 478        {
 479            let mut state = self.state.write().await;
 480            let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
 481            connection_ids = worktree.connection_ids();
 482            collaborator_user_ids = worktree.collaborator_user_ids.clone();
 483
 484            if worktree.host_connection_id == sender_conn_id {
 485                is_host = true;
 486                state.remove_worktree(worktree_id);
 487            } else {
 488                let share = worktree.share_mut()?;
 489                if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
 490                    is_guest = true;
 491                    share.active_replica_ids.remove(&replica_id);
 492                }
 493            }
 494        }
 495
 496        if is_host {
 497            broadcast(sender_conn_id, connection_ids, |conn_id| {
 498                self.peer
 499                    .send(conn_id, proto::UnshareWorktree { worktree_id })
 500            })
 501            .await?;
 502        } else if is_guest {
 503            broadcast(sender_conn_id, connection_ids, |conn_id| {
 504                self.peer.send(
 505                    conn_id,
 506                    proto::RemovePeer {
 507                        worktree_id,
 508                        peer_id: sender_conn_id.0,
 509                    },
 510                )
 511            })
 512            .await?
 513        }
 514        self.update_collaborators_for_users(&collaborator_user_ids)
 515            .await?;
 516        Ok(())
 517    }
 518
 519    async fn update_worktree(
 520        self: Arc<Server>,
 521        request: TypedEnvelope<proto::UpdateWorktree>,
 522    ) -> tide::Result<()> {
 523        {
 524            let mut state = self.state.write().await;
 525            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 526            let share = worktree.share_mut()?;
 527
 528            for entry_id in &request.payload.removed_entries {
 529                share.entries.remove(&entry_id);
 530            }
 531
 532            for entry in &request.payload.updated_entries {
 533                share.entries.insert(entry.id, entry.clone());
 534            }
 535        }
 536
 537        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 538            .await?;
 539        Ok(())
 540    }
 541
 542    async fn open_buffer(
 543        self: Arc<Server>,
 544        request: TypedEnvelope<proto::OpenBuffer>,
 545    ) -> tide::Result<()> {
 546        let receipt = request.receipt();
 547        let worktree_id = request.payload.worktree_id;
 548        let host_connection_id = self
 549            .state
 550            .read()
 551            .await
 552            .read_worktree(worktree_id, request.sender_id)?
 553            .host_connection_id;
 554
 555        let response = self
 556            .peer
 557            .forward_request(request.sender_id, host_connection_id, request.payload)
 558            .await?;
 559        self.peer.respond(receipt, response).await?;
 560        Ok(())
 561    }
 562
 563    async fn close_buffer(
 564        self: Arc<Server>,
 565        request: TypedEnvelope<proto::CloseBuffer>,
 566    ) -> tide::Result<()> {
 567        let host_connection_id = self
 568            .state
 569            .read()
 570            .await
 571            .read_worktree(request.payload.worktree_id, request.sender_id)?
 572            .host_connection_id;
 573
 574        self.peer
 575            .forward_send(request.sender_id, host_connection_id, request.payload)
 576            .await?;
 577
 578        Ok(())
 579    }
 580
 581    async fn save_buffer(
 582        self: Arc<Server>,
 583        request: TypedEnvelope<proto::SaveBuffer>,
 584    ) -> tide::Result<()> {
 585        let host;
 586        let guests;
 587        {
 588            let state = self.state.read().await;
 589            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
 590            host = worktree.host_connection_id;
 591            guests = worktree
 592                .share()?
 593                .guest_connection_ids
 594                .keys()
 595                .copied()
 596                .collect::<Vec<_>>();
 597        }
 598
 599        let sender = request.sender_id;
 600        let receipt = request.receipt();
 601        let response = self
 602            .peer
 603            .forward_request(sender, host, request.payload.clone())
 604            .await?;
 605
 606        broadcast(host, guests, |conn_id| {
 607            let response = response.clone();
 608            let peer = &self.peer;
 609            async move {
 610                if conn_id == sender {
 611                    peer.respond(receipt, response).await
 612                } else {
 613                    peer.forward_send(host, conn_id, response).await
 614                }
 615            }
 616        })
 617        .await?;
 618
 619        Ok(())
 620    }
 621
 622    async fn update_buffer(
 623        self: Arc<Server>,
 624        request: TypedEnvelope<proto::UpdateBuffer>,
 625    ) -> tide::Result<()> {
 626        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 627            .await?;
 628        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 629        Ok(())
 630    }
 631
 632    async fn buffer_saved(
 633        self: Arc<Server>,
 634        request: TypedEnvelope<proto::BufferSaved>,
 635    ) -> tide::Result<()> {
 636        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 637            .await
 638    }
 639
 640    async fn get_channels(
 641        self: Arc<Server>,
 642        request: TypedEnvelope<proto::GetChannels>,
 643    ) -> tide::Result<()> {
 644        let user_id = self
 645            .state
 646            .read()
 647            .await
 648            .user_id_for_connection(request.sender_id)?;
 649        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 650        self.peer
 651            .respond(
 652                request.receipt(),
 653                proto::GetChannelsResponse {
 654                    channels: channels
 655                        .into_iter()
 656                        .map(|chan| proto::Channel {
 657                            id: chan.id.to_proto(),
 658                            name: chan.name,
 659                        })
 660                        .collect(),
 661                },
 662            )
 663            .await?;
 664        Ok(())
 665    }
 666
 667    async fn get_users(
 668        self: Arc<Server>,
 669        request: TypedEnvelope<proto::GetUsers>,
 670    ) -> tide::Result<()> {
 671        let user_id = self
 672            .state
 673            .read()
 674            .await
 675            .user_id_for_connection(request.sender_id)?;
 676        let receipt = request.receipt();
 677        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 678        let users = self
 679            .app_state
 680            .db
 681            .get_users_by_ids(user_id, user_ids)
 682            .await?
 683            .into_iter()
 684            .map(|user| proto::User {
 685                id: user.id.to_proto(),
 686                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
 687                github_login: user.github_login,
 688            })
 689            .collect();
 690        self.peer
 691            .respond(receipt, proto::GetUsersResponse { users })
 692            .await?;
 693        Ok(())
 694    }
 695
 696    async fn update_collaborators_for_users<'a>(
 697        self: &Arc<Server>,
 698        user_ids: impl IntoIterator<Item = &'a UserId>,
 699    ) -> tide::Result<()> {
 700        let mut send_futures = Vec::new();
 701
 702        let state = self.state.read().await;
 703        for user_id in user_ids {
 704            let mut collaborators = HashMap::new();
 705            for worktree_id in state
 706                .visible_worktrees_by_user_id
 707                .get(&user_id)
 708                .unwrap_or(&HashSet::new())
 709            {
 710                let worktree = &state.worktrees[worktree_id];
 711
 712                let mut participants = HashSet::new();
 713                if let Ok(share) = worktree.share() {
 714                    for guest_connection_id in share.guest_connection_ids.keys() {
 715                        let user_id = state.user_id_for_connection(*guest_connection_id)?;
 716                        participants.insert(user_id.to_proto());
 717                    }
 718                }
 719
 720                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
 721                let host =
 722                    collaborators
 723                        .entry(host_user_id)
 724                        .or_insert_with(|| proto::Collaborator {
 725                            user_id: host_user_id.to_proto(),
 726                            worktrees: Vec::new(),
 727                        });
 728                host.worktrees.push(proto::WorktreeMetadata {
 729                    root_name: worktree.root_name.clone(),
 730                    is_shared: worktree.share().is_ok(),
 731                    participants: participants.into_iter().collect(),
 732                });
 733            }
 734
 735            let collaborators = collaborators.into_values().collect::<Vec<_>>();
 736            for connection_id in state.user_connection_ids(*user_id) {
 737                send_futures.push(self.peer.send(
 738                    connection_id,
 739                    proto::UpdateCollaborators {
 740                        collaborators: collaborators.clone(),
 741                    },
 742                ));
 743            }
 744        }
 745
 746        drop(state);
 747        futures::future::try_join_all(send_futures).await?;
 748
 749        Ok(())
 750    }
 751
 752    async fn join_channel(
 753        self: Arc<Self>,
 754        request: TypedEnvelope<proto::JoinChannel>,
 755    ) -> tide::Result<()> {
 756        let user_id = self
 757            .state
 758            .read()
 759            .await
 760            .user_id_for_connection(request.sender_id)?;
 761        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 762        if !self
 763            .app_state
 764            .db
 765            .can_user_access_channel(user_id, channel_id)
 766            .await?
 767        {
 768            Err(anyhow!("access denied"))?;
 769        }
 770
 771        self.state
 772            .write()
 773            .await
 774            .join_channel(request.sender_id, channel_id);
 775        let messages = self
 776            .app_state
 777            .db
 778            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 779            .await?
 780            .into_iter()
 781            .map(|msg| proto::ChannelMessage {
 782                id: msg.id.to_proto(),
 783                body: msg.body,
 784                timestamp: msg.sent_at.unix_timestamp() as u64,
 785                sender_id: msg.sender_id.to_proto(),
 786                nonce: Some(msg.nonce.as_u128().into()),
 787            })
 788            .collect::<Vec<_>>();
 789        self.peer
 790            .respond(
 791                request.receipt(),
 792                proto::JoinChannelResponse {
 793                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 794                    messages,
 795                },
 796            )
 797            .await?;
 798        Ok(())
 799    }
 800
 801    async fn leave_channel(
 802        self: Arc<Self>,
 803        request: TypedEnvelope<proto::LeaveChannel>,
 804    ) -> tide::Result<()> {
 805        let user_id = self
 806            .state
 807            .read()
 808            .await
 809            .user_id_for_connection(request.sender_id)?;
 810        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 811        if !self
 812            .app_state
 813            .db
 814            .can_user_access_channel(user_id, channel_id)
 815            .await?
 816        {
 817            Err(anyhow!("access denied"))?;
 818        }
 819
 820        self.state
 821            .write()
 822            .await
 823            .leave_channel(request.sender_id, channel_id);
 824
 825        Ok(())
 826    }
 827
 828    async fn send_channel_message(
 829        self: Arc<Self>,
 830        request: TypedEnvelope<proto::SendChannelMessage>,
 831    ) -> tide::Result<()> {
 832        let receipt = request.receipt();
 833        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 834        let user_id;
 835        let connection_ids;
 836        {
 837            let state = self.state.read().await;
 838            user_id = state.user_id_for_connection(request.sender_id)?;
 839            if let Some(channel) = state.channels.get(&channel_id) {
 840                connection_ids = channel.connection_ids();
 841            } else {
 842                return Ok(());
 843            }
 844        }
 845
 846        // Validate the message body.
 847        let body = request.payload.body.trim().to_string();
 848        if body.len() > MAX_MESSAGE_LEN {
 849            self.peer
 850                .respond_with_error(
 851                    receipt,
 852                    proto::Error {
 853                        message: "message is too long".to_string(),
 854                    },
 855                )
 856                .await?;
 857            return Ok(());
 858        }
 859        if body.is_empty() {
 860            self.peer
 861                .respond_with_error(
 862                    receipt,
 863                    proto::Error {
 864                        message: "message can't be blank".to_string(),
 865                    },
 866                )
 867                .await?;
 868            return Ok(());
 869        }
 870
 871        let timestamp = OffsetDateTime::now_utc();
 872        let nonce = if let Some(nonce) = request.payload.nonce {
 873            nonce
 874        } else {
 875            self.peer
 876                .respond_with_error(
 877                    receipt,
 878                    proto::Error {
 879                        message: "nonce can't be blank".to_string(),
 880                    },
 881                )
 882                .await?;
 883            return Ok(());
 884        };
 885
 886        let message_id = self
 887            .app_state
 888            .db
 889            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
 890            .await?
 891            .to_proto();
 892        let message = proto::ChannelMessage {
 893            sender_id: user_id.to_proto(),
 894            id: message_id,
 895            body,
 896            timestamp: timestamp.unix_timestamp() as u64,
 897            nonce: Some(nonce),
 898        };
 899        broadcast(request.sender_id, connection_ids, |conn_id| {
 900            self.peer.send(
 901                conn_id,
 902                proto::ChannelMessageSent {
 903                    channel_id: channel_id.to_proto(),
 904                    message: Some(message.clone()),
 905                },
 906            )
 907        })
 908        .await?;
 909        self.peer
 910            .respond(
 911                receipt,
 912                proto::SendChannelMessageResponse {
 913                    message: Some(message),
 914                },
 915            )
 916            .await?;
 917        Ok(())
 918    }
 919
 920    async fn get_channel_messages(
 921        self: Arc<Self>,
 922        request: TypedEnvelope<proto::GetChannelMessages>,
 923    ) -> tide::Result<()> {
 924        let user_id = self
 925            .state
 926            .read()
 927            .await
 928            .user_id_for_connection(request.sender_id)?;
 929        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 930        if !self
 931            .app_state
 932            .db
 933            .can_user_access_channel(user_id, channel_id)
 934            .await?
 935        {
 936            Err(anyhow!("access denied"))?;
 937        }
 938
 939        let messages = self
 940            .app_state
 941            .db
 942            .get_channel_messages(
 943                channel_id,
 944                MESSAGE_COUNT_PER_PAGE,
 945                Some(MessageId::from_proto(request.payload.before_message_id)),
 946            )
 947            .await?
 948            .into_iter()
 949            .map(|msg| proto::ChannelMessage {
 950                id: msg.id.to_proto(),
 951                body: msg.body,
 952                timestamp: msg.sent_at.unix_timestamp() as u64,
 953                sender_id: msg.sender_id.to_proto(),
 954                nonce: Some(msg.nonce.as_u128().into()),
 955            })
 956            .collect::<Vec<_>>();
 957        self.peer
 958            .respond(
 959                request.receipt(),
 960                proto::GetChannelMessagesResponse {
 961                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 962                    messages,
 963                },
 964            )
 965            .await?;
 966        Ok(())
 967    }
 968
 969    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
 970        &self,
 971        worktree_id: u64,
 972        message: &TypedEnvelope<T>,
 973    ) -> tide::Result<()> {
 974        let connection_ids = self
 975            .state
 976            .read()
 977            .await
 978            .read_worktree(worktree_id, message.sender_id)?
 979            .connection_ids();
 980
 981        broadcast(message.sender_id, connection_ids, |conn_id| {
 982            self.peer
 983                .forward_send(message.sender_id, conn_id, message.payload.clone())
 984        })
 985        .await?;
 986
 987        Ok(())
 988    }
 989}
 990
 991pub async fn broadcast<F, T>(
 992    sender_id: ConnectionId,
 993    receiver_ids: Vec<ConnectionId>,
 994    mut f: F,
 995) -> anyhow::Result<()>
 996where
 997    F: FnMut(ConnectionId) -> T,
 998    T: Future<Output = anyhow::Result<()>>,
 999{
1000    let futures = receiver_ids
1001        .into_iter()
1002        .filter(|id| *id != sender_id)
1003        .map(|id| f(id));
1004    futures::future::try_join_all(futures).await?;
1005    Ok(())
1006}
1007
1008impl ServerState {
1009    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1010        if let Some(connection) = self.connections.get_mut(&connection_id) {
1011            connection.channels.insert(channel_id);
1012            self.channels
1013                .entry(channel_id)
1014                .or_default()
1015                .connection_ids
1016                .insert(connection_id);
1017        }
1018    }
1019
1020    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1021        if let Some(connection) = self.connections.get_mut(&connection_id) {
1022            connection.channels.remove(&channel_id);
1023            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1024                entry.get_mut().connection_ids.remove(&connection_id);
1025                if entry.get_mut().connection_ids.is_empty() {
1026                    entry.remove();
1027                }
1028            }
1029        }
1030    }
1031
1032    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1033        Ok(self
1034            .connections
1035            .get(&connection_id)
1036            .ok_or_else(|| anyhow!("unknown connection"))?
1037            .user_id)
1038    }
1039
1040    fn user_connection_ids<'a>(
1041        &'a self,
1042        user_id: UserId,
1043    ) -> impl 'a + Iterator<Item = ConnectionId> {
1044        self.connections_by_user_id
1045            .get(&user_id)
1046            .into_iter()
1047            .flatten()
1048            .copied()
1049    }
1050
1051    // Add the given connection as a guest of the given worktree
1052    fn join_worktree(
1053        &mut self,
1054        connection_id: ConnectionId,
1055        user_id: UserId,
1056        worktree_id: u64,
1057    ) -> tide::Result<(ReplicaId, &Worktree)> {
1058        let connection = self
1059            .connections
1060            .get_mut(&connection_id)
1061            .ok_or_else(|| anyhow!("no such connection"))?;
1062        let worktree = self
1063            .worktrees
1064            .get_mut(&worktree_id)
1065            .ok_or_else(|| anyhow!("no such worktree"))?;
1066        if !worktree.collaborator_user_ids.contains(&user_id) {
1067            Err(anyhow!("no such worktree"))?;
1068        }
1069
1070        let share = worktree.share_mut()?;
1071        connection.worktrees.insert(worktree_id);
1072
1073        let mut replica_id = 1;
1074        while share.active_replica_ids.contains(&replica_id) {
1075            replica_id += 1;
1076        }
1077        share.active_replica_ids.insert(replica_id);
1078        share.guest_connection_ids.insert(connection_id, replica_id);
1079        return Ok((replica_id, worktree));
1080    }
1081
1082    fn read_worktree(
1083        &self,
1084        worktree_id: u64,
1085        connection_id: ConnectionId,
1086    ) -> tide::Result<&Worktree> {
1087        let worktree = self
1088            .worktrees
1089            .get(&worktree_id)
1090            .ok_or_else(|| anyhow!("worktree not found"))?;
1091
1092        if worktree.host_connection_id == connection_id
1093            || worktree
1094                .share()?
1095                .guest_connection_ids
1096                .contains_key(&connection_id)
1097        {
1098            Ok(worktree)
1099        } else {
1100            Err(anyhow!(
1101                "{} is not a member of worktree {}",
1102                connection_id,
1103                worktree_id
1104            ))?
1105        }
1106    }
1107
1108    fn write_worktree(
1109        &mut self,
1110        worktree_id: u64,
1111        connection_id: ConnectionId,
1112    ) -> tide::Result<&mut Worktree> {
1113        let worktree = self
1114            .worktrees
1115            .get_mut(&worktree_id)
1116            .ok_or_else(|| anyhow!("worktree not found"))?;
1117
1118        if worktree.host_connection_id == connection_id
1119            || worktree.share.as_ref().map_or(false, |share| {
1120                share.guest_connection_ids.contains_key(&connection_id)
1121            })
1122        {
1123            Ok(worktree)
1124        } else {
1125            Err(anyhow!(
1126                "{} is not a member of worktree {}",
1127                connection_id,
1128                worktree_id
1129            ))?
1130        }
1131    }
1132
1133    fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1134        let worktree_id = self.next_worktree_id;
1135        for collaborator_user_id in &worktree.collaborator_user_ids {
1136            self.visible_worktrees_by_user_id
1137                .entry(*collaborator_user_id)
1138                .or_default()
1139                .insert(worktree_id);
1140        }
1141        self.next_worktree_id += 1;
1142        self.worktrees.insert(worktree_id, worktree);
1143        worktree_id
1144    }
1145
1146    fn remove_worktree(&mut self, worktree_id: u64) {
1147        let worktree = self.worktrees.remove(&worktree_id).unwrap();
1148        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1149            connection.worktrees.remove(&worktree_id);
1150        }
1151        if let Some(share) = worktree.share {
1152            for connection_id in share.guest_connection_ids.keys() {
1153                if let Some(connection) = self.connections.get_mut(connection_id) {
1154                    connection.worktrees.remove(&worktree_id);
1155                }
1156            }
1157        }
1158        for collaborator_user_id in worktree.collaborator_user_ids {
1159            if let Some(visible_worktrees) = self
1160                .visible_worktrees_by_user_id
1161                .get_mut(&collaborator_user_id)
1162            {
1163                visible_worktrees.remove(&worktree_id);
1164            }
1165        }
1166    }
1167}
1168
1169impl Worktree {
1170    pub fn connection_ids(&self) -> Vec<ConnectionId> {
1171        if let Some(share) = &self.share {
1172            share
1173                .guest_connection_ids
1174                .keys()
1175                .copied()
1176                .chain(Some(self.host_connection_id))
1177                .collect()
1178        } else {
1179            vec![self.host_connection_id]
1180        }
1181    }
1182
1183    fn share(&self) -> tide::Result<&WorktreeShare> {
1184        Ok(self
1185            .share
1186            .as_ref()
1187            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1188    }
1189
1190    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1191        Ok(self
1192            .share
1193            .as_mut()
1194            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1195    }
1196}
1197
1198impl Channel {
1199    fn connection_ids(&self) -> Vec<ConnectionId> {
1200        self.connection_ids.iter().copied().collect()
1201    }
1202}
1203
1204pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1205    let server = Server::new(app.state().clone(), rpc.clone(), None);
1206    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1207        let user_id = request.ext::<UserId>().copied();
1208        let server = server.clone();
1209        async move {
1210            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1211
1212            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1213            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1214            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1215
1216            if !upgrade_requested {
1217                return Ok(Response::new(StatusCode::UpgradeRequired));
1218            }
1219
1220            let header = match request.header("Sec-Websocket-Key") {
1221                Some(h) => h.as_str(),
1222                None => return Err(anyhow!("expected sec-websocket-key"))?,
1223            };
1224
1225            let mut response = Response::new(StatusCode::SwitchingProtocols);
1226            response.insert_header(UPGRADE, "websocket");
1227            response.insert_header(CONNECTION, "Upgrade");
1228            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1229            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1230            response.insert_header("Sec-Websocket-Version", "13");
1231
1232            let http_res: &mut tide::http::Response = response.as_mut();
1233            let upgrade_receiver = http_res.recv_upgrade().await;
1234            let addr = request.remote().unwrap_or("unknown").to_string();
1235            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1236            task::spawn(async move {
1237                if let Some(stream) = upgrade_receiver.await {
1238                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1239                }
1240            });
1241
1242            Ok(response)
1243        }
1244    });
1245}
1246
1247fn header_contains_ignore_case<T>(
1248    request: &tide::Request<T>,
1249    header_name: HeaderName,
1250    value: &str,
1251) -> bool {
1252    request
1253        .header(header_name)
1254        .map(|h| {
1255            h.as_str()
1256                .split(',')
1257                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1258        })
1259        .unwrap_or(false)
1260}
1261
1262#[cfg(test)]
1263mod tests {
1264    use super::*;
1265    use crate::{
1266        auth,
1267        db::{tests::TestDb, UserId},
1268        github, AppState, Config,
1269    };
1270    use async_std::{sync::RwLockReadGuard, task};
1271    use gpui::{ModelHandle, TestAppContext};
1272    use parking_lot::Mutex;
1273    use postage::{mpsc, watch};
1274    use serde_json::json;
1275    use sqlx::types::time::OffsetDateTime;
1276    use std::{
1277        path::Path,
1278        sync::{
1279            atomic::{AtomicBool, Ordering::SeqCst},
1280            Arc,
1281        },
1282        time::Duration,
1283    };
1284    use zed::{
1285        channel::{Channel, ChannelDetails, ChannelList},
1286        editor::{Editor, Insert},
1287        fs::{FakeFs, Fs as _},
1288        language::LanguageRegistry,
1289        rpc::{self, Client, Credentials, EstablishConnectionError},
1290        settings,
1291        test::FakeHttpClient,
1292        user::UserStore,
1293        worktree::Worktree,
1294    };
1295    use zrpc::Peer;
1296
1297    #[gpui::test]
1298    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1299        let (window_b, _) = cx_b.add_window(|_| EmptyView);
1300        let settings = cx_b.read(settings::test).1;
1301        let lang_registry = Arc::new(LanguageRegistry::new());
1302
1303        // Connect to a server as 2 clients.
1304        let mut server = TestServer::start().await;
1305        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1306        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1307
1308        cx_a.foreground().forbid_parking();
1309
1310        // Share a local worktree as client A
1311        let fs = Arc::new(FakeFs::new());
1312        fs.insert_tree(
1313            "/a",
1314            json!({
1315                ".zed.toml": r#"collaborators = ["user_b"]"#,
1316                "a.txt": "a-contents",
1317                "b.txt": "b-contents",
1318            }),
1319        )
1320        .await;
1321        let worktree_a = Worktree::open_local(
1322            client_a.clone(),
1323            "/a".as_ref(),
1324            fs,
1325            lang_registry.clone(),
1326            &mut cx_a.to_async(),
1327        )
1328        .await
1329        .unwrap();
1330        worktree_a
1331            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1332            .await;
1333        let worktree_id = worktree_a
1334            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1335            .await
1336            .unwrap();
1337
1338        // Join that worktree as client B, and see that a guest has joined as client A.
1339        let worktree_b = Worktree::open_remote(
1340            client_b.clone(),
1341            worktree_id,
1342            lang_registry.clone(),
1343            &mut cx_b.to_async(),
1344        )
1345        .await
1346        .unwrap();
1347        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1348        worktree_a
1349            .condition(&cx_a, |tree, _| {
1350                tree.peers()
1351                    .values()
1352                    .any(|replica_id| *replica_id == replica_id_b)
1353            })
1354            .await;
1355
1356        // Open the same file as client B and client A.
1357        let buffer_b = worktree_b
1358            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1359            .await
1360            .unwrap();
1361        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1362        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1363        let buffer_a = worktree_a
1364            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1365            .await
1366            .unwrap();
1367
1368        // Create a selection set as client B and see that selection set as client A.
1369        let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1370        buffer_a
1371            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1372            .await;
1373
1374        // Edit the buffer as client B and see that edit as client A.
1375        editor_b.update(&mut cx_b, |editor, cx| {
1376            editor.insert(&Insert("ok, ".into()), cx)
1377        });
1378        buffer_a
1379            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1380            .await;
1381
1382        // Remove the selection set as client B, see those selections disappear as client A.
1383        cx_b.update(move |_| drop(editor_b));
1384        buffer_a
1385            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1386            .await;
1387
1388        // Close the buffer as client A, see that the buffer is closed.
1389        cx_a.update(move |_| drop(buffer_a));
1390        worktree_a
1391            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1392            .await;
1393
1394        // Dropping the worktree removes client B from client A's peers.
1395        cx_b.update(move |_| drop(worktree_b));
1396        worktree_a
1397            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1398            .await;
1399    }
1400
1401    #[gpui::test]
1402    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1403        mut cx_a: TestAppContext,
1404        mut cx_b: TestAppContext,
1405        mut cx_c: TestAppContext,
1406    ) {
1407        cx_a.foreground().forbid_parking();
1408        let lang_registry = Arc::new(LanguageRegistry::new());
1409
1410        // Connect to a server as 3 clients.
1411        let mut server = TestServer::start().await;
1412        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1413        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1414        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1415
1416        let fs = Arc::new(FakeFs::new());
1417
1418        // Share a worktree as client A.
1419        fs.insert_tree(
1420            "/a",
1421            json!({
1422                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1423                "file1": "",
1424                "file2": ""
1425            }),
1426        )
1427        .await;
1428
1429        let worktree_a = Worktree::open_local(
1430            client_a.clone(),
1431            "/a".as_ref(),
1432            fs.clone(),
1433            lang_registry.clone(),
1434            &mut cx_a.to_async(),
1435        )
1436        .await
1437        .unwrap();
1438        worktree_a
1439            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1440            .await;
1441        let worktree_id = worktree_a
1442            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1443            .await
1444            .unwrap();
1445
1446        // Join that worktree as clients B and C.
1447        let worktree_b = Worktree::open_remote(
1448            client_b.clone(),
1449            worktree_id,
1450            lang_registry.clone(),
1451            &mut cx_b.to_async(),
1452        )
1453        .await
1454        .unwrap();
1455        let worktree_c = Worktree::open_remote(
1456            client_c.clone(),
1457            worktree_id,
1458            lang_registry.clone(),
1459            &mut cx_c.to_async(),
1460        )
1461        .await
1462        .unwrap();
1463
1464        // Open and edit a buffer as both guests B and C.
1465        let buffer_b = worktree_b
1466            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1467            .await
1468            .unwrap();
1469        let buffer_c = worktree_c
1470            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1471            .await
1472            .unwrap();
1473        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1474        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1475
1476        // Open and edit that buffer as the host.
1477        let buffer_a = worktree_a
1478            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1479            .await
1480            .unwrap();
1481
1482        buffer_a
1483            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1484            .await;
1485        buffer_a.update(&mut cx_a, |buf, cx| {
1486            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1487        });
1488
1489        // Wait for edits to propagate
1490        buffer_a
1491            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1492            .await;
1493        buffer_b
1494            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1495            .await;
1496        buffer_c
1497            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1498            .await;
1499
1500        // Edit the buffer as the host and concurrently save as guest B.
1501        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1502        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1503        save_b.await.unwrap();
1504        assert_eq!(
1505            fs.load("/a/file1".as_ref()).await.unwrap(),
1506            "hi-a, i-am-c, i-am-b, i-am-a"
1507        );
1508        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1509        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1510        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1511
1512        // Make changes on host's file system, see those changes on the guests.
1513        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1514            .await
1515            .unwrap();
1516        fs.insert_file(Path::new("/a/file4"), "4".into())
1517            .await
1518            .unwrap();
1519
1520        worktree_b
1521            .condition(&cx_b, |tree, _| tree.file_count() == 4)
1522            .await;
1523        worktree_c
1524            .condition(&cx_c, |tree, _| tree.file_count() == 4)
1525            .await;
1526        worktree_b.read_with(&cx_b, |tree, _| {
1527            assert_eq!(
1528                tree.paths()
1529                    .map(|p| p.to_string_lossy())
1530                    .collect::<Vec<_>>(),
1531                &[".zed.toml", "file1", "file3", "file4"]
1532            )
1533        });
1534        worktree_c.read_with(&cx_c, |tree, _| {
1535            assert_eq!(
1536                tree.paths()
1537                    .map(|p| p.to_string_lossy())
1538                    .collect::<Vec<_>>(),
1539                &[".zed.toml", "file1", "file3", "file4"]
1540            )
1541        });
1542    }
1543
1544    #[gpui::test]
1545    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1546        cx_a.foreground().forbid_parking();
1547        let lang_registry = Arc::new(LanguageRegistry::new());
1548
1549        // Connect to a server as 2 clients.
1550        let mut server = TestServer::start().await;
1551        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1552        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1553
1554        // Share a local worktree as client A
1555        let fs = Arc::new(FakeFs::new());
1556        fs.insert_tree(
1557            "/dir",
1558            json!({
1559                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1560                "a.txt": "a-contents",
1561            }),
1562        )
1563        .await;
1564
1565        let worktree_a = Worktree::open_local(
1566            client_a.clone(),
1567            "/dir".as_ref(),
1568            fs,
1569            lang_registry.clone(),
1570            &mut cx_a.to_async(),
1571        )
1572        .await
1573        .unwrap();
1574        worktree_a
1575            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1576            .await;
1577        let worktree_id = worktree_a
1578            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1579            .await
1580            .unwrap();
1581
1582        // Join that worktree as client B, and see that a guest has joined as client A.
1583        let worktree_b = Worktree::open_remote(
1584            client_b.clone(),
1585            worktree_id,
1586            lang_registry.clone(),
1587            &mut cx_b.to_async(),
1588        )
1589        .await
1590        .unwrap();
1591
1592        let buffer_b = worktree_b
1593            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1594            .await
1595            .unwrap();
1596        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1597
1598        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1599        buffer_b.read_with(&cx_b, |buf, _| {
1600            assert!(buf.is_dirty());
1601            assert!(!buf.has_conflict());
1602        });
1603
1604        buffer_b
1605            .update(&mut cx_b, |buf, cx| buf.save(cx))
1606            .unwrap()
1607            .await
1608            .unwrap();
1609        worktree_b
1610            .condition(&cx_b, |_, cx| {
1611                buffer_b.read(cx).file().unwrap().mtime != mtime
1612            })
1613            .await;
1614        buffer_b.read_with(&cx_b, |buf, _| {
1615            assert!(!buf.is_dirty());
1616            assert!(!buf.has_conflict());
1617        });
1618
1619        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1620        buffer_b.read_with(&cx_b, |buf, _| {
1621            assert!(buf.is_dirty());
1622            assert!(!buf.has_conflict());
1623        });
1624    }
1625
1626    #[gpui::test]
1627    async fn test_editing_while_guest_opens_buffer(
1628        mut cx_a: TestAppContext,
1629        mut cx_b: TestAppContext,
1630    ) {
1631        cx_a.foreground().forbid_parking();
1632        let lang_registry = Arc::new(LanguageRegistry::new());
1633
1634        // Connect to a server as 2 clients.
1635        let mut server = TestServer::start().await;
1636        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1637        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1638
1639        // Share a local worktree as client A
1640        let fs = Arc::new(FakeFs::new());
1641        fs.insert_tree(
1642            "/dir",
1643            json!({
1644                ".zed.toml": r#"collaborators = ["user_b"]"#,
1645                "a.txt": "a-contents",
1646            }),
1647        )
1648        .await;
1649        let worktree_a = Worktree::open_local(
1650            client_a.clone(),
1651            "/dir".as_ref(),
1652            fs,
1653            lang_registry.clone(),
1654            &mut cx_a.to_async(),
1655        )
1656        .await
1657        .unwrap();
1658        worktree_a
1659            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1660            .await;
1661        let worktree_id = worktree_a
1662            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1663            .await
1664            .unwrap();
1665
1666        // Join that worktree as client B, and see that a guest has joined as client A.
1667        let worktree_b = Worktree::open_remote(
1668            client_b.clone(),
1669            worktree_id,
1670            lang_registry.clone(),
1671            &mut cx_b.to_async(),
1672        )
1673        .await
1674        .unwrap();
1675
1676        let buffer_a = worktree_a
1677            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1678            .await
1679            .unwrap();
1680        let buffer_b = cx_b
1681            .background()
1682            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1683
1684        task::yield_now().await;
1685        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1686
1687        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1688        let buffer_b = buffer_b.await.unwrap();
1689        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1690    }
1691
1692    #[gpui::test]
1693    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1694        cx_a.foreground().forbid_parking();
1695        let lang_registry = Arc::new(LanguageRegistry::new());
1696
1697        // Connect to a server as 2 clients.
1698        let mut server = TestServer::start().await;
1699        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1700        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1701
1702        // Share a local worktree as client A
1703        let fs = Arc::new(FakeFs::new());
1704        fs.insert_tree(
1705            "/a",
1706            json!({
1707                ".zed.toml": r#"collaborators = ["user_b"]"#,
1708                "a.txt": "a-contents",
1709                "b.txt": "b-contents",
1710            }),
1711        )
1712        .await;
1713        let worktree_a = Worktree::open_local(
1714            client_a.clone(),
1715            "/a".as_ref(),
1716            fs,
1717            lang_registry.clone(),
1718            &mut cx_a.to_async(),
1719        )
1720        .await
1721        .unwrap();
1722        worktree_a
1723            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1724            .await;
1725        let worktree_id = worktree_a
1726            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1727            .await
1728            .unwrap();
1729
1730        // Join that worktree as client B, and see that a guest has joined as client A.
1731        let _worktree_b = Worktree::open_remote(
1732            client_b.clone(),
1733            worktree_id,
1734            lang_registry.clone(),
1735            &mut cx_b.to_async(),
1736        )
1737        .await
1738        .unwrap();
1739        worktree_a
1740            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1741            .await;
1742
1743        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1744        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1745        worktree_a
1746            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1747            .await;
1748    }
1749
1750    #[gpui::test]
1751    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1752        cx_a.foreground().forbid_parking();
1753
1754        // Connect to a server as 2 clients.
1755        let mut server = TestServer::start().await;
1756        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1757        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1758
1759        // Create an org that includes these 2 users.
1760        let db = &server.app_state.db;
1761        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1762        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1763            .await
1764            .unwrap();
1765        db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1766            .await
1767            .unwrap();
1768
1769        // Create a channel that includes all the users.
1770        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1771        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1772            .await
1773            .unwrap();
1774        db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1775            .await
1776            .unwrap();
1777        db.create_channel_message(
1778            channel_id,
1779            current_user_id(&user_store_b, &cx_b),
1780            "hello A, it's B.",
1781            OffsetDateTime::now_utc(),
1782            1,
1783        )
1784        .await
1785        .unwrap();
1786
1787        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1788        channels_a
1789            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1790            .await;
1791        channels_a.read_with(&cx_a, |list, _| {
1792            assert_eq!(
1793                list.available_channels().unwrap(),
1794                &[ChannelDetails {
1795                    id: channel_id.to_proto(),
1796                    name: "test-channel".to_string()
1797                }]
1798            )
1799        });
1800        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1801            this.get_channel(channel_id.to_proto(), cx).unwrap()
1802        });
1803        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1804        channel_a
1805            .condition(&cx_a, |channel, _| {
1806                channel_messages(channel)
1807                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1808            })
1809            .await;
1810
1811        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1812        channels_b
1813            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1814            .await;
1815        channels_b.read_with(&cx_b, |list, _| {
1816            assert_eq!(
1817                list.available_channels().unwrap(),
1818                &[ChannelDetails {
1819                    id: channel_id.to_proto(),
1820                    name: "test-channel".to_string()
1821                }]
1822            )
1823        });
1824
1825        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1826            this.get_channel(channel_id.to_proto(), cx).unwrap()
1827        });
1828        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1829        channel_b
1830            .condition(&cx_b, |channel, _| {
1831                channel_messages(channel)
1832                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1833            })
1834            .await;
1835
1836        channel_a
1837            .update(&mut cx_a, |channel, cx| {
1838                channel
1839                    .send_message("oh, hi B.".to_string(), cx)
1840                    .unwrap()
1841                    .detach();
1842                let task = channel.send_message("sup".to_string(), cx).unwrap();
1843                assert_eq!(
1844                    channel_messages(channel),
1845                    &[
1846                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1847                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
1848                        ("user_a".to_string(), "sup".to_string(), true)
1849                    ]
1850                );
1851                task
1852            })
1853            .await
1854            .unwrap();
1855
1856        channel_b
1857            .condition(&cx_b, |channel, _| {
1858                channel_messages(channel)
1859                    == [
1860                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1861                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
1862                        ("user_a".to_string(), "sup".to_string(), false),
1863                    ]
1864            })
1865            .await;
1866
1867        assert_eq!(
1868            server.state().await.channels[&channel_id]
1869                .connection_ids
1870                .len(),
1871            2
1872        );
1873        cx_b.update(|_| drop(channel_b));
1874        server
1875            .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1876            .await;
1877
1878        cx_a.update(|_| drop(channel_a));
1879        server
1880            .condition(|state| !state.channels.contains_key(&channel_id))
1881            .await;
1882    }
1883
1884    #[gpui::test]
1885    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1886        cx_a.foreground().forbid_parking();
1887
1888        let mut server = TestServer::start().await;
1889        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1890
1891        let db = &server.app_state.db;
1892        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1893        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1894        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1895            .await
1896            .unwrap();
1897        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1898            .await
1899            .unwrap();
1900
1901        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1902        channels_a
1903            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1904            .await;
1905        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1906            this.get_channel(channel_id.to_proto(), cx).unwrap()
1907        });
1908
1909        // Messages aren't allowed to be too long.
1910        channel_a
1911            .update(&mut cx_a, |channel, cx| {
1912                let long_body = "this is long.\n".repeat(1024);
1913                channel.send_message(long_body, cx).unwrap()
1914            })
1915            .await
1916            .unwrap_err();
1917
1918        // Messages aren't allowed to be blank.
1919        channel_a.update(&mut cx_a, |channel, cx| {
1920            channel.send_message(String::new(), cx).unwrap_err()
1921        });
1922
1923        // Leading and trailing whitespace are trimmed.
1924        channel_a
1925            .update(&mut cx_a, |channel, cx| {
1926                channel
1927                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1928                    .unwrap()
1929            })
1930            .await
1931            .unwrap();
1932        assert_eq!(
1933            db.get_channel_messages(channel_id, 10, None)
1934                .await
1935                .unwrap()
1936                .iter()
1937                .map(|m| &m.body)
1938                .collect::<Vec<_>>(),
1939            &["surrounded by whitespace"]
1940        );
1941    }
1942
1943    #[gpui::test]
1944    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1945        cx_a.foreground().forbid_parking();
1946
1947        // Connect to a server as 2 clients.
1948        let mut server = TestServer::start().await;
1949        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1950        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1951        let mut status_b = client_b.status();
1952
1953        // Create an org that includes these 2 users.
1954        let db = &server.app_state.db;
1955        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1956        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1957            .await
1958            .unwrap();
1959        db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1960            .await
1961            .unwrap();
1962
1963        // Create a channel that includes all the users.
1964        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1965        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1966            .await
1967            .unwrap();
1968        db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1969            .await
1970            .unwrap();
1971        db.create_channel_message(
1972            channel_id,
1973            current_user_id(&user_store_b, &cx_b),
1974            "hello A, it's B.",
1975            OffsetDateTime::now_utc(),
1976            2,
1977        )
1978        .await
1979        .unwrap();
1980
1981        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1982        channels_a
1983            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1984            .await;
1985
1986        channels_a.read_with(&cx_a, |list, _| {
1987            assert_eq!(
1988                list.available_channels().unwrap(),
1989                &[ChannelDetails {
1990                    id: channel_id.to_proto(),
1991                    name: "test-channel".to_string()
1992                }]
1993            )
1994        });
1995        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1996            this.get_channel(channel_id.to_proto(), cx).unwrap()
1997        });
1998        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1999        channel_a
2000            .condition(&cx_a, |channel, _| {
2001                channel_messages(channel)
2002                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2003            })
2004            .await;
2005
2006        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2007        channels_b
2008            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2009            .await;
2010        channels_b.read_with(&cx_b, |list, _| {
2011            assert_eq!(
2012                list.available_channels().unwrap(),
2013                &[ChannelDetails {
2014                    id: channel_id.to_proto(),
2015                    name: "test-channel".to_string()
2016                }]
2017            )
2018        });
2019
2020        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2021            this.get_channel(channel_id.to_proto(), cx).unwrap()
2022        });
2023        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2024        channel_b
2025            .condition(&cx_b, |channel, _| {
2026                channel_messages(channel)
2027                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2028            })
2029            .await;
2030
2031        // Disconnect client B, ensuring we can still access its cached channel data.
2032        server.forbid_connections();
2033        server.disconnect_client(current_user_id(&user_store_b, &cx_b));
2034        while !matches!(
2035            status_b.recv().await,
2036            Some(rpc::Status::ReconnectionError { .. })
2037        ) {}
2038
2039        channels_b.read_with(&cx_b, |channels, _| {
2040            assert_eq!(
2041                channels.available_channels().unwrap(),
2042                [ChannelDetails {
2043                    id: channel_id.to_proto(),
2044                    name: "test-channel".to_string()
2045                }]
2046            )
2047        });
2048        channel_b.read_with(&cx_b, |channel, _| {
2049            assert_eq!(
2050                channel_messages(channel),
2051                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2052            )
2053        });
2054
2055        // Send a message from client B while it is disconnected.
2056        channel_b
2057            .update(&mut cx_b, |channel, cx| {
2058                let task = channel
2059                    .send_message("can you see this?".to_string(), cx)
2060                    .unwrap();
2061                assert_eq!(
2062                    channel_messages(channel),
2063                    &[
2064                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2065                        ("user_b".to_string(), "can you see this?".to_string(), true)
2066                    ]
2067                );
2068                task
2069            })
2070            .await
2071            .unwrap_err();
2072
2073        // Send a message from client A while B is disconnected.
2074        channel_a
2075            .update(&mut cx_a, |channel, cx| {
2076                channel
2077                    .send_message("oh, hi B.".to_string(), cx)
2078                    .unwrap()
2079                    .detach();
2080                let task = channel.send_message("sup".to_string(), cx).unwrap();
2081                assert_eq!(
2082                    channel_messages(channel),
2083                    &[
2084                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2085                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
2086                        ("user_a".to_string(), "sup".to_string(), true)
2087                    ]
2088                );
2089                task
2090            })
2091            .await
2092            .unwrap();
2093
2094        // Give client B a chance to reconnect.
2095        server.allow_connections();
2096        cx_b.foreground().advance_clock(Duration::from_secs(10));
2097
2098        // Verify that B sees the new messages upon reconnection, as well as the message client B
2099        // sent while offline.
2100        channel_b
2101            .condition(&cx_b, |channel, _| {
2102                channel_messages(channel)
2103                    == [
2104                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2105                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2106                        ("user_a".to_string(), "sup".to_string(), false),
2107                        ("user_b".to_string(), "can you see this?".to_string(), false),
2108                    ]
2109            })
2110            .await;
2111
2112        // Ensure client A and B can communicate normally after reconnection.
2113        channel_a
2114            .update(&mut cx_a, |channel, cx| {
2115                channel.send_message("you online?".to_string(), cx).unwrap()
2116            })
2117            .await
2118            .unwrap();
2119        channel_b
2120            .condition(&cx_b, |channel, _| {
2121                channel_messages(channel)
2122                    == [
2123                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2124                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2125                        ("user_a".to_string(), "sup".to_string(), false),
2126                        ("user_b".to_string(), "can you see this?".to_string(), false),
2127                        ("user_a".to_string(), "you online?".to_string(), false),
2128                    ]
2129            })
2130            .await;
2131
2132        channel_b
2133            .update(&mut cx_b, |channel, cx| {
2134                channel.send_message("yep".to_string(), cx).unwrap()
2135            })
2136            .await
2137            .unwrap();
2138        channel_a
2139            .condition(&cx_a, |channel, _| {
2140                channel_messages(channel)
2141                    == [
2142                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2143                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2144                        ("user_a".to_string(), "sup".to_string(), false),
2145                        ("user_b".to_string(), "can you see this?".to_string(), false),
2146                        ("user_a".to_string(), "you online?".to_string(), false),
2147                        ("user_b".to_string(), "yep".to_string(), false),
2148                    ]
2149            })
2150            .await;
2151    }
2152
2153    struct TestServer {
2154        peer: Arc<Peer>,
2155        app_state: Arc<AppState>,
2156        server: Arc<Server>,
2157        notifications: mpsc::Receiver<()>,
2158        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2159        forbid_connections: Arc<AtomicBool>,
2160        _test_db: TestDb,
2161    }
2162
2163    impl TestServer {
2164        async fn start() -> Self {
2165            let test_db = TestDb::new();
2166            let app_state = Self::build_app_state(&test_db).await;
2167            let peer = Peer::new();
2168            let notifications = mpsc::channel(128);
2169            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2170            Self {
2171                peer,
2172                app_state,
2173                server,
2174                notifications: notifications.1,
2175                connection_killers: Default::default(),
2176                forbid_connections: Default::default(),
2177                _test_db: test_db,
2178            }
2179        }
2180
2181        async fn create_client(
2182            &mut self,
2183            cx: &mut TestAppContext,
2184            name: &str,
2185        ) -> (Arc<Client>, ModelHandle<UserStore>) {
2186            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2187            let client_name = name.to_string();
2188            let mut client = Client::new();
2189            let server = self.server.clone();
2190            let connection_killers = self.connection_killers.clone();
2191            let forbid_connections = self.forbid_connections.clone();
2192            Arc::get_mut(&mut client)
2193                .unwrap()
2194                .override_authenticate(move |cx| {
2195                    cx.spawn(|_| async move {
2196                        let access_token = "the-token".to_string();
2197                        Ok(Credentials {
2198                            user_id: user_id.0 as u64,
2199                            access_token,
2200                        })
2201                    })
2202                })
2203                .override_establish_connection(move |credentials, cx| {
2204                    assert_eq!(credentials.user_id, user_id.0 as u64);
2205                    assert_eq!(credentials.access_token, "the-token");
2206
2207                    let server = server.clone();
2208                    let connection_killers = connection_killers.clone();
2209                    let forbid_connections = forbid_connections.clone();
2210                    let client_name = client_name.clone();
2211                    cx.spawn(move |cx| async move {
2212                        if forbid_connections.load(SeqCst) {
2213                            Err(EstablishConnectionError::other(anyhow!(
2214                                "server is forbidding connections"
2215                            )))
2216                        } else {
2217                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2218                            connection_killers.lock().insert(user_id, kill_conn);
2219                            cx.background()
2220                                .spawn(server.handle_connection(server_conn, client_name, user_id))
2221                                .detach();
2222                            Ok(client_conn)
2223                        }
2224                    })
2225                });
2226
2227            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2228            client
2229                .authenticate_and_connect(&cx.to_async())
2230                .await
2231                .unwrap();
2232
2233            let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
2234            let mut authed_user =
2235                user_store.read_with(cx, |user_store, _| user_store.watch_current_user());
2236            while authed_user.recv().await.unwrap().is_none() {}
2237
2238            (client, user_store)
2239        }
2240
2241        fn disconnect_client(&self, user_id: UserId) {
2242            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2243                let _ = kill_conn.try_send(Some(()));
2244            }
2245        }
2246
2247        fn forbid_connections(&self) {
2248            self.forbid_connections.store(true, SeqCst);
2249        }
2250
2251        fn allow_connections(&self) {
2252            self.forbid_connections.store(false, SeqCst);
2253        }
2254
2255        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2256            let mut config = Config::default();
2257            config.session_secret = "a".repeat(32);
2258            config.database_url = test_db.url.clone();
2259            let github_client = github::AppClient::test();
2260            Arc::new(AppState {
2261                db: test_db.db().clone(),
2262                handlebars: Default::default(),
2263                auth_client: auth::build_client("", ""),
2264                repo_client: github::RepoClient::test(&github_client),
2265                github_client,
2266                config,
2267            })
2268        }
2269
2270        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2271            self.server.state.read().await
2272        }
2273
2274        async fn condition<F>(&mut self, mut predicate: F)
2275        where
2276            F: FnMut(&ServerState) -> bool,
2277        {
2278            async_std::future::timeout(Duration::from_millis(500), async {
2279                while !(predicate)(&*self.server.state.read().await) {
2280                    self.notifications.recv().await;
2281                }
2282            })
2283            .await
2284            .expect("condition timed out");
2285        }
2286    }
2287
2288    impl Drop for TestServer {
2289        fn drop(&mut self) {
2290            task::block_on(self.peer.reset());
2291        }
2292    }
2293
2294    fn current_user_id(user_store: &ModelHandle<UserStore>, cx: &TestAppContext) -> UserId {
2295        UserId::from_proto(
2296            user_store.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
2297        )
2298    }
2299
2300    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2301        channel
2302            .messages()
2303            .cursor::<(), ()>()
2304            .map(|m| {
2305                (
2306                    m.sender.github_login.clone(),
2307                    m.body.clone(),
2308                    m.is_pending(),
2309                )
2310            })
2311            .collect()
2312    }
2313
2314    struct EmptyView;
2315
2316    impl gpui::Entity for EmptyView {
2317        type Event = ();
2318    }
2319
2320    impl gpui::View for EmptyView {
2321        fn ui_name() -> &'static str {
2322            "empty view"
2323        }
2324
2325        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2326            gpui::Element::boxed(gpui::elements::Empty)
2327        }
2328    }
2329}