rpc.rs

   1use super::{
   2    auth,
   3    db::{ChannelId, MessageId, UserId},
   4    AppState,
   5};
   6use anyhow::anyhow;
   7use async_std::{sync::RwLock, task};
   8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
   9use futures::{future::BoxFuture, FutureExt};
  10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  11use sha1::{Digest as _, Sha1};
  12use std::{
  13    any::TypeId,
  14    collections::{hash_map, HashMap, HashSet},
  15    future::Future,
  16    mem,
  17    sync::Arc,
  18    time::Instant,
  19};
  20use surf::StatusCode;
  21use tide::log;
  22use tide::{
  23    http::headers::{HeaderName, CONNECTION, UPGRADE},
  24    Request, Response,
  25};
  26use time::OffsetDateTime;
  27use zrpc::{
  28    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  29    Connection, ConnectionId, Peer, TypedEnvelope,
  30};
  31
  32type ReplicaId = u16;
  33
  34type MessageHandler = Box<
  35    dyn Send
  36        + Sync
  37        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  38>;
  39
  40pub struct Server {
  41    peer: Arc<Peer>,
  42    state: RwLock<ServerState>,
  43    app_state: Arc<AppState>,
  44    handlers: HashMap<TypeId, MessageHandler>,
  45    notifications: Option<mpsc::Sender<()>>,
  46}
  47
  48#[derive(Default)]
  49struct ServerState {
  50    connections: HashMap<ConnectionId, ConnectionState>,
  51    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
  52    pub worktrees: HashMap<u64, Worktree>,
  53    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
  54    channels: HashMap<ChannelId, Channel>,
  55    next_worktree_id: u64,
  56}
  57
  58struct ConnectionState {
  59    user_id: UserId,
  60    worktrees: HashSet<u64>,
  61    channels: HashSet<ChannelId>,
  62}
  63
  64struct Worktree {
  65    host_connection_id: ConnectionId,
  66    collaborator_user_ids: Vec<UserId>,
  67    root_name: String,
  68    share: Option<WorktreeShare>,
  69}
  70
  71struct WorktreeShare {
  72    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
  73    active_replica_ids: HashSet<ReplicaId>,
  74    entries: HashMap<u64, proto::Entry>,
  75}
  76
  77#[derive(Default)]
  78struct Channel {
  79    connection_ids: HashSet<ConnectionId>,
  80}
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84
  85impl Server {
  86    pub fn new(
  87        app_state: Arc<AppState>,
  88        peer: Arc<Peer>,
  89        notifications: Option<mpsc::Sender<()>>,
  90    ) -> Arc<Self> {
  91        let mut server = Self {
  92            peer,
  93            app_state,
  94            state: Default::default(),
  95            handlers: Default::default(),
  96            notifications,
  97        };
  98
  99        server
 100            .add_handler(Server::ping)
 101            .add_handler(Server::open_worktree)
 102            .add_handler(Server::handle_close_worktree)
 103            .add_handler(Server::share_worktree)
 104            .add_handler(Server::unshare_worktree)
 105            .add_handler(Server::join_worktree)
 106            .add_handler(Server::update_worktree)
 107            .add_handler(Server::open_buffer)
 108            .add_handler(Server::close_buffer)
 109            .add_handler(Server::update_buffer)
 110            .add_handler(Server::buffer_saved)
 111            .add_handler(Server::save_buffer)
 112            .add_handler(Server::get_channels)
 113            .add_handler(Server::get_users)
 114            .add_handler(Server::join_channel)
 115            .add_handler(Server::leave_channel)
 116            .add_handler(Server::send_channel_message)
 117            .add_handler(Server::get_channel_messages);
 118
 119        Arc::new(server)
 120    }
 121
 122    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 123    where
 124        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 125        Fut: 'static + Send + Future<Output = tide::Result<()>>,
 126        M: EnvelopedMessage,
 127    {
 128        let prev_handler = self.handlers.insert(
 129            TypeId::of::<M>(),
 130            Box::new(move |server, envelope| {
 131                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 132                (handler)(server, *envelope).boxed()
 133            }),
 134        );
 135        if prev_handler.is_some() {
 136            panic!("registered a handler for the same message twice");
 137        }
 138        self
 139    }
 140
 141    pub fn handle_connection(
 142        self: &Arc<Self>,
 143        connection: Connection,
 144        addr: String,
 145        user_id: UserId,
 146    ) -> impl Future<Output = ()> {
 147        let this = self.clone();
 148        async move {
 149            let (connection_id, handle_io, mut incoming_rx) =
 150                this.peer.add_connection(connection).await;
 151            this.add_connection(connection_id, user_id).await;
 152            if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
 153                log::error!("error updating collaborators for {:?}: {}", user_id, err);
 154            }
 155
 156            let handle_io = handle_io.fuse();
 157            futures::pin_mut!(handle_io);
 158            loop {
 159                let next_message = incoming_rx.recv().fuse();
 160                futures::pin_mut!(next_message);
 161                futures::select_biased! {
 162                    message = next_message => {
 163                        if let Some(message) = message {
 164                            let start_time = Instant::now();
 165                            log::info!("RPC message received: {}", message.payload_type_name());
 166                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 167                                if let Err(err) = (handler)(this.clone(), message).await {
 168                                    log::error!("error handling message: {:?}", err);
 169                                } else {
 170                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 171                                }
 172
 173                                if let Some(mut notifications) = this.notifications.clone() {
 174                                    let _ = notifications.send(()).await;
 175                                }
 176                            } else {
 177                                log::warn!("unhandled message: {}", message.payload_type_name());
 178                            }
 179                        } else {
 180                            log::info!("rpc connection closed {:?}", addr);
 181                            break;
 182                        }
 183                    }
 184                    handle_io = handle_io => {
 185                        if let Err(err) = handle_io {
 186                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 187                        }
 188                        break;
 189                    }
 190                }
 191            }
 192
 193            if let Err(err) = this.sign_out(connection_id).await {
 194                log::error!("error signing out connection {:?} - {:?}", addr, err);
 195            }
 196        }
 197    }
 198
 199    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
 200        self.peer.disconnect(connection_id).await;
 201        self.remove_connection(connection_id).await?;
 202        Ok(())
 203    }
 204
 205    // Add a new connection associated with a given user.
 206    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
 207        let mut state = self.state.write().await;
 208        state.connections.insert(
 209            connection_id,
 210            ConnectionState {
 211                user_id,
 212                worktrees: Default::default(),
 213                channels: Default::default(),
 214            },
 215        );
 216        state
 217            .connections_by_user_id
 218            .entry(user_id)
 219            .or_default()
 220            .insert(connection_id);
 221    }
 222
 223    // Remove the given connection and its association with any worktrees.
 224    async fn remove_connection(
 225        self: &Arc<Server>,
 226        connection_id: ConnectionId,
 227    ) -> tide::Result<()> {
 228        let mut worktree_ids = Vec::new();
 229        let mut state = self.state.write().await;
 230        if let Some(connection) = state.connections.remove(&connection_id) {
 231            worktree_ids = connection.worktrees.into_iter().collect();
 232
 233            for channel_id in connection.channels {
 234                if let Some(channel) = state.channels.get_mut(&channel_id) {
 235                    channel.connection_ids.remove(&connection_id);
 236                }
 237            }
 238
 239            let user_connections = state
 240                .connections_by_user_id
 241                .get_mut(&connection.user_id)
 242                .unwrap();
 243            user_connections.remove(&connection_id);
 244            if user_connections.is_empty() {
 245                state.connections_by_user_id.remove(&connection.user_id);
 246            }
 247        }
 248
 249        drop(state);
 250        for worktree_id in worktree_ids {
 251            self.close_worktree(worktree_id, connection_id).await?;
 252        }
 253
 254        Ok(())
 255    }
 256
 257    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
 258        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 259        Ok(())
 260    }
 261
 262    async fn open_worktree(
 263        self: Arc<Server>,
 264        request: TypedEnvelope<proto::OpenWorktree>,
 265    ) -> tide::Result<()> {
 266        let receipt = request.receipt();
 267        let host_user_id = self
 268            .state
 269            .read()
 270            .await
 271            .user_id_for_connection(request.sender_id)?;
 272
 273        let mut collaborator_user_ids = HashSet::new();
 274        collaborator_user_ids.insert(host_user_id);
 275        for github_login in request.payload.collaborator_logins {
 276            match self.app_state.db.create_user(&github_login, false).await {
 277                Ok(collaborator_user_id) => {
 278                    collaborator_user_ids.insert(collaborator_user_id);
 279                }
 280                Err(err) => {
 281                    let message = err.to_string();
 282                    self.peer
 283                        .respond_with_error(receipt, proto::Error { message })
 284                        .await?;
 285                    return Ok(());
 286                }
 287            }
 288        }
 289
 290        let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
 291        let worktree_id = self.state.write().await.add_worktree(Worktree {
 292            host_connection_id: request.sender_id,
 293            collaborator_user_ids: collaborator_user_ids.clone(),
 294            root_name: request.payload.root_name,
 295            share: None,
 296        });
 297
 298        self.peer
 299            .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
 300            .await?;
 301        self.update_collaborators_for_users(&collaborator_user_ids)
 302            .await?;
 303
 304        Ok(())
 305    }
 306
 307    async fn share_worktree(
 308        self: Arc<Server>,
 309        mut request: TypedEnvelope<proto::ShareWorktree>,
 310    ) -> tide::Result<()> {
 311        let worktree = request
 312            .payload
 313            .worktree
 314            .as_mut()
 315            .ok_or_else(|| anyhow!("missing worktree"))?;
 316        let entries = mem::take(&mut worktree.entries)
 317            .into_iter()
 318            .map(|entry| (entry.id, entry))
 319            .collect();
 320
 321        let mut state = self.state.write().await;
 322        if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
 323            worktree.share = Some(WorktreeShare {
 324                guest_connection_ids: Default::default(),
 325                active_replica_ids: Default::default(),
 326                entries,
 327            });
 328            let collaborator_user_ids = worktree.collaborator_user_ids.clone();
 329
 330            drop(state);
 331            self.peer
 332                .respond(request.receipt(), proto::ShareWorktreeResponse {})
 333                .await?;
 334            self.update_collaborators_for_users(&collaborator_user_ids)
 335                .await?;
 336        } else {
 337            self.peer
 338                .respond_with_error(
 339                    request.receipt(),
 340                    proto::Error {
 341                        message: "no such worktree".to_string(),
 342                    },
 343                )
 344                .await?;
 345        }
 346        Ok(())
 347    }
 348
 349    async fn unshare_worktree(
 350        self: Arc<Server>,
 351        request: TypedEnvelope<proto::UnshareWorktree>,
 352    ) -> tide::Result<()> {
 353        let worktree_id = request.payload.worktree_id;
 354
 355        let connection_ids;
 356        let collaborator_user_ids;
 357        {
 358            let mut state = self.state.write().await;
 359            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 360            if worktree.host_connection_id != request.sender_id {
 361                return Err(anyhow!("no such worktree"))?;
 362            }
 363
 364            connection_ids = worktree.connection_ids();
 365            collaborator_user_ids = worktree.collaborator_user_ids.clone();
 366
 367            worktree.share.take();
 368            for connection_id in &connection_ids {
 369                if let Some(connection) = state.connections.get_mut(connection_id) {
 370                    connection.worktrees.remove(&worktree_id);
 371                }
 372            }
 373        }
 374
 375        broadcast(request.sender_id, connection_ids, |conn_id| {
 376            self.peer
 377                .send(conn_id, proto::UnshareWorktree { worktree_id })
 378        })
 379        .await?;
 380        self.update_collaborators_for_users(&collaborator_user_ids)
 381            .await?;
 382
 383        Ok(())
 384    }
 385
 386    async fn join_worktree(
 387        self: Arc<Server>,
 388        request: TypedEnvelope<proto::JoinWorktree>,
 389    ) -> tide::Result<()> {
 390        let worktree_id = request.payload.worktree_id;
 391        let user_id = self
 392            .state
 393            .read()
 394            .await
 395            .user_id_for_connection(request.sender_id)?;
 396
 397        let response;
 398        let connection_ids;
 399        let collaborator_user_ids;
 400        let mut state = self.state.write().await;
 401        match state.join_worktree(request.sender_id, user_id, worktree_id) {
 402            Ok((peer_replica_id, worktree)) => {
 403                let share = worktree.share()?;
 404                let peer_count = share.guest_connection_ids.len();
 405                let mut peers = Vec::with_capacity(peer_count);
 406                peers.push(proto::Peer {
 407                    peer_id: worktree.host_connection_id.0,
 408                    replica_id: 0,
 409                });
 410                for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
 411                    if *peer_conn_id != request.sender_id {
 412                        peers.push(proto::Peer {
 413                            peer_id: peer_conn_id.0,
 414                            replica_id: *peer_replica_id as u32,
 415                        });
 416                    }
 417                }
 418                response = proto::JoinWorktreeResponse {
 419                    worktree: Some(proto::Worktree {
 420                        id: worktree_id,
 421                        root_name: worktree.root_name.clone(),
 422                        entries: share.entries.values().cloned().collect(),
 423                    }),
 424                    replica_id: peer_replica_id as u32,
 425                    peers,
 426                };
 427                connection_ids = worktree.connection_ids();
 428                collaborator_user_ids = worktree.collaborator_user_ids.clone();
 429            }
 430            Err(error) => {
 431                self.peer
 432                    .respond_with_error(
 433                        request.receipt(),
 434                        proto::Error {
 435                            message: error.to_string(),
 436                        },
 437                    )
 438                    .await?;
 439                return Ok(());
 440            }
 441        }
 442
 443        drop(state);
 444        broadcast(request.sender_id, connection_ids, |conn_id| {
 445            self.peer.send(
 446                conn_id,
 447                proto::AddPeer {
 448                    worktree_id,
 449                    peer: Some(proto::Peer {
 450                        peer_id: request.sender_id.0,
 451                        replica_id: response.replica_id,
 452                    }),
 453                },
 454            )
 455        })
 456        .await?;
 457        self.peer.respond(request.receipt(), response).await?;
 458        self.update_collaborators_for_users(&collaborator_user_ids)
 459            .await?;
 460
 461        Ok(())
 462    }
 463
 464    async fn handle_close_worktree(
 465        self: Arc<Server>,
 466        request: TypedEnvelope<proto::CloseWorktree>,
 467    ) -> tide::Result<()> {
 468        self.close_worktree(request.payload.worktree_id, request.sender_id)
 469            .await
 470    }
 471
 472    async fn close_worktree(
 473        self: &Arc<Server>,
 474        worktree_id: u64,
 475        sender_conn_id: ConnectionId,
 476    ) -> tide::Result<()> {
 477        let connection_ids;
 478        let collaborator_user_ids;
 479        let mut is_host = false;
 480        let mut is_guest = false;
 481        {
 482            let mut state = self.state.write().await;
 483            let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
 484            connection_ids = worktree.connection_ids();
 485            collaborator_user_ids = worktree.collaborator_user_ids.clone();
 486
 487            if worktree.host_connection_id == sender_conn_id {
 488                is_host = true;
 489                state.remove_worktree(worktree_id);
 490            } else {
 491                let share = worktree.share_mut()?;
 492                if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
 493                    is_guest = true;
 494                    share.active_replica_ids.remove(&replica_id);
 495                }
 496            }
 497        }
 498
 499        if is_host {
 500            broadcast(sender_conn_id, connection_ids, |conn_id| {
 501                self.peer
 502                    .send(conn_id, proto::UnshareWorktree { worktree_id })
 503            })
 504            .await?;
 505        } else if is_guest {
 506            broadcast(sender_conn_id, connection_ids, |conn_id| {
 507                self.peer.send(
 508                    conn_id,
 509                    proto::RemovePeer {
 510                        worktree_id,
 511                        peer_id: sender_conn_id.0,
 512                    },
 513                )
 514            })
 515            .await?
 516        }
 517        self.update_collaborators_for_users(&collaborator_user_ids)
 518            .await?;
 519        Ok(())
 520    }
 521
 522    async fn update_worktree(
 523        self: Arc<Server>,
 524        request: TypedEnvelope<proto::UpdateWorktree>,
 525    ) -> tide::Result<()> {
 526        {
 527            let mut state = self.state.write().await;
 528            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 529            let share = worktree.share_mut()?;
 530
 531            for entry_id in &request.payload.removed_entries {
 532                share.entries.remove(&entry_id);
 533            }
 534
 535            for entry in &request.payload.updated_entries {
 536                share.entries.insert(entry.id, entry.clone());
 537            }
 538        }
 539
 540        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 541            .await?;
 542        Ok(())
 543    }
 544
 545    async fn open_buffer(
 546        self: Arc<Server>,
 547        request: TypedEnvelope<proto::OpenBuffer>,
 548    ) -> tide::Result<()> {
 549        let receipt = request.receipt();
 550        let worktree_id = request.payload.worktree_id;
 551        let host_connection_id = self
 552            .state
 553            .read()
 554            .await
 555            .read_worktree(worktree_id, request.sender_id)?
 556            .host_connection_id;
 557
 558        let response = self
 559            .peer
 560            .forward_request(request.sender_id, host_connection_id, request.payload)
 561            .await?;
 562        self.peer.respond(receipt, response).await?;
 563        Ok(())
 564    }
 565
 566    async fn close_buffer(
 567        self: Arc<Server>,
 568        request: TypedEnvelope<proto::CloseBuffer>,
 569    ) -> tide::Result<()> {
 570        let host_connection_id = self
 571            .state
 572            .read()
 573            .await
 574            .read_worktree(request.payload.worktree_id, request.sender_id)?
 575            .host_connection_id;
 576
 577        self.peer
 578            .forward_send(request.sender_id, host_connection_id, request.payload)
 579            .await?;
 580
 581        Ok(())
 582    }
 583
 584    async fn save_buffer(
 585        self: Arc<Server>,
 586        request: TypedEnvelope<proto::SaveBuffer>,
 587    ) -> tide::Result<()> {
 588        let host;
 589        let guests;
 590        {
 591            let state = self.state.read().await;
 592            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
 593            host = worktree.host_connection_id;
 594            guests = worktree
 595                .share()?
 596                .guest_connection_ids
 597                .keys()
 598                .copied()
 599                .collect::<Vec<_>>();
 600        }
 601
 602        let sender = request.sender_id;
 603        let receipt = request.receipt();
 604        let response = self
 605            .peer
 606            .forward_request(sender, host, request.payload.clone())
 607            .await?;
 608
 609        broadcast(host, guests, |conn_id| {
 610            let response = response.clone();
 611            let peer = &self.peer;
 612            async move {
 613                if conn_id == sender {
 614                    peer.respond(receipt, response).await
 615                } else {
 616                    peer.forward_send(host, conn_id, response).await
 617                }
 618            }
 619        })
 620        .await?;
 621
 622        Ok(())
 623    }
 624
 625    async fn update_buffer(
 626        self: Arc<Server>,
 627        request: TypedEnvelope<proto::UpdateBuffer>,
 628    ) -> tide::Result<()> {
 629        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 630            .await?;
 631        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 632        Ok(())
 633    }
 634
 635    async fn buffer_saved(
 636        self: Arc<Server>,
 637        request: TypedEnvelope<proto::BufferSaved>,
 638    ) -> tide::Result<()> {
 639        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 640            .await
 641    }
 642
 643    async fn get_channels(
 644        self: Arc<Server>,
 645        request: TypedEnvelope<proto::GetChannels>,
 646    ) -> tide::Result<()> {
 647        let user_id = self
 648            .state
 649            .read()
 650            .await
 651            .user_id_for_connection(request.sender_id)?;
 652        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 653        self.peer
 654            .respond(
 655                request.receipt(),
 656                proto::GetChannelsResponse {
 657                    channels: channels
 658                        .into_iter()
 659                        .map(|chan| proto::Channel {
 660                            id: chan.id.to_proto(),
 661                            name: chan.name,
 662                        })
 663                        .collect(),
 664                },
 665            )
 666            .await?;
 667        Ok(())
 668    }
 669
 670    async fn get_users(
 671        self: Arc<Server>,
 672        request: TypedEnvelope<proto::GetUsers>,
 673    ) -> tide::Result<()> {
 674        let receipt = request.receipt();
 675        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 676        let users = self
 677            .app_state
 678            .db
 679            .get_users_by_ids(user_ids)
 680            .await?
 681            .into_iter()
 682            .map(|user| proto::User {
 683                id: user.id.to_proto(),
 684                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
 685                github_login: user.github_login,
 686            })
 687            .collect();
 688        self.peer
 689            .respond(receipt, proto::GetUsersResponse { users })
 690            .await?;
 691        Ok(())
 692    }
 693
 694    async fn update_collaborators_for_users<'a>(
 695        self: &Arc<Server>,
 696        user_ids: impl IntoIterator<Item = &'a UserId>,
 697    ) -> tide::Result<()> {
 698        let mut send_futures = Vec::new();
 699
 700        let state = self.state.read().await;
 701        for user_id in user_ids {
 702            let mut collaborators = HashMap::new();
 703            for worktree_id in state
 704                .visible_worktrees_by_user_id
 705                .get(&user_id)
 706                .unwrap_or(&HashSet::new())
 707            {
 708                let worktree = &state.worktrees[worktree_id];
 709
 710                let mut participants = HashSet::new();
 711                if let Ok(share) = worktree.share() {
 712                    for guest_connection_id in share.guest_connection_ids.keys() {
 713                        let user_id = state.user_id_for_connection(*guest_connection_id)?;
 714                        participants.insert(user_id.to_proto());
 715                    }
 716                }
 717
 718                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
 719                let host =
 720                    collaborators
 721                        .entry(host_user_id)
 722                        .or_insert_with(|| proto::Collaborator {
 723                            user_id: host_user_id.to_proto(),
 724                            worktrees: Vec::new(),
 725                        });
 726                host.worktrees.push(proto::WorktreeMetadata {
 727                    root_name: worktree.root_name.clone(),
 728                    is_shared: worktree.share().is_ok(),
 729                    participants: participants.into_iter().collect(),
 730                });
 731            }
 732
 733            let collaborators = collaborators.into_values().collect::<Vec<_>>();
 734            for connection_id in state.user_connection_ids(*user_id) {
 735                send_futures.push(self.peer.send(
 736                    connection_id,
 737                    proto::UpdateCollaborators {
 738                        collaborators: collaborators.clone(),
 739                    },
 740                ));
 741            }
 742        }
 743
 744        drop(state);
 745        futures::future::try_join_all(send_futures).await?;
 746
 747        Ok(())
 748    }
 749
 750    async fn join_channel(
 751        self: Arc<Self>,
 752        request: TypedEnvelope<proto::JoinChannel>,
 753    ) -> tide::Result<()> {
 754        let user_id = self
 755            .state
 756            .read()
 757            .await
 758            .user_id_for_connection(request.sender_id)?;
 759        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 760        if !self
 761            .app_state
 762            .db
 763            .can_user_access_channel(user_id, channel_id)
 764            .await?
 765        {
 766            Err(anyhow!("access denied"))?;
 767        }
 768
 769        self.state
 770            .write()
 771            .await
 772            .join_channel(request.sender_id, channel_id);
 773        let messages = self
 774            .app_state
 775            .db
 776            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 777            .await?
 778            .into_iter()
 779            .map(|msg| proto::ChannelMessage {
 780                id: msg.id.to_proto(),
 781                body: msg.body,
 782                timestamp: msg.sent_at.unix_timestamp() as u64,
 783                sender_id: msg.sender_id.to_proto(),
 784                nonce: Some(msg.nonce.as_u128().into()),
 785            })
 786            .collect::<Vec<_>>();
 787        self.peer
 788            .respond(
 789                request.receipt(),
 790                proto::JoinChannelResponse {
 791                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 792                    messages,
 793                },
 794            )
 795            .await?;
 796        Ok(())
 797    }
 798
 799    async fn leave_channel(
 800        self: Arc<Self>,
 801        request: TypedEnvelope<proto::LeaveChannel>,
 802    ) -> tide::Result<()> {
 803        let user_id = self
 804            .state
 805            .read()
 806            .await
 807            .user_id_for_connection(request.sender_id)?;
 808        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 809        if !self
 810            .app_state
 811            .db
 812            .can_user_access_channel(user_id, channel_id)
 813            .await?
 814        {
 815            Err(anyhow!("access denied"))?;
 816        }
 817
 818        self.state
 819            .write()
 820            .await
 821            .leave_channel(request.sender_id, channel_id);
 822
 823        Ok(())
 824    }
 825
 826    async fn send_channel_message(
 827        self: Arc<Self>,
 828        request: TypedEnvelope<proto::SendChannelMessage>,
 829    ) -> tide::Result<()> {
 830        let receipt = request.receipt();
 831        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 832        let user_id;
 833        let connection_ids;
 834        {
 835            let state = self.state.read().await;
 836            user_id = state.user_id_for_connection(request.sender_id)?;
 837            if let Some(channel) = state.channels.get(&channel_id) {
 838                connection_ids = channel.connection_ids();
 839            } else {
 840                return Ok(());
 841            }
 842        }
 843
 844        // Validate the message body.
 845        let body = request.payload.body.trim().to_string();
 846        if body.len() > MAX_MESSAGE_LEN {
 847            self.peer
 848                .respond_with_error(
 849                    receipt,
 850                    proto::Error {
 851                        message: "message is too long".to_string(),
 852                    },
 853                )
 854                .await?;
 855            return Ok(());
 856        }
 857        if body.is_empty() {
 858            self.peer
 859                .respond_with_error(
 860                    receipt,
 861                    proto::Error {
 862                        message: "message can't be blank".to_string(),
 863                    },
 864                )
 865                .await?;
 866            return Ok(());
 867        }
 868
 869        let timestamp = OffsetDateTime::now_utc();
 870        let nonce = if let Some(nonce) = request.payload.nonce {
 871            nonce
 872        } else {
 873            self.peer
 874                .respond_with_error(
 875                    receipt,
 876                    proto::Error {
 877                        message: "nonce can't be blank".to_string(),
 878                    },
 879                )
 880                .await?;
 881            return Ok(());
 882        };
 883
 884        let message_id = self
 885            .app_state
 886            .db
 887            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
 888            .await?
 889            .to_proto();
 890        let message = proto::ChannelMessage {
 891            sender_id: user_id.to_proto(),
 892            id: message_id,
 893            body,
 894            timestamp: timestamp.unix_timestamp() as u64,
 895            nonce: Some(nonce),
 896        };
 897        broadcast(request.sender_id, connection_ids, |conn_id| {
 898            self.peer.send(
 899                conn_id,
 900                proto::ChannelMessageSent {
 901                    channel_id: channel_id.to_proto(),
 902                    message: Some(message.clone()),
 903                },
 904            )
 905        })
 906        .await?;
 907        self.peer
 908            .respond(
 909                receipt,
 910                proto::SendChannelMessageResponse {
 911                    message: Some(message),
 912                },
 913            )
 914            .await?;
 915        Ok(())
 916    }
 917
 918    async fn get_channel_messages(
 919        self: Arc<Self>,
 920        request: TypedEnvelope<proto::GetChannelMessages>,
 921    ) -> tide::Result<()> {
 922        let user_id = self
 923            .state
 924            .read()
 925            .await
 926            .user_id_for_connection(request.sender_id)?;
 927        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 928        if !self
 929            .app_state
 930            .db
 931            .can_user_access_channel(user_id, channel_id)
 932            .await?
 933        {
 934            Err(anyhow!("access denied"))?;
 935        }
 936
 937        let messages = self
 938            .app_state
 939            .db
 940            .get_channel_messages(
 941                channel_id,
 942                MESSAGE_COUNT_PER_PAGE,
 943                Some(MessageId::from_proto(request.payload.before_message_id)),
 944            )
 945            .await?
 946            .into_iter()
 947            .map(|msg| proto::ChannelMessage {
 948                id: msg.id.to_proto(),
 949                body: msg.body,
 950                timestamp: msg.sent_at.unix_timestamp() as u64,
 951                sender_id: msg.sender_id.to_proto(),
 952                nonce: Some(msg.nonce.as_u128().into()),
 953            })
 954            .collect::<Vec<_>>();
 955        self.peer
 956            .respond(
 957                request.receipt(),
 958                proto::GetChannelMessagesResponse {
 959                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 960                    messages,
 961                },
 962            )
 963            .await?;
 964        Ok(())
 965    }
 966
 967    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
 968        &self,
 969        worktree_id: u64,
 970        message: &TypedEnvelope<T>,
 971    ) -> tide::Result<()> {
 972        let connection_ids = self
 973            .state
 974            .read()
 975            .await
 976            .read_worktree(worktree_id, message.sender_id)?
 977            .connection_ids();
 978
 979        broadcast(message.sender_id, connection_ids, |conn_id| {
 980            self.peer
 981                .forward_send(message.sender_id, conn_id, message.payload.clone())
 982        })
 983        .await?;
 984
 985        Ok(())
 986    }
 987}
 988
 989pub async fn broadcast<F, T>(
 990    sender_id: ConnectionId,
 991    receiver_ids: Vec<ConnectionId>,
 992    mut f: F,
 993) -> anyhow::Result<()>
 994where
 995    F: FnMut(ConnectionId) -> T,
 996    T: Future<Output = anyhow::Result<()>>,
 997{
 998    let futures = receiver_ids
 999        .into_iter()
1000        .filter(|id| *id != sender_id)
1001        .map(|id| f(id));
1002    futures::future::try_join_all(futures).await?;
1003    Ok(())
1004}
1005
1006impl ServerState {
1007    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1008        if let Some(connection) = self.connections.get_mut(&connection_id) {
1009            connection.channels.insert(channel_id);
1010            self.channels
1011                .entry(channel_id)
1012                .or_default()
1013                .connection_ids
1014                .insert(connection_id);
1015        }
1016    }
1017
1018    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1019        if let Some(connection) = self.connections.get_mut(&connection_id) {
1020            connection.channels.remove(&channel_id);
1021            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1022                entry.get_mut().connection_ids.remove(&connection_id);
1023                if entry.get_mut().connection_ids.is_empty() {
1024                    entry.remove();
1025                }
1026            }
1027        }
1028    }
1029
1030    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1031        Ok(self
1032            .connections
1033            .get(&connection_id)
1034            .ok_or_else(|| anyhow!("unknown connection"))?
1035            .user_id)
1036    }
1037
1038    fn user_connection_ids<'a>(
1039        &'a self,
1040        user_id: UserId,
1041    ) -> impl 'a + Iterator<Item = ConnectionId> {
1042        self.connections_by_user_id
1043            .get(&user_id)
1044            .into_iter()
1045            .flatten()
1046            .copied()
1047    }
1048
1049    // Add the given connection as a guest of the given worktree
1050    fn join_worktree(
1051        &mut self,
1052        connection_id: ConnectionId,
1053        user_id: UserId,
1054        worktree_id: u64,
1055    ) -> tide::Result<(ReplicaId, &Worktree)> {
1056        let connection = self
1057            .connections
1058            .get_mut(&connection_id)
1059            .ok_or_else(|| anyhow!("no such connection"))?;
1060        let worktree = self
1061            .worktrees
1062            .get_mut(&worktree_id)
1063            .ok_or_else(|| anyhow!("no such worktree"))?;
1064        if !worktree.collaborator_user_ids.contains(&user_id) {
1065            Err(anyhow!("no such worktree"))?;
1066        }
1067
1068        let share = worktree.share_mut()?;
1069        connection.worktrees.insert(worktree_id);
1070
1071        let mut replica_id = 1;
1072        while share.active_replica_ids.contains(&replica_id) {
1073            replica_id += 1;
1074        }
1075        share.active_replica_ids.insert(replica_id);
1076        share.guest_connection_ids.insert(connection_id, replica_id);
1077        return Ok((replica_id, worktree));
1078    }
1079
1080    fn read_worktree(
1081        &self,
1082        worktree_id: u64,
1083        connection_id: ConnectionId,
1084    ) -> tide::Result<&Worktree> {
1085        let worktree = self
1086            .worktrees
1087            .get(&worktree_id)
1088            .ok_or_else(|| anyhow!("worktree not found"))?;
1089
1090        if worktree.host_connection_id == connection_id
1091            || worktree
1092                .share()?
1093                .guest_connection_ids
1094                .contains_key(&connection_id)
1095        {
1096            Ok(worktree)
1097        } else {
1098            Err(anyhow!(
1099                "{} is not a member of worktree {}",
1100                connection_id,
1101                worktree_id
1102            ))?
1103        }
1104    }
1105
1106    fn write_worktree(
1107        &mut self,
1108        worktree_id: u64,
1109        connection_id: ConnectionId,
1110    ) -> tide::Result<&mut Worktree> {
1111        let worktree = self
1112            .worktrees
1113            .get_mut(&worktree_id)
1114            .ok_or_else(|| anyhow!("worktree not found"))?;
1115
1116        if worktree.host_connection_id == connection_id
1117            || worktree.share.as_ref().map_or(false, |share| {
1118                share.guest_connection_ids.contains_key(&connection_id)
1119            })
1120        {
1121            Ok(worktree)
1122        } else {
1123            Err(anyhow!(
1124                "{} is not a member of worktree {}",
1125                connection_id,
1126                worktree_id
1127            ))?
1128        }
1129    }
1130
1131    fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1132        let worktree_id = self.next_worktree_id;
1133        for collaborator_user_id in &worktree.collaborator_user_ids {
1134            self.visible_worktrees_by_user_id
1135                .entry(*collaborator_user_id)
1136                .or_default()
1137                .insert(worktree_id);
1138        }
1139        self.next_worktree_id += 1;
1140        self.worktrees.insert(worktree_id, worktree);
1141        worktree_id
1142    }
1143
1144    fn remove_worktree(&mut self, worktree_id: u64) {
1145        let worktree = self.worktrees.remove(&worktree_id).unwrap();
1146        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1147            connection.worktrees.remove(&worktree_id);
1148        }
1149        if let Some(share) = worktree.share {
1150            for connection_id in share.guest_connection_ids.keys() {
1151                if let Some(connection) = self.connections.get_mut(connection_id) {
1152                    connection.worktrees.remove(&worktree_id);
1153                }
1154            }
1155        }
1156        for collaborator_user_id in worktree.collaborator_user_ids {
1157            if let Some(visible_worktrees) = self
1158                .visible_worktrees_by_user_id
1159                .get_mut(&collaborator_user_id)
1160            {
1161                visible_worktrees.remove(&worktree_id);
1162            }
1163        }
1164    }
1165}
1166
1167impl Worktree {
1168    pub fn connection_ids(&self) -> Vec<ConnectionId> {
1169        if let Some(share) = &self.share {
1170            share
1171                .guest_connection_ids
1172                .keys()
1173                .copied()
1174                .chain(Some(self.host_connection_id))
1175                .collect()
1176        } else {
1177            vec![self.host_connection_id]
1178        }
1179    }
1180
1181    fn share(&self) -> tide::Result<&WorktreeShare> {
1182        Ok(self
1183            .share
1184            .as_ref()
1185            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1186    }
1187
1188    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1189        Ok(self
1190            .share
1191            .as_mut()
1192            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1193    }
1194}
1195
1196impl Channel {
1197    fn connection_ids(&self) -> Vec<ConnectionId> {
1198        self.connection_ids.iter().copied().collect()
1199    }
1200}
1201
1202pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1203    let server = Server::new(app.state().clone(), rpc.clone(), None);
1204    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1205        let user_id = request.ext::<UserId>().copied();
1206        let server = server.clone();
1207        async move {
1208            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1209
1210            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1211            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1212            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1213
1214            if !upgrade_requested {
1215                return Ok(Response::new(StatusCode::UpgradeRequired));
1216            }
1217
1218            let header = match request.header("Sec-Websocket-Key") {
1219                Some(h) => h.as_str(),
1220                None => return Err(anyhow!("expected sec-websocket-key"))?,
1221            };
1222
1223            let mut response = Response::new(StatusCode::SwitchingProtocols);
1224            response.insert_header(UPGRADE, "websocket");
1225            response.insert_header(CONNECTION, "Upgrade");
1226            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1227            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1228            response.insert_header("Sec-Websocket-Version", "13");
1229
1230            let http_res: &mut tide::http::Response = response.as_mut();
1231            let upgrade_receiver = http_res.recv_upgrade().await;
1232            let addr = request.remote().unwrap_or("unknown").to_string();
1233            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1234            task::spawn(async move {
1235                if let Some(stream) = upgrade_receiver.await {
1236                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1237                }
1238            });
1239
1240            Ok(response)
1241        }
1242    });
1243}
1244
1245fn header_contains_ignore_case<T>(
1246    request: &tide::Request<T>,
1247    header_name: HeaderName,
1248    value: &str,
1249) -> bool {
1250    request
1251        .header(header_name)
1252        .map(|h| {
1253            h.as_str()
1254                .split(',')
1255                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1256        })
1257        .unwrap_or(false)
1258}
1259
1260#[cfg(test)]
1261mod tests {
1262    use super::*;
1263    use crate::{
1264        auth,
1265        db::{tests::TestDb, UserId},
1266        github, AppState, Config,
1267    };
1268    use async_std::{sync::RwLockReadGuard, task};
1269    use gpui::{ModelHandle, TestAppContext};
1270    use parking_lot::Mutex;
1271    use postage::{mpsc, watch};
1272    use serde_json::json;
1273    use sqlx::types::time::OffsetDateTime;
1274    use std::{
1275        path::Path,
1276        sync::{
1277            atomic::{AtomicBool, Ordering::SeqCst},
1278            Arc,
1279        },
1280        time::Duration,
1281    };
1282    use zed::{
1283        channel::{Channel, ChannelDetails, ChannelList},
1284        editor::{Editor, Insert},
1285        fs::{FakeFs, Fs as _},
1286        language::LanguageRegistry,
1287        rpc::{self, Client, Credentials, EstablishConnectionError},
1288        settings,
1289        test::FakeHttpClient,
1290        user::UserStore,
1291        worktree::Worktree,
1292    };
1293    use zrpc::Peer;
1294
1295    #[gpui::test]
1296    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1297        let (window_b, _) = cx_b.add_window(|_| EmptyView);
1298        let settings = cx_b.read(settings::test).1;
1299        let lang_registry = Arc::new(LanguageRegistry::new());
1300
1301        // Connect to a server as 2 clients.
1302        let mut server = TestServer::start().await;
1303        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1304        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1305
1306        cx_a.foreground().forbid_parking();
1307
1308        // Share a local worktree as client A
1309        let fs = Arc::new(FakeFs::new());
1310        fs.insert_tree(
1311            "/a",
1312            json!({
1313                ".zed.toml": r#"collaborators = ["user_b"]"#,
1314                "a.txt": "a-contents",
1315                "b.txt": "b-contents",
1316            }),
1317        )
1318        .await;
1319        let worktree_a = Worktree::open_local(
1320            client_a.clone(),
1321            "/a".as_ref(),
1322            fs,
1323            lang_registry.clone(),
1324            &mut cx_a.to_async(),
1325        )
1326        .await
1327        .unwrap();
1328        worktree_a
1329            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1330            .await;
1331        let worktree_id = worktree_a
1332            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1333            .await
1334            .unwrap();
1335
1336        // Join that worktree as client B, and see that a guest has joined as client A.
1337        let worktree_b = Worktree::open_remote(
1338            client_b.clone(),
1339            worktree_id,
1340            lang_registry.clone(),
1341            &mut cx_b.to_async(),
1342        )
1343        .await
1344        .unwrap();
1345        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1346        worktree_a
1347            .condition(&cx_a, |tree, _| {
1348                tree.peers()
1349                    .values()
1350                    .any(|replica_id| *replica_id == replica_id_b)
1351            })
1352            .await;
1353
1354        // Open the same file as client B and client A.
1355        let buffer_b = worktree_b
1356            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1357            .await
1358            .unwrap();
1359        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1360        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1361        let buffer_a = worktree_a
1362            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1363            .await
1364            .unwrap();
1365
1366        // Create a selection set as client B and see that selection set as client A.
1367        let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1368        buffer_a
1369            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1370            .await;
1371
1372        // Edit the buffer as client B and see that edit as client A.
1373        editor_b.update(&mut cx_b, |editor, cx| {
1374            editor.insert(&Insert("ok, ".into()), cx)
1375        });
1376        buffer_a
1377            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1378            .await;
1379
1380        // Remove the selection set as client B, see those selections disappear as client A.
1381        cx_b.update(move |_| drop(editor_b));
1382        buffer_a
1383            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1384            .await;
1385
1386        // Close the buffer as client A, see that the buffer is closed.
1387        cx_a.update(move |_| drop(buffer_a));
1388        worktree_a
1389            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1390            .await;
1391
1392        // Dropping the worktree removes client B from client A's peers.
1393        cx_b.update(move |_| drop(worktree_b));
1394        worktree_a
1395            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1396            .await;
1397    }
1398
1399    #[gpui::test]
1400    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1401        mut cx_a: TestAppContext,
1402        mut cx_b: TestAppContext,
1403        mut cx_c: TestAppContext,
1404    ) {
1405        cx_a.foreground().forbid_parking();
1406        let lang_registry = Arc::new(LanguageRegistry::new());
1407
1408        // Connect to a server as 3 clients.
1409        let mut server = TestServer::start().await;
1410        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1411        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1412        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1413
1414        let fs = Arc::new(FakeFs::new());
1415
1416        // Share a worktree as client A.
1417        fs.insert_tree(
1418            "/a",
1419            json!({
1420                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1421                "file1": "",
1422                "file2": ""
1423            }),
1424        )
1425        .await;
1426
1427        let worktree_a = Worktree::open_local(
1428            client_a.clone(),
1429            "/a".as_ref(),
1430            fs.clone(),
1431            lang_registry.clone(),
1432            &mut cx_a.to_async(),
1433        )
1434        .await
1435        .unwrap();
1436        worktree_a
1437            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1438            .await;
1439        let worktree_id = worktree_a
1440            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1441            .await
1442            .unwrap();
1443
1444        // Join that worktree as clients B and C.
1445        let worktree_b = Worktree::open_remote(
1446            client_b.clone(),
1447            worktree_id,
1448            lang_registry.clone(),
1449            &mut cx_b.to_async(),
1450        )
1451        .await
1452        .unwrap();
1453        let worktree_c = Worktree::open_remote(
1454            client_c.clone(),
1455            worktree_id,
1456            lang_registry.clone(),
1457            &mut cx_c.to_async(),
1458        )
1459        .await
1460        .unwrap();
1461
1462        // Open and edit a buffer as both guests B and C.
1463        let buffer_b = worktree_b
1464            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1465            .await
1466            .unwrap();
1467        let buffer_c = worktree_c
1468            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1469            .await
1470            .unwrap();
1471        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1472        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1473
1474        // Open and edit that buffer as the host.
1475        let buffer_a = worktree_a
1476            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1477            .await
1478            .unwrap();
1479
1480        buffer_a
1481            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1482            .await;
1483        buffer_a.update(&mut cx_a, |buf, cx| {
1484            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1485        });
1486
1487        // Wait for edits to propagate
1488        buffer_a
1489            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1490            .await;
1491        buffer_b
1492            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1493            .await;
1494        buffer_c
1495            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1496            .await;
1497
1498        // Edit the buffer as the host and concurrently save as guest B.
1499        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1500        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1501        save_b.await.unwrap();
1502        assert_eq!(
1503            fs.load("/a/file1".as_ref()).await.unwrap(),
1504            "hi-a, i-am-c, i-am-b, i-am-a"
1505        );
1506        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1507        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1508        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1509
1510        // Make changes on host's file system, see those changes on the guests.
1511        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1512            .await
1513            .unwrap();
1514        fs.insert_file(Path::new("/a/file4"), "4".into())
1515            .await
1516            .unwrap();
1517
1518        worktree_b
1519            .condition(&cx_b, |tree, _| tree.file_count() == 4)
1520            .await;
1521        worktree_c
1522            .condition(&cx_c, |tree, _| tree.file_count() == 4)
1523            .await;
1524        worktree_b.read_with(&cx_b, |tree, _| {
1525            assert_eq!(
1526                tree.paths()
1527                    .map(|p| p.to_string_lossy())
1528                    .collect::<Vec<_>>(),
1529                &[".zed.toml", "file1", "file3", "file4"]
1530            )
1531        });
1532        worktree_c.read_with(&cx_c, |tree, _| {
1533            assert_eq!(
1534                tree.paths()
1535                    .map(|p| p.to_string_lossy())
1536                    .collect::<Vec<_>>(),
1537                &[".zed.toml", "file1", "file3", "file4"]
1538            )
1539        });
1540    }
1541
1542    #[gpui::test]
1543    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1544        cx_a.foreground().forbid_parking();
1545        let lang_registry = Arc::new(LanguageRegistry::new());
1546
1547        // Connect to a server as 2 clients.
1548        let mut server = TestServer::start().await;
1549        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1550        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1551
1552        // Share a local worktree as client A
1553        let fs = Arc::new(FakeFs::new());
1554        fs.insert_tree(
1555            "/dir",
1556            json!({
1557                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1558                "a.txt": "a-contents",
1559            }),
1560        )
1561        .await;
1562
1563        let worktree_a = Worktree::open_local(
1564            client_a.clone(),
1565            "/dir".as_ref(),
1566            fs,
1567            lang_registry.clone(),
1568            &mut cx_a.to_async(),
1569        )
1570        .await
1571        .unwrap();
1572        worktree_a
1573            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1574            .await;
1575        let worktree_id = worktree_a
1576            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1577            .await
1578            .unwrap();
1579
1580        // Join that worktree as client B, and see that a guest has joined as client A.
1581        let worktree_b = Worktree::open_remote(
1582            client_b.clone(),
1583            worktree_id,
1584            lang_registry.clone(),
1585            &mut cx_b.to_async(),
1586        )
1587        .await
1588        .unwrap();
1589
1590        let buffer_b = worktree_b
1591            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1592            .await
1593            .unwrap();
1594        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1595
1596        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1597        buffer_b.read_with(&cx_b, |buf, _| {
1598            assert!(buf.is_dirty());
1599            assert!(!buf.has_conflict());
1600        });
1601
1602        buffer_b
1603            .update(&mut cx_b, |buf, cx| buf.save(cx))
1604            .unwrap()
1605            .await
1606            .unwrap();
1607        worktree_b
1608            .condition(&cx_b, |_, cx| {
1609                buffer_b.read(cx).file().unwrap().mtime != mtime
1610            })
1611            .await;
1612        buffer_b.read_with(&cx_b, |buf, _| {
1613            assert!(!buf.is_dirty());
1614            assert!(!buf.has_conflict());
1615        });
1616
1617        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1618        buffer_b.read_with(&cx_b, |buf, _| {
1619            assert!(buf.is_dirty());
1620            assert!(!buf.has_conflict());
1621        });
1622    }
1623
1624    #[gpui::test]
1625    async fn test_editing_while_guest_opens_buffer(
1626        mut cx_a: TestAppContext,
1627        mut cx_b: TestAppContext,
1628    ) {
1629        cx_a.foreground().forbid_parking();
1630        let lang_registry = Arc::new(LanguageRegistry::new());
1631
1632        // Connect to a server as 2 clients.
1633        let mut server = TestServer::start().await;
1634        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1635        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1636
1637        // Share a local worktree as client A
1638        let fs = Arc::new(FakeFs::new());
1639        fs.insert_tree(
1640            "/dir",
1641            json!({
1642                ".zed.toml": r#"collaborators = ["user_b"]"#,
1643                "a.txt": "a-contents",
1644            }),
1645        )
1646        .await;
1647        let worktree_a = Worktree::open_local(
1648            client_a.clone(),
1649            "/dir".as_ref(),
1650            fs,
1651            lang_registry.clone(),
1652            &mut cx_a.to_async(),
1653        )
1654        .await
1655        .unwrap();
1656        worktree_a
1657            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1658            .await;
1659        let worktree_id = worktree_a
1660            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1661            .await
1662            .unwrap();
1663
1664        // Join that worktree as client B, and see that a guest has joined as client A.
1665        let worktree_b = Worktree::open_remote(
1666            client_b.clone(),
1667            worktree_id,
1668            lang_registry.clone(),
1669            &mut cx_b.to_async(),
1670        )
1671        .await
1672        .unwrap();
1673
1674        let buffer_a = worktree_a
1675            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1676            .await
1677            .unwrap();
1678        let buffer_b = cx_b
1679            .background()
1680            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1681
1682        task::yield_now().await;
1683        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1684
1685        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1686        let buffer_b = buffer_b.await.unwrap();
1687        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1688    }
1689
1690    #[gpui::test]
1691    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1692        cx_a.foreground().forbid_parking();
1693        let lang_registry = Arc::new(LanguageRegistry::new());
1694
1695        // Connect to a server as 2 clients.
1696        let mut server = TestServer::start().await;
1697        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1698        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1699
1700        // Share a local worktree as client A
1701        let fs = Arc::new(FakeFs::new());
1702        fs.insert_tree(
1703            "/a",
1704            json!({
1705                ".zed.toml": r#"collaborators = ["user_b"]"#,
1706                "a.txt": "a-contents",
1707                "b.txt": "b-contents",
1708            }),
1709        )
1710        .await;
1711        let worktree_a = Worktree::open_local(
1712            client_a.clone(),
1713            "/a".as_ref(),
1714            fs,
1715            lang_registry.clone(),
1716            &mut cx_a.to_async(),
1717        )
1718        .await
1719        .unwrap();
1720        worktree_a
1721            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1722            .await;
1723        let worktree_id = worktree_a
1724            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1725            .await
1726            .unwrap();
1727
1728        // Join that worktree as client B, and see that a guest has joined as client A.
1729        let _worktree_b = Worktree::open_remote(
1730            client_b.clone(),
1731            worktree_id,
1732            lang_registry.clone(),
1733            &mut cx_b.to_async(),
1734        )
1735        .await
1736        .unwrap();
1737        worktree_a
1738            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1739            .await;
1740
1741        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1742        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1743        worktree_a
1744            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1745            .await;
1746    }
1747
1748    #[gpui::test]
1749    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1750        cx_a.foreground().forbid_parking();
1751
1752        // Connect to a server as 2 clients.
1753        let mut server = TestServer::start().await;
1754        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1755        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1756
1757        // Create an org that includes these 2 users.
1758        let db = &server.app_state.db;
1759        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1760        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1761            .await
1762            .unwrap();
1763        db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1764            .await
1765            .unwrap();
1766
1767        // Create a channel that includes all the users.
1768        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1769        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1770            .await
1771            .unwrap();
1772        db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1773            .await
1774            .unwrap();
1775        db.create_channel_message(
1776            channel_id,
1777            current_user_id(&user_store_b, &cx_b),
1778            "hello A, it's B.",
1779            OffsetDateTime::now_utc(),
1780            1,
1781        )
1782        .await
1783        .unwrap();
1784
1785        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1786        channels_a
1787            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1788            .await;
1789        channels_a.read_with(&cx_a, |list, _| {
1790            assert_eq!(
1791                list.available_channels().unwrap(),
1792                &[ChannelDetails {
1793                    id: channel_id.to_proto(),
1794                    name: "test-channel".to_string()
1795                }]
1796            )
1797        });
1798        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1799            this.get_channel(channel_id.to_proto(), cx).unwrap()
1800        });
1801        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1802        channel_a
1803            .condition(&cx_a, |channel, _| {
1804                channel_messages(channel)
1805                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1806            })
1807            .await;
1808
1809        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1810        channels_b
1811            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1812            .await;
1813        channels_b.read_with(&cx_b, |list, _| {
1814            assert_eq!(
1815                list.available_channels().unwrap(),
1816                &[ChannelDetails {
1817                    id: channel_id.to_proto(),
1818                    name: "test-channel".to_string()
1819                }]
1820            )
1821        });
1822
1823        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1824            this.get_channel(channel_id.to_proto(), cx).unwrap()
1825        });
1826        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1827        channel_b
1828            .condition(&cx_b, |channel, _| {
1829                channel_messages(channel)
1830                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1831            })
1832            .await;
1833
1834        channel_a
1835            .update(&mut cx_a, |channel, cx| {
1836                channel
1837                    .send_message("oh, hi B.".to_string(), cx)
1838                    .unwrap()
1839                    .detach();
1840                let task = channel.send_message("sup".to_string(), cx).unwrap();
1841                assert_eq!(
1842                    channel_messages(channel),
1843                    &[
1844                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1845                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
1846                        ("user_a".to_string(), "sup".to_string(), true)
1847                    ]
1848                );
1849                task
1850            })
1851            .await
1852            .unwrap();
1853
1854        channel_b
1855            .condition(&cx_b, |channel, _| {
1856                channel_messages(channel)
1857                    == [
1858                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1859                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
1860                        ("user_a".to_string(), "sup".to_string(), false),
1861                    ]
1862            })
1863            .await;
1864
1865        assert_eq!(
1866            server.state().await.channels[&channel_id]
1867                .connection_ids
1868                .len(),
1869            2
1870        );
1871        cx_b.update(|_| drop(channel_b));
1872        server
1873            .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1874            .await;
1875
1876        cx_a.update(|_| drop(channel_a));
1877        server
1878            .condition(|state| !state.channels.contains_key(&channel_id))
1879            .await;
1880    }
1881
1882    #[gpui::test]
1883    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1884        cx_a.foreground().forbid_parking();
1885
1886        let mut server = TestServer::start().await;
1887        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1888
1889        let db = &server.app_state.db;
1890        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1891        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1892        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1893            .await
1894            .unwrap();
1895        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1896            .await
1897            .unwrap();
1898
1899        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1900        channels_a
1901            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1902            .await;
1903        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1904            this.get_channel(channel_id.to_proto(), cx).unwrap()
1905        });
1906
1907        // Messages aren't allowed to be too long.
1908        channel_a
1909            .update(&mut cx_a, |channel, cx| {
1910                let long_body = "this is long.\n".repeat(1024);
1911                channel.send_message(long_body, cx).unwrap()
1912            })
1913            .await
1914            .unwrap_err();
1915
1916        // Messages aren't allowed to be blank.
1917        channel_a.update(&mut cx_a, |channel, cx| {
1918            channel.send_message(String::new(), cx).unwrap_err()
1919        });
1920
1921        // Leading and trailing whitespace are trimmed.
1922        channel_a
1923            .update(&mut cx_a, |channel, cx| {
1924                channel
1925                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1926                    .unwrap()
1927            })
1928            .await
1929            .unwrap();
1930        assert_eq!(
1931            db.get_channel_messages(channel_id, 10, None)
1932                .await
1933                .unwrap()
1934                .iter()
1935                .map(|m| &m.body)
1936                .collect::<Vec<_>>(),
1937            &["surrounded by whitespace"]
1938        );
1939    }
1940
1941    #[gpui::test]
1942    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1943        cx_a.foreground().forbid_parking();
1944
1945        // Connect to a server as 2 clients.
1946        let mut server = TestServer::start().await;
1947        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1948        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1949        let mut status_b = client_b.status();
1950
1951        // Create an org that includes these 2 users.
1952        let db = &server.app_state.db;
1953        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1954        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1955            .await
1956            .unwrap();
1957        db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1958            .await
1959            .unwrap();
1960
1961        // Create a channel that includes all the users.
1962        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1963        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1964            .await
1965            .unwrap();
1966        db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1967            .await
1968            .unwrap();
1969        db.create_channel_message(
1970            channel_id,
1971            current_user_id(&user_store_b, &cx_b),
1972            "hello A, it's B.",
1973            OffsetDateTime::now_utc(),
1974            2,
1975        )
1976        .await
1977        .unwrap();
1978
1979        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1980        channels_a
1981            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1982            .await;
1983
1984        channels_a.read_with(&cx_a, |list, _| {
1985            assert_eq!(
1986                list.available_channels().unwrap(),
1987                &[ChannelDetails {
1988                    id: channel_id.to_proto(),
1989                    name: "test-channel".to_string()
1990                }]
1991            )
1992        });
1993        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1994            this.get_channel(channel_id.to_proto(), cx).unwrap()
1995        });
1996        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1997        channel_a
1998            .condition(&cx_a, |channel, _| {
1999                channel_messages(channel)
2000                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2001            })
2002            .await;
2003
2004        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2005        channels_b
2006            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2007            .await;
2008        channels_b.read_with(&cx_b, |list, _| {
2009            assert_eq!(
2010                list.available_channels().unwrap(),
2011                &[ChannelDetails {
2012                    id: channel_id.to_proto(),
2013                    name: "test-channel".to_string()
2014                }]
2015            )
2016        });
2017
2018        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2019            this.get_channel(channel_id.to_proto(), cx).unwrap()
2020        });
2021        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2022        channel_b
2023            .condition(&cx_b, |channel, _| {
2024                channel_messages(channel)
2025                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2026            })
2027            .await;
2028
2029        // Disconnect client B, ensuring we can still access its cached channel data.
2030        server.forbid_connections();
2031        server.disconnect_client(current_user_id(&user_store_b, &cx_b));
2032        while !matches!(
2033            status_b.recv().await,
2034            Some(rpc::Status::ReconnectionError { .. })
2035        ) {}
2036
2037        channels_b.read_with(&cx_b, |channels, _| {
2038            assert_eq!(
2039                channels.available_channels().unwrap(),
2040                [ChannelDetails {
2041                    id: channel_id.to_proto(),
2042                    name: "test-channel".to_string()
2043                }]
2044            )
2045        });
2046        channel_b.read_with(&cx_b, |channel, _| {
2047            assert_eq!(
2048                channel_messages(channel),
2049                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2050            )
2051        });
2052
2053        // Send a message from client B while it is disconnected.
2054        channel_b
2055            .update(&mut cx_b, |channel, cx| {
2056                let task = channel
2057                    .send_message("can you see this?".to_string(), cx)
2058                    .unwrap();
2059                assert_eq!(
2060                    channel_messages(channel),
2061                    &[
2062                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2063                        ("user_b".to_string(), "can you see this?".to_string(), true)
2064                    ]
2065                );
2066                task
2067            })
2068            .await
2069            .unwrap_err();
2070
2071        // Send a message from client A while B is disconnected.
2072        channel_a
2073            .update(&mut cx_a, |channel, cx| {
2074                channel
2075                    .send_message("oh, hi B.".to_string(), cx)
2076                    .unwrap()
2077                    .detach();
2078                let task = channel.send_message("sup".to_string(), cx).unwrap();
2079                assert_eq!(
2080                    channel_messages(channel),
2081                    &[
2082                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2083                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
2084                        ("user_a".to_string(), "sup".to_string(), true)
2085                    ]
2086                );
2087                task
2088            })
2089            .await
2090            .unwrap();
2091
2092        // Give client B a chance to reconnect.
2093        server.allow_connections();
2094        cx_b.foreground().advance_clock(Duration::from_secs(10));
2095
2096        // Verify that B sees the new messages upon reconnection, as well as the message client B
2097        // sent while offline.
2098        channel_b
2099            .condition(&cx_b, |channel, _| {
2100                channel_messages(channel)
2101                    == [
2102                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2103                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2104                        ("user_a".to_string(), "sup".to_string(), false),
2105                        ("user_b".to_string(), "can you see this?".to_string(), false),
2106                    ]
2107            })
2108            .await;
2109
2110        // Ensure client A and B can communicate normally after reconnection.
2111        channel_a
2112            .update(&mut cx_a, |channel, cx| {
2113                channel.send_message("you online?".to_string(), cx).unwrap()
2114            })
2115            .await
2116            .unwrap();
2117        channel_b
2118            .condition(&cx_b, |channel, _| {
2119                channel_messages(channel)
2120                    == [
2121                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2122                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2123                        ("user_a".to_string(), "sup".to_string(), false),
2124                        ("user_b".to_string(), "can you see this?".to_string(), false),
2125                        ("user_a".to_string(), "you online?".to_string(), false),
2126                    ]
2127            })
2128            .await;
2129
2130        channel_b
2131            .update(&mut cx_b, |channel, cx| {
2132                channel.send_message("yep".to_string(), cx).unwrap()
2133            })
2134            .await
2135            .unwrap();
2136        channel_a
2137            .condition(&cx_a, |channel, _| {
2138                channel_messages(channel)
2139                    == [
2140                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2141                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2142                        ("user_a".to_string(), "sup".to_string(), false),
2143                        ("user_b".to_string(), "can you see this?".to_string(), false),
2144                        ("user_a".to_string(), "you online?".to_string(), false),
2145                        ("user_b".to_string(), "yep".to_string(), false),
2146                    ]
2147            })
2148            .await;
2149    }
2150
2151    #[gpui::test]
2152    async fn test_collaborators(
2153        mut cx_a: TestAppContext,
2154        mut cx_b: TestAppContext,
2155        mut cx_c: TestAppContext,
2156    ) {
2157        cx_a.foreground().forbid_parking();
2158        let lang_registry = Arc::new(LanguageRegistry::new());
2159
2160        // Connect to a server as 3 clients.
2161        let mut server = TestServer::start().await;
2162        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
2163        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
2164        let (_client_c, user_store_c) = server.create_client(&mut cx_c, "user_c").await;
2165
2166        let fs = Arc::new(FakeFs::new());
2167
2168        // Share a worktree as client A.
2169        fs.insert_tree(
2170            "/a",
2171            json!({
2172                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
2173            }),
2174        )
2175        .await;
2176
2177        let worktree_a = Worktree::open_local(
2178            client_a.clone(),
2179            "/a".as_ref(),
2180            fs.clone(),
2181            lang_registry.clone(),
2182            &mut cx_a.to_async(),
2183        )
2184        .await
2185        .unwrap();
2186
2187        user_store_a
2188            .condition(&cx_a, |user_store, _| {
2189                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2190            })
2191            .await;
2192        user_store_b
2193            .condition(&cx_b, |user_store, _| {
2194                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2195            })
2196            .await;
2197        user_store_c
2198            .condition(&cx_c, |user_store, _| {
2199                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2200            })
2201            .await;
2202
2203        let worktree_id = worktree_a
2204            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
2205            .await
2206            .unwrap();
2207
2208        let _worktree_b = Worktree::open_remote(
2209            client_b.clone(),
2210            worktree_id,
2211            lang_registry.clone(),
2212            &mut cx_b.to_async(),
2213        )
2214        .await
2215        .unwrap();
2216
2217        user_store_a
2218            .condition(&cx_a, |user_store, _| {
2219                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2220            })
2221            .await;
2222        user_store_b
2223            .condition(&cx_b, |user_store, _| {
2224                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2225            })
2226            .await;
2227        user_store_c
2228            .condition(&cx_c, |user_store, _| {
2229                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2230            })
2231            .await;
2232
2233        cx_a.update(move |_| drop(worktree_a));
2234        user_store_a
2235            .condition(&cx_a, |user_store, _| collaborators(user_store) == vec![])
2236            .await;
2237        user_store_b
2238            .condition(&cx_b, |user_store, _| collaborators(user_store) == vec![])
2239            .await;
2240        user_store_c
2241            .condition(&cx_c, |user_store, _| collaborators(user_store) == vec![])
2242            .await;
2243
2244        fn collaborators(user_store: &UserStore) -> Vec<(&str, Vec<(&str, Vec<&str>)>)> {
2245            user_store
2246                .collaborators()
2247                .iter()
2248                .map(|collaborator| {
2249                    let worktrees = collaborator
2250                        .worktrees
2251                        .iter()
2252                        .map(|w| {
2253                            (
2254                                w.root_name.as_str(),
2255                                w.participants
2256                                    .iter()
2257                                    .map(|p| p.github_login.as_str())
2258                                    .collect(),
2259                            )
2260                        })
2261                        .collect();
2262                    (collaborator.user.github_login.as_str(), worktrees)
2263                })
2264                .collect()
2265        }
2266    }
2267
2268    struct TestServer {
2269        peer: Arc<Peer>,
2270        app_state: Arc<AppState>,
2271        server: Arc<Server>,
2272        notifications: mpsc::Receiver<()>,
2273        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2274        forbid_connections: Arc<AtomicBool>,
2275        _test_db: TestDb,
2276    }
2277
2278    impl TestServer {
2279        async fn start() -> Self {
2280            let test_db = TestDb::new();
2281            let app_state = Self::build_app_state(&test_db).await;
2282            let peer = Peer::new();
2283            let notifications = mpsc::channel(128);
2284            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2285            Self {
2286                peer,
2287                app_state,
2288                server,
2289                notifications: notifications.1,
2290                connection_killers: Default::default(),
2291                forbid_connections: Default::default(),
2292                _test_db: test_db,
2293            }
2294        }
2295
2296        async fn create_client(
2297            &mut self,
2298            cx: &mut TestAppContext,
2299            name: &str,
2300        ) -> (Arc<Client>, ModelHandle<UserStore>) {
2301            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2302            let client_name = name.to_string();
2303            let mut client = Client::new();
2304            let server = self.server.clone();
2305            let connection_killers = self.connection_killers.clone();
2306            let forbid_connections = self.forbid_connections.clone();
2307            Arc::get_mut(&mut client)
2308                .unwrap()
2309                .override_authenticate(move |cx| {
2310                    cx.spawn(|_| async move {
2311                        let access_token = "the-token".to_string();
2312                        Ok(Credentials {
2313                            user_id: user_id.0 as u64,
2314                            access_token,
2315                        })
2316                    })
2317                })
2318                .override_establish_connection(move |credentials, cx| {
2319                    assert_eq!(credentials.user_id, user_id.0 as u64);
2320                    assert_eq!(credentials.access_token, "the-token");
2321
2322                    let server = server.clone();
2323                    let connection_killers = connection_killers.clone();
2324                    let forbid_connections = forbid_connections.clone();
2325                    let client_name = client_name.clone();
2326                    cx.spawn(move |cx| async move {
2327                        if forbid_connections.load(SeqCst) {
2328                            Err(EstablishConnectionError::other(anyhow!(
2329                                "server is forbidding connections"
2330                            )))
2331                        } else {
2332                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2333                            connection_killers.lock().insert(user_id, kill_conn);
2334                            cx.background()
2335                                .spawn(server.handle_connection(server_conn, client_name, user_id))
2336                                .detach();
2337                            Ok(client_conn)
2338                        }
2339                    })
2340                });
2341
2342            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2343            client
2344                .authenticate_and_connect(&cx.to_async())
2345                .await
2346                .unwrap();
2347
2348            let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
2349            let mut authed_user =
2350                user_store.read_with(cx, |user_store, _| user_store.watch_current_user());
2351            while authed_user.recv().await.unwrap().is_none() {}
2352
2353            (client, user_store)
2354        }
2355
2356        fn disconnect_client(&self, user_id: UserId) {
2357            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2358                let _ = kill_conn.try_send(Some(()));
2359            }
2360        }
2361
2362        fn forbid_connections(&self) {
2363            self.forbid_connections.store(true, SeqCst);
2364        }
2365
2366        fn allow_connections(&self) {
2367            self.forbid_connections.store(false, SeqCst);
2368        }
2369
2370        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2371            let mut config = Config::default();
2372            config.session_secret = "a".repeat(32);
2373            config.database_url = test_db.url.clone();
2374            let github_client = github::AppClient::test();
2375            Arc::new(AppState {
2376                db: test_db.db().clone(),
2377                handlebars: Default::default(),
2378                auth_client: auth::build_client("", ""),
2379                repo_client: github::RepoClient::test(&github_client),
2380                github_client,
2381                config,
2382            })
2383        }
2384
2385        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2386            self.server.state.read().await
2387        }
2388
2389        async fn condition<F>(&mut self, mut predicate: F)
2390        where
2391            F: FnMut(&ServerState) -> bool,
2392        {
2393            async_std::future::timeout(Duration::from_millis(500), async {
2394                while !(predicate)(&*self.server.state.read().await) {
2395                    self.notifications.recv().await;
2396                }
2397            })
2398            .await
2399            .expect("condition timed out");
2400        }
2401    }
2402
2403    impl Drop for TestServer {
2404        fn drop(&mut self) {
2405            task::block_on(self.peer.reset());
2406        }
2407    }
2408
2409    fn current_user_id(user_store: &ModelHandle<UserStore>, cx: &TestAppContext) -> UserId {
2410        UserId::from_proto(
2411            user_store.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
2412        )
2413    }
2414
2415    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2416        channel
2417            .messages()
2418            .cursor::<(), ()>()
2419            .map(|m| {
2420                (
2421                    m.sender.github_login.clone(),
2422                    m.body.clone(),
2423                    m.is_pending(),
2424                )
2425            })
2426            .collect()
2427    }
2428
2429    struct EmptyView;
2430
2431    impl gpui::Entity for EmptyView {
2432        type Event = ();
2433    }
2434
2435    impl gpui::View for EmptyView {
2436        fn ui_name() -> &'static str {
2437            "empty view"
2438        }
2439
2440        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2441            gpui::Element::boxed(gpui::elements::Empty)
2442        }
2443    }
2444}