rpc.rs

   1use super::{
   2    auth,
   3    db::{ChannelId, MessageId, UserId},
   4    AppState,
   5};
   6use anyhow::anyhow;
   7use async_std::{sync::RwLock, task};
   8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
   9use futures::{future::BoxFuture, FutureExt};
  10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  11use sha1::{Digest as _, Sha1};
  12use std::{
  13    any::TypeId,
  14    collections::{hash_map, HashMap, HashSet},
  15    future::Future,
  16    mem,
  17    sync::Arc,
  18    time::Instant,
  19};
  20use surf::StatusCode;
  21use tide::log;
  22use tide::{
  23    http::headers::{HeaderName, CONNECTION, UPGRADE},
  24    Request, Response,
  25};
  26use time::OffsetDateTime;
  27use zrpc::{
  28    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  29    Connection, ConnectionId, Peer, TypedEnvelope,
  30};
  31
  32type ReplicaId = u16;
  33
  34type MessageHandler = Box<
  35    dyn Send
  36        + Sync
  37        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  38>;
  39
  40pub struct Server {
  41    peer: Arc<Peer>,
  42    state: RwLock<ServerState>,
  43    app_state: Arc<AppState>,
  44    handlers: HashMap<TypeId, MessageHandler>,
  45    notifications: Option<mpsc::Sender<()>>,
  46}
  47
  48#[derive(Default)]
  49struct ServerState {
  50    connections: HashMap<ConnectionId, ConnectionState>,
  51    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
  52    pub worktrees: HashMap<u64, Worktree>,
  53    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
  54    channels: HashMap<ChannelId, Channel>,
  55    next_worktree_id: u64,
  56}
  57
  58struct ConnectionState {
  59    user_id: UserId,
  60    worktrees: HashSet<u64>,
  61    channels: HashSet<ChannelId>,
  62}
  63
  64struct Worktree {
  65    host_connection_id: ConnectionId,
  66    collaborator_user_ids: Vec<UserId>,
  67    root_name: String,
  68    share: Option<WorktreeShare>,
  69}
  70
  71struct WorktreeShare {
  72    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
  73    active_replica_ids: HashSet<ReplicaId>,
  74    entries: HashMap<u64, proto::Entry>,
  75}
  76
  77#[derive(Default)]
  78struct Channel {
  79    connection_ids: HashSet<ConnectionId>,
  80}
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84
  85impl Server {
  86    pub fn new(
  87        app_state: Arc<AppState>,
  88        peer: Arc<Peer>,
  89        notifications: Option<mpsc::Sender<()>>,
  90    ) -> Arc<Self> {
  91        let mut server = Self {
  92            peer,
  93            app_state,
  94            state: Default::default(),
  95            handlers: Default::default(),
  96            notifications,
  97        };
  98
  99        server
 100            .add_handler(Server::ping)
 101            .add_handler(Server::open_worktree)
 102            .add_handler(Server::handle_close_worktree)
 103            .add_handler(Server::share_worktree)
 104            .add_handler(Server::unshare_worktree)
 105            .add_handler(Server::join_worktree)
 106            .add_handler(Server::update_worktree)
 107            .add_handler(Server::open_buffer)
 108            .add_handler(Server::close_buffer)
 109            .add_handler(Server::update_buffer)
 110            .add_handler(Server::buffer_saved)
 111            .add_handler(Server::save_buffer)
 112            .add_handler(Server::get_channels)
 113            .add_handler(Server::get_users)
 114            .add_handler(Server::join_channel)
 115            .add_handler(Server::leave_channel)
 116            .add_handler(Server::send_channel_message)
 117            .add_handler(Server::get_channel_messages);
 118
 119        Arc::new(server)
 120    }
 121
 122    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 123    where
 124        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 125        Fut: 'static + Send + Future<Output = tide::Result<()>>,
 126        M: EnvelopedMessage,
 127    {
 128        let prev_handler = self.handlers.insert(
 129            TypeId::of::<M>(),
 130            Box::new(move |server, envelope| {
 131                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 132                (handler)(server, *envelope).boxed()
 133            }),
 134        );
 135        if prev_handler.is_some() {
 136            panic!("registered a handler for the same message twice");
 137        }
 138        self
 139    }
 140
 141    pub fn handle_connection(
 142        self: &Arc<Self>,
 143        connection: Connection,
 144        addr: String,
 145        user_id: UserId,
 146    ) -> impl Future<Output = ()> {
 147        let this = self.clone();
 148        async move {
 149            let (connection_id, handle_io, mut incoming_rx) =
 150                this.peer.add_connection(connection).await;
 151            this.add_connection(connection_id, user_id).await;
 152
 153            let handle_io = handle_io.fuse();
 154            futures::pin_mut!(handle_io);
 155            loop {
 156                let next_message = incoming_rx.recv().fuse();
 157                futures::pin_mut!(next_message);
 158                futures::select_biased! {
 159                    message = next_message => {
 160                        if let Some(message) = message {
 161                            let start_time = Instant::now();
 162                            log::info!("RPC message received: {}", message.payload_type_name());
 163                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 164                                if let Err(err) = (handler)(this.clone(), message).await {
 165                                    log::error!("error handling message: {:?}", err);
 166                                } else {
 167                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 168                                }
 169
 170                                if let Some(mut notifications) = this.notifications.clone() {
 171                                    let _ = notifications.send(()).await;
 172                                }
 173                            } else {
 174                                log::warn!("unhandled message: {}", message.payload_type_name());
 175                            }
 176                        } else {
 177                            log::info!("rpc connection closed {:?}", addr);
 178                            break;
 179                        }
 180                    }
 181                    handle_io = handle_io => {
 182                        if let Err(err) = handle_io {
 183                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 184                        }
 185                        break;
 186                    }
 187                }
 188            }
 189
 190            if let Err(err) = this.sign_out(connection_id).await {
 191                log::error!("error signing out connection {:?} - {:?}", addr, err);
 192            }
 193        }
 194    }
 195
 196    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
 197        self.peer.disconnect(connection_id).await;
 198        self.remove_connection(connection_id).await?;
 199        Ok(())
 200    }
 201
 202    // Add a new connection associated with a given user.
 203    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
 204        let mut state = self.state.write().await;
 205        state.connections.insert(
 206            connection_id,
 207            ConnectionState {
 208                user_id,
 209                worktrees: Default::default(),
 210                channels: Default::default(),
 211            },
 212        );
 213        state
 214            .connections_by_user_id
 215            .entry(user_id)
 216            .or_default()
 217            .insert(connection_id);
 218    }
 219
 220    // Remove the given connection and its association with any worktrees.
 221    async fn remove_connection(
 222        self: &Arc<Server>,
 223        connection_id: ConnectionId,
 224    ) -> tide::Result<()> {
 225        let mut worktree_ids = Vec::new();
 226        let mut state = self.state.write().await;
 227        if let Some(connection) = state.connections.remove(&connection_id) {
 228            worktree_ids = connection.worktrees.into_iter().collect();
 229
 230            for channel_id in connection.channels {
 231                if let Some(channel) = state.channels.get_mut(&channel_id) {
 232                    channel.connection_ids.remove(&connection_id);
 233                }
 234            }
 235
 236            let user_connections = state
 237                .connections_by_user_id
 238                .get_mut(&connection.user_id)
 239                .unwrap();
 240            user_connections.remove(&connection_id);
 241            if user_connections.is_empty() {
 242                state.connections_by_user_id.remove(&connection.user_id);
 243            }
 244        }
 245
 246        for worktree_id in worktree_ids {
 247            self.close_worktree(worktree_id, connection_id).await?;
 248        }
 249
 250        Ok(())
 251    }
 252
 253    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
 254        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 255        Ok(())
 256    }
 257
 258    async fn open_worktree(
 259        self: Arc<Server>,
 260        request: TypedEnvelope<proto::OpenWorktree>,
 261    ) -> tide::Result<()> {
 262        let receipt = request.receipt();
 263        let host_user_id = self
 264            .state
 265            .read()
 266            .await
 267            .user_id_for_connection(request.sender_id)?;
 268
 269        let mut collaborator_user_ids = Vec::new();
 270        for github_login in request.payload.collaborator_logins {
 271            match self.app_state.db.create_user(&github_login, false).await {
 272                Ok(collaborator_user_id) => {
 273                    if collaborator_user_id != host_user_id {
 274                        collaborator_user_ids.push(collaborator_user_id);
 275                    }
 276                }
 277                Err(err) => {
 278                    let message = err.to_string();
 279                    self.peer
 280                        .respond_with_error(receipt, proto::Error { message })
 281                        .await?;
 282                    return Ok(());
 283                }
 284            }
 285        }
 286
 287        let worktree_id;
 288        let mut user_ids;
 289        {
 290            let mut state = self.state.write().await;
 291            worktree_id = state.add_worktree(Worktree {
 292                host_connection_id: request.sender_id,
 293                collaborator_user_ids: collaborator_user_ids.clone(),
 294                root_name: request.payload.root_name,
 295                share: None,
 296            });
 297            user_ids = collaborator_user_ids;
 298            user_ids.push(host_user_id);
 299        }
 300
 301        self.peer
 302            .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
 303            .await?;
 304        self.update_collaborators_for_users(&user_ids).await?;
 305
 306        Ok(())
 307    }
 308
 309    async fn share_worktree(
 310        self: Arc<Server>,
 311        mut request: TypedEnvelope<proto::ShareWorktree>,
 312    ) -> tide::Result<()> {
 313        let host_user_id = self
 314            .state
 315            .read()
 316            .await
 317            .user_id_for_connection(request.sender_id)?;
 318        let worktree = request
 319            .payload
 320            .worktree
 321            .as_mut()
 322            .ok_or_else(|| anyhow!("missing worktree"))?;
 323        let entries = mem::take(&mut worktree.entries)
 324            .into_iter()
 325            .map(|entry| (entry.id, entry))
 326            .collect();
 327
 328        let mut state = self.state.write().await;
 329        if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
 330            worktree.share = Some(WorktreeShare {
 331                guest_connection_ids: Default::default(),
 332                active_replica_ids: Default::default(),
 333                entries,
 334            });
 335
 336            let mut user_ids = worktree.collaborator_user_ids.clone();
 337            user_ids.push(host_user_id);
 338
 339            drop(state);
 340            self.peer
 341                .respond(request.receipt(), proto::ShareWorktreeResponse {})
 342                .await?;
 343            self.update_collaborators_for_users(&user_ids).await?;
 344        } else {
 345            self.peer
 346                .respond_with_error(
 347                    request.receipt(),
 348                    proto::Error {
 349                        message: "no such worktree".to_string(),
 350                    },
 351                )
 352                .await?;
 353        }
 354        Ok(())
 355    }
 356
 357    async fn unshare_worktree(
 358        self: Arc<Server>,
 359        request: TypedEnvelope<proto::UnshareWorktree>,
 360    ) -> tide::Result<()> {
 361        let worktree_id = request.payload.worktree_id;
 362        let host_user_id = self
 363            .state
 364            .read()
 365            .await
 366            .user_id_for_connection(request.sender_id)?;
 367
 368        let connection_ids;
 369        let mut user_ids;
 370        {
 371            let mut state = self.state.write().await;
 372            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 373            if worktree.host_connection_id != request.sender_id {
 374                return Err(anyhow!("no such worktree"))?;
 375            }
 376
 377            connection_ids = worktree.connection_ids();
 378            user_ids = worktree.collaborator_user_ids.clone();
 379            user_ids.push(host_user_id);
 380            worktree.share.take();
 381            for connection_id in &connection_ids {
 382                if let Some(connection) = state.connections.get_mut(connection_id) {
 383                    connection.worktrees.remove(&worktree_id);
 384                }
 385            }
 386        }
 387
 388        broadcast(request.sender_id, connection_ids, |conn_id| {
 389            self.peer
 390                .send(conn_id, proto::UnshareWorktree { worktree_id })
 391        })
 392        .await?;
 393        self.update_collaborators_for_users(&user_ids).await?;
 394
 395        Ok(())
 396    }
 397
 398    async fn join_worktree(
 399        self: Arc<Server>,
 400        request: TypedEnvelope<proto::JoinWorktree>,
 401    ) -> tide::Result<()> {
 402        let worktree_id = request.payload.worktree_id;
 403        let user_id = self
 404            .state
 405            .read()
 406            .await
 407            .user_id_for_connection(request.sender_id)?;
 408
 409        let response;
 410        let connection_ids;
 411        let mut user_ids;
 412        let mut state = self.state.write().await;
 413        match state.join_worktree(request.sender_id, user_id, worktree_id) {
 414            Ok((peer_replica_id, worktree)) => {
 415                let share = worktree.share()?;
 416                let peer_count = share.guest_connection_ids.len();
 417                let mut peers = Vec::with_capacity(peer_count);
 418                peers.push(proto::Peer {
 419                    peer_id: worktree.host_connection_id.0,
 420                    replica_id: 0,
 421                });
 422                for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
 423                    if *peer_conn_id != request.sender_id {
 424                        peers.push(proto::Peer {
 425                            peer_id: peer_conn_id.0,
 426                            replica_id: *peer_replica_id as u32,
 427                        });
 428                    }
 429                }
 430                response = proto::JoinWorktreeResponse {
 431                    worktree: Some(proto::Worktree {
 432                        id: worktree_id,
 433                        root_name: worktree.root_name.clone(),
 434                        entries: share.entries.values().cloned().collect(),
 435                    }),
 436                    replica_id: peer_replica_id as u32,
 437                    peers,
 438                };
 439
 440                let host_connection_id = worktree.host_connection_id;
 441                connection_ids = worktree.connection_ids();
 442                user_ids = worktree.collaborator_user_ids.clone();
 443                user_ids.push(state.user_id_for_connection(host_connection_id)?);
 444            }
 445            Err(error) => {
 446                self.peer
 447                    .respond_with_error(
 448                        request.receipt(),
 449                        proto::Error {
 450                            message: error.to_string(),
 451                        },
 452                    )
 453                    .await?;
 454                return Ok(());
 455            }
 456        }
 457
 458        broadcast(request.sender_id, connection_ids, |conn_id| {
 459            self.peer.send(
 460                conn_id,
 461                proto::AddPeer {
 462                    worktree_id,
 463                    peer: Some(proto::Peer {
 464                        peer_id: request.sender_id.0,
 465                        replica_id: response.replica_id,
 466                    }),
 467                },
 468            )
 469        })
 470        .await?;
 471        self.peer.respond(request.receipt(), response).await?;
 472        self.update_collaborators_for_users(&user_ids).await?;
 473
 474        Ok(())
 475    }
 476
 477    async fn handle_close_worktree(
 478        self: Arc<Server>,
 479        request: TypedEnvelope<proto::CloseWorktree>,
 480    ) -> tide::Result<()> {
 481        self.close_worktree(request.payload.worktree_id, request.sender_id)
 482            .await
 483    }
 484
 485    async fn close_worktree(
 486        self: &Arc<Server>,
 487        worktree_id: u64,
 488        conn_id: ConnectionId,
 489    ) -> tide::Result<()> {
 490        let connection_ids;
 491        let mut user_ids;
 492
 493        let mut is_host = false;
 494        let mut is_guest = false;
 495        {
 496            let mut state = self.state.write().await;
 497            let worktree = state.write_worktree(worktree_id, conn_id)?;
 498            let host_connection_id = worktree.host_connection_id;
 499            connection_ids = worktree.connection_ids();
 500            user_ids = worktree.collaborator_user_ids.clone();
 501
 502            if worktree.host_connection_id == conn_id {
 503                is_host = true;
 504                state.remove_worktree(worktree_id);
 505            } else {
 506                let share = worktree.share_mut()?;
 507                if let Some(replica_id) = share.guest_connection_ids.remove(&conn_id) {
 508                    is_guest = true;
 509                    share.active_replica_ids.remove(&replica_id);
 510                }
 511            }
 512
 513            user_ids.push(state.user_id_for_connection(host_connection_id)?);
 514        }
 515
 516        if is_host {
 517            broadcast(conn_id, connection_ids, |conn_id| {
 518                self.peer
 519                    .send(conn_id, proto::UnshareWorktree { worktree_id })
 520            })
 521            .await?;
 522        } else if is_guest {
 523            broadcast(conn_id, connection_ids, |conn_id| {
 524                self.peer.send(
 525                    conn_id,
 526                    proto::RemovePeer {
 527                        worktree_id,
 528                        peer_id: conn_id.0,
 529                    },
 530                )
 531            })
 532            .await?
 533        }
 534        self.update_collaborators_for_users(&user_ids).await?;
 535        Ok(())
 536    }
 537
 538    async fn update_worktree(
 539        self: Arc<Server>,
 540        request: TypedEnvelope<proto::UpdateWorktree>,
 541    ) -> tide::Result<()> {
 542        {
 543            let mut state = self.state.write().await;
 544            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 545            let share = worktree.share_mut()?;
 546
 547            for entry_id in &request.payload.removed_entries {
 548                share.entries.remove(&entry_id);
 549            }
 550
 551            for entry in &request.payload.updated_entries {
 552                share.entries.insert(entry.id, entry.clone());
 553            }
 554        }
 555
 556        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 557            .await?;
 558        Ok(())
 559    }
 560
 561    async fn open_buffer(
 562        self: Arc<Server>,
 563        request: TypedEnvelope<proto::OpenBuffer>,
 564    ) -> tide::Result<()> {
 565        let receipt = request.receipt();
 566        let worktree_id = request.payload.worktree_id;
 567        let host_connection_id = self
 568            .state
 569            .read()
 570            .await
 571            .read_worktree(worktree_id, request.sender_id)?
 572            .host_connection_id;
 573
 574        let response = self
 575            .peer
 576            .forward_request(request.sender_id, host_connection_id, request.payload)
 577            .await?;
 578        self.peer.respond(receipt, response).await?;
 579        Ok(())
 580    }
 581
 582    async fn close_buffer(
 583        self: Arc<Server>,
 584        request: TypedEnvelope<proto::CloseBuffer>,
 585    ) -> tide::Result<()> {
 586        let host_connection_id = self
 587            .state
 588            .read()
 589            .await
 590            .read_worktree(request.payload.worktree_id, request.sender_id)?
 591            .host_connection_id;
 592
 593        self.peer
 594            .forward_send(request.sender_id, host_connection_id, request.payload)
 595            .await?;
 596
 597        Ok(())
 598    }
 599
 600    async fn save_buffer(
 601        self: Arc<Server>,
 602        request: TypedEnvelope<proto::SaveBuffer>,
 603    ) -> tide::Result<()> {
 604        let host;
 605        let guests;
 606        {
 607            let state = self.state.read().await;
 608            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
 609            host = worktree.host_connection_id;
 610            guests = worktree
 611                .share()?
 612                .guest_connection_ids
 613                .keys()
 614                .copied()
 615                .collect::<Vec<_>>();
 616        }
 617
 618        let sender = request.sender_id;
 619        let receipt = request.receipt();
 620        let response = self
 621            .peer
 622            .forward_request(sender, host, request.payload.clone())
 623            .await?;
 624
 625        broadcast(host, guests, |conn_id| {
 626            let response = response.clone();
 627            let peer = &self.peer;
 628            async move {
 629                if conn_id == sender {
 630                    peer.respond(receipt, response).await
 631                } else {
 632                    peer.forward_send(host, conn_id, response).await
 633                }
 634            }
 635        })
 636        .await?;
 637
 638        Ok(())
 639    }
 640
 641    async fn update_buffer(
 642        self: Arc<Server>,
 643        request: TypedEnvelope<proto::UpdateBuffer>,
 644    ) -> tide::Result<()> {
 645        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 646            .await?;
 647        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 648        Ok(())
 649    }
 650
 651    async fn buffer_saved(
 652        self: Arc<Server>,
 653        request: TypedEnvelope<proto::BufferSaved>,
 654    ) -> tide::Result<()> {
 655        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 656            .await
 657    }
 658
 659    async fn get_channels(
 660        self: Arc<Server>,
 661        request: TypedEnvelope<proto::GetChannels>,
 662    ) -> tide::Result<()> {
 663        let user_id = self
 664            .state
 665            .read()
 666            .await
 667            .user_id_for_connection(request.sender_id)?;
 668        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 669        self.peer
 670            .respond(
 671                request.receipt(),
 672                proto::GetChannelsResponse {
 673                    channels: channels
 674                        .into_iter()
 675                        .map(|chan| proto::Channel {
 676                            id: chan.id.to_proto(),
 677                            name: chan.name,
 678                        })
 679                        .collect(),
 680                },
 681            )
 682            .await?;
 683        Ok(())
 684    }
 685
 686    async fn get_users(
 687        self: Arc<Server>,
 688        request: TypedEnvelope<proto::GetUsers>,
 689    ) -> tide::Result<()> {
 690        let user_id = self
 691            .state
 692            .read()
 693            .await
 694            .user_id_for_connection(request.sender_id)?;
 695        let receipt = request.receipt();
 696        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 697        let users = self
 698            .app_state
 699            .db
 700            .get_users_by_ids(user_id, user_ids)
 701            .await?
 702            .into_iter()
 703            .map(|user| proto::User {
 704                id: user.id.to_proto(),
 705                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
 706                github_login: user.github_login,
 707            })
 708            .collect();
 709        self.peer
 710            .respond(receipt, proto::GetUsersResponse { users })
 711            .await?;
 712        Ok(())
 713    }
 714
 715    async fn update_collaborators_for_users<'a>(
 716        self: &Arc<Server>,
 717        user_ids: impl IntoIterator<Item = &'a UserId>,
 718    ) -> tide::Result<()> {
 719        let mut send_futures = Vec::new();
 720
 721        let state = self.state.read().await;
 722        for user_id in user_ids {
 723            let mut collaborators = HashMap::new();
 724            for worktree_id in state
 725                .visible_worktrees_by_user_id
 726                .get(&user_id)
 727                .unwrap_or(&HashSet::new())
 728            {
 729                let worktree = &state.worktrees[worktree_id];
 730
 731                let mut participants = HashSet::new();
 732                if let Ok(share) = worktree.share() {
 733                    for guest_connection_id in share.guest_connection_ids.keys() {
 734                        let user_id = state.user_id_for_connection(*guest_connection_id)?;
 735                        participants.insert(user_id.to_proto());
 736                    }
 737                }
 738
 739                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
 740                let host =
 741                    collaborators
 742                        .entry(host_user_id)
 743                        .or_insert_with(|| proto::Collaborator {
 744                            user_id: host_user_id.to_proto(),
 745                            worktrees: Vec::new(),
 746                        });
 747                host.worktrees.push(proto::WorktreeMetadata {
 748                    root_name: worktree.root_name.clone(),
 749                    is_shared: worktree.share().is_ok(),
 750                    participants: participants.into_iter().collect(),
 751                });
 752            }
 753
 754            let collaborators = collaborators.into_values().collect::<Vec<_>>();
 755            for connection_id in state.user_connection_ids(*user_id) {
 756                send_futures.push(self.peer.send(
 757                    connection_id,
 758                    proto::UpdateCollaborators {
 759                        collaborators: collaborators.clone(),
 760                    },
 761                ));
 762            }
 763        }
 764
 765        drop(state);
 766        futures::future::try_join_all(send_futures).await?;
 767
 768        Ok(())
 769    }
 770
 771    async fn join_channel(
 772        self: Arc<Self>,
 773        request: TypedEnvelope<proto::JoinChannel>,
 774    ) -> tide::Result<()> {
 775        let user_id = self
 776            .state
 777            .read()
 778            .await
 779            .user_id_for_connection(request.sender_id)?;
 780        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 781        if !self
 782            .app_state
 783            .db
 784            .can_user_access_channel(user_id, channel_id)
 785            .await?
 786        {
 787            Err(anyhow!("access denied"))?;
 788        }
 789
 790        self.state
 791            .write()
 792            .await
 793            .join_channel(request.sender_id, channel_id);
 794        let messages = self
 795            .app_state
 796            .db
 797            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 798            .await?
 799            .into_iter()
 800            .map(|msg| proto::ChannelMessage {
 801                id: msg.id.to_proto(),
 802                body: msg.body,
 803                timestamp: msg.sent_at.unix_timestamp() as u64,
 804                sender_id: msg.sender_id.to_proto(),
 805                nonce: Some(msg.nonce.as_u128().into()),
 806            })
 807            .collect::<Vec<_>>();
 808        self.peer
 809            .respond(
 810                request.receipt(),
 811                proto::JoinChannelResponse {
 812                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 813                    messages,
 814                },
 815            )
 816            .await?;
 817        Ok(())
 818    }
 819
 820    async fn leave_channel(
 821        self: Arc<Self>,
 822        request: TypedEnvelope<proto::LeaveChannel>,
 823    ) -> tide::Result<()> {
 824        let user_id = self
 825            .state
 826            .read()
 827            .await
 828            .user_id_for_connection(request.sender_id)?;
 829        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 830        if !self
 831            .app_state
 832            .db
 833            .can_user_access_channel(user_id, channel_id)
 834            .await?
 835        {
 836            Err(anyhow!("access denied"))?;
 837        }
 838
 839        self.state
 840            .write()
 841            .await
 842            .leave_channel(request.sender_id, channel_id);
 843
 844        Ok(())
 845    }
 846
 847    async fn send_channel_message(
 848        self: Arc<Self>,
 849        request: TypedEnvelope<proto::SendChannelMessage>,
 850    ) -> tide::Result<()> {
 851        let receipt = request.receipt();
 852        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 853        let user_id;
 854        let connection_ids;
 855        {
 856            let state = self.state.read().await;
 857            user_id = state.user_id_for_connection(request.sender_id)?;
 858            if let Some(channel) = state.channels.get(&channel_id) {
 859                connection_ids = channel.connection_ids();
 860            } else {
 861                return Ok(());
 862            }
 863        }
 864
 865        // Validate the message body.
 866        let body = request.payload.body.trim().to_string();
 867        if body.len() > MAX_MESSAGE_LEN {
 868            self.peer
 869                .respond_with_error(
 870                    receipt,
 871                    proto::Error {
 872                        message: "message is too long".to_string(),
 873                    },
 874                )
 875                .await?;
 876            return Ok(());
 877        }
 878        if body.is_empty() {
 879            self.peer
 880                .respond_with_error(
 881                    receipt,
 882                    proto::Error {
 883                        message: "message can't be blank".to_string(),
 884                    },
 885                )
 886                .await?;
 887            return Ok(());
 888        }
 889
 890        let timestamp = OffsetDateTime::now_utc();
 891        let nonce = if let Some(nonce) = request.payload.nonce {
 892            nonce
 893        } else {
 894            self.peer
 895                .respond_with_error(
 896                    receipt,
 897                    proto::Error {
 898                        message: "nonce can't be blank".to_string(),
 899                    },
 900                )
 901                .await?;
 902            return Ok(());
 903        };
 904
 905        let message_id = self
 906            .app_state
 907            .db
 908            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
 909            .await?
 910            .to_proto();
 911        let message = proto::ChannelMessage {
 912            sender_id: user_id.to_proto(),
 913            id: message_id,
 914            body,
 915            timestamp: timestamp.unix_timestamp() as u64,
 916            nonce: Some(nonce),
 917        };
 918        broadcast(request.sender_id, connection_ids, |conn_id| {
 919            self.peer.send(
 920                conn_id,
 921                proto::ChannelMessageSent {
 922                    channel_id: channel_id.to_proto(),
 923                    message: Some(message.clone()),
 924                },
 925            )
 926        })
 927        .await?;
 928        self.peer
 929            .respond(
 930                receipt,
 931                proto::SendChannelMessageResponse {
 932                    message: Some(message),
 933                },
 934            )
 935            .await?;
 936        Ok(())
 937    }
 938
 939    async fn get_channel_messages(
 940        self: Arc<Self>,
 941        request: TypedEnvelope<proto::GetChannelMessages>,
 942    ) -> tide::Result<()> {
 943        let user_id = self
 944            .state
 945            .read()
 946            .await
 947            .user_id_for_connection(request.sender_id)?;
 948        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 949        if !self
 950            .app_state
 951            .db
 952            .can_user_access_channel(user_id, channel_id)
 953            .await?
 954        {
 955            Err(anyhow!("access denied"))?;
 956        }
 957
 958        let messages = self
 959            .app_state
 960            .db
 961            .get_channel_messages(
 962                channel_id,
 963                MESSAGE_COUNT_PER_PAGE,
 964                Some(MessageId::from_proto(request.payload.before_message_id)),
 965            )
 966            .await?
 967            .into_iter()
 968            .map(|msg| proto::ChannelMessage {
 969                id: msg.id.to_proto(),
 970                body: msg.body,
 971                timestamp: msg.sent_at.unix_timestamp() as u64,
 972                sender_id: msg.sender_id.to_proto(),
 973                nonce: Some(msg.nonce.as_u128().into()),
 974            })
 975            .collect::<Vec<_>>();
 976        self.peer
 977            .respond(
 978                request.receipt(),
 979                proto::GetChannelMessagesResponse {
 980                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 981                    messages,
 982                },
 983            )
 984            .await?;
 985        Ok(())
 986    }
 987
 988    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
 989        &self,
 990        worktree_id: u64,
 991        message: &TypedEnvelope<T>,
 992    ) -> tide::Result<()> {
 993        let connection_ids = self
 994            .state
 995            .read()
 996            .await
 997            .read_worktree(worktree_id, message.sender_id)?
 998            .connection_ids();
 999
1000        broadcast(message.sender_id, connection_ids, |conn_id| {
1001            self.peer
1002                .forward_send(message.sender_id, conn_id, message.payload.clone())
1003        })
1004        .await?;
1005
1006        Ok(())
1007    }
1008}
1009
1010pub async fn broadcast<F, T>(
1011    sender_id: ConnectionId,
1012    receiver_ids: Vec<ConnectionId>,
1013    mut f: F,
1014) -> anyhow::Result<()>
1015where
1016    F: FnMut(ConnectionId) -> T,
1017    T: Future<Output = anyhow::Result<()>>,
1018{
1019    let futures = receiver_ids
1020        .into_iter()
1021        .filter(|id| *id != sender_id)
1022        .map(|id| f(id));
1023    futures::future::try_join_all(futures).await?;
1024    Ok(())
1025}
1026
1027impl ServerState {
1028    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1029        if let Some(connection) = self.connections.get_mut(&connection_id) {
1030            connection.channels.insert(channel_id);
1031            self.channels
1032                .entry(channel_id)
1033                .or_default()
1034                .connection_ids
1035                .insert(connection_id);
1036        }
1037    }
1038
1039    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1040        if let Some(connection) = self.connections.get_mut(&connection_id) {
1041            connection.channels.remove(&channel_id);
1042            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1043                entry.get_mut().connection_ids.remove(&connection_id);
1044                if entry.get_mut().connection_ids.is_empty() {
1045                    entry.remove();
1046                }
1047            }
1048        }
1049    }
1050
1051    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1052        Ok(self
1053            .connections
1054            .get(&connection_id)
1055            .ok_or_else(|| anyhow!("unknown connection"))?
1056            .user_id)
1057    }
1058
1059    fn user_connection_ids<'a>(
1060        &'a self,
1061        user_id: UserId,
1062    ) -> impl 'a + Iterator<Item = ConnectionId> {
1063        self.connections_by_user_id
1064            .get(&user_id)
1065            .into_iter()
1066            .flatten()
1067            .copied()
1068    }
1069
1070    // Add the given connection as a guest of the given worktree
1071    fn join_worktree(
1072        &mut self,
1073        connection_id: ConnectionId,
1074        user_id: UserId,
1075        worktree_id: u64,
1076    ) -> tide::Result<(ReplicaId, &Worktree)> {
1077        let connection = self
1078            .connections
1079            .get_mut(&connection_id)
1080            .ok_or_else(|| anyhow!("no such connection"))?;
1081        let worktree = self
1082            .worktrees
1083            .get_mut(&worktree_id)
1084            .ok_or_else(|| anyhow!("no such worktree"))?;
1085        if !worktree.collaborator_user_ids.contains(&user_id) {
1086            Err(anyhow!("no such worktree"))?;
1087        }
1088
1089        let share = worktree.share_mut()?;
1090        connection.worktrees.insert(worktree_id);
1091
1092        let mut replica_id = 1;
1093        while share.active_replica_ids.contains(&replica_id) {
1094            replica_id += 1;
1095        }
1096        share.active_replica_ids.insert(replica_id);
1097        share.guest_connection_ids.insert(connection_id, replica_id);
1098        return Ok((replica_id, worktree));
1099    }
1100
1101    fn read_worktree(
1102        &self,
1103        worktree_id: u64,
1104        connection_id: ConnectionId,
1105    ) -> tide::Result<&Worktree> {
1106        let worktree = self
1107            .worktrees
1108            .get(&worktree_id)
1109            .ok_or_else(|| anyhow!("worktree not found"))?;
1110
1111        if worktree.host_connection_id == connection_id
1112            || worktree
1113                .share()?
1114                .guest_connection_ids
1115                .contains_key(&connection_id)
1116        {
1117            Ok(worktree)
1118        } else {
1119            Err(anyhow!(
1120                "{} is not a member of worktree {}",
1121                connection_id,
1122                worktree_id
1123            ))?
1124        }
1125    }
1126
1127    fn write_worktree(
1128        &mut self,
1129        worktree_id: u64,
1130        connection_id: ConnectionId,
1131    ) -> tide::Result<&mut Worktree> {
1132        let worktree = self
1133            .worktrees
1134            .get_mut(&worktree_id)
1135            .ok_or_else(|| anyhow!("worktree not found"))?;
1136
1137        if worktree.host_connection_id == connection_id
1138            || worktree.share.as_ref().map_or(false, |share| {
1139                share.guest_connection_ids.contains_key(&connection_id)
1140            })
1141        {
1142            Ok(worktree)
1143        } else {
1144            Err(anyhow!(
1145                "{} is not a member of worktree {}",
1146                connection_id,
1147                worktree_id
1148            ))?
1149        }
1150    }
1151
1152    fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1153        let worktree_id = self.next_worktree_id;
1154        for collaborator_user_id in &worktree.collaborator_user_ids {
1155            self.visible_worktrees_by_user_id
1156                .entry(*collaborator_user_id)
1157                .or_default()
1158                .insert(worktree_id);
1159        }
1160        self.next_worktree_id += 1;
1161        self.worktrees.insert(worktree_id, worktree);
1162        worktree_id
1163    }
1164
1165    fn remove_worktree(&mut self, worktree_id: u64) {
1166        let worktree = self.worktrees.remove(&worktree_id).unwrap();
1167        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1168            connection.worktrees.remove(&worktree_id);
1169        }
1170        if let Some(share) = worktree.share {
1171            for connection_id in share.guest_connection_ids.keys() {
1172                if let Some(connection) = self.connections.get_mut(connection_id) {
1173                    connection.worktrees.remove(&worktree_id);
1174                }
1175            }
1176        }
1177        for collaborator_user_id in worktree.collaborator_user_ids {
1178            if let Some(visible_worktrees) = self
1179                .visible_worktrees_by_user_id
1180                .get_mut(&collaborator_user_id)
1181            {
1182                visible_worktrees.remove(&worktree_id);
1183            }
1184        }
1185    }
1186}
1187
1188impl Worktree {
1189    pub fn connection_ids(&self) -> Vec<ConnectionId> {
1190        if let Some(share) = &self.share {
1191            share
1192                .guest_connection_ids
1193                .keys()
1194                .copied()
1195                .chain(Some(self.host_connection_id))
1196                .collect()
1197        } else {
1198            vec![self.host_connection_id]
1199        }
1200    }
1201
1202    fn share(&self) -> tide::Result<&WorktreeShare> {
1203        Ok(self
1204            .share
1205            .as_ref()
1206            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1207    }
1208
1209    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1210        Ok(self
1211            .share
1212            .as_mut()
1213            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1214    }
1215}
1216
1217impl Channel {
1218    fn connection_ids(&self) -> Vec<ConnectionId> {
1219        self.connection_ids.iter().copied().collect()
1220    }
1221}
1222
1223pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1224    let server = Server::new(app.state().clone(), rpc.clone(), None);
1225    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1226        let user_id = request.ext::<UserId>().copied();
1227        let server = server.clone();
1228        async move {
1229            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1230
1231            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1232            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1233            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1234
1235            if !upgrade_requested {
1236                return Ok(Response::new(StatusCode::UpgradeRequired));
1237            }
1238
1239            let header = match request.header("Sec-Websocket-Key") {
1240                Some(h) => h.as_str(),
1241                None => return Err(anyhow!("expected sec-websocket-key"))?,
1242            };
1243
1244            let mut response = Response::new(StatusCode::SwitchingProtocols);
1245            response.insert_header(UPGRADE, "websocket");
1246            response.insert_header(CONNECTION, "Upgrade");
1247            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1248            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1249            response.insert_header("Sec-Websocket-Version", "13");
1250
1251            let http_res: &mut tide::http::Response = response.as_mut();
1252            let upgrade_receiver = http_res.recv_upgrade().await;
1253            let addr = request.remote().unwrap_or("unknown").to_string();
1254            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1255            task::spawn(async move {
1256                if let Some(stream) = upgrade_receiver.await {
1257                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1258                }
1259            });
1260
1261            Ok(response)
1262        }
1263    });
1264}
1265
1266fn header_contains_ignore_case<T>(
1267    request: &tide::Request<T>,
1268    header_name: HeaderName,
1269    value: &str,
1270) -> bool {
1271    request
1272        .header(header_name)
1273        .map(|h| {
1274            h.as_str()
1275                .split(',')
1276                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1277        })
1278        .unwrap_or(false)
1279}
1280
1281#[cfg(test)]
1282mod tests {
1283    use super::*;
1284    use crate::{
1285        auth,
1286        db::{tests::TestDb, UserId},
1287        github, AppState, Config,
1288    };
1289    use async_std::{sync::RwLockReadGuard, task};
1290    use gpui::TestAppContext;
1291    use parking_lot::Mutex;
1292    use postage::{mpsc, watch};
1293    use serde_json::json;
1294    use sqlx::types::time::OffsetDateTime;
1295    use std::{
1296        path::Path,
1297        sync::{
1298            atomic::{AtomicBool, Ordering::SeqCst},
1299            Arc,
1300        },
1301        time::Duration,
1302    };
1303    use zed::{
1304        channel::{Channel, ChannelDetails, ChannelList},
1305        editor::{Editor, Insert},
1306        fs::{FakeFs, Fs as _},
1307        language::LanguageRegistry,
1308        rpc::{self, Client, Credentials, EstablishConnectionError},
1309        settings,
1310        test::FakeHttpClient,
1311        user::UserStore,
1312        worktree::Worktree,
1313    };
1314    use zrpc::Peer;
1315
1316    #[gpui::test]
1317    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1318        let (window_b, _) = cx_b.add_window(|_| EmptyView);
1319        let settings = cx_b.read(settings::test).1;
1320        let lang_registry = Arc::new(LanguageRegistry::new());
1321
1322        // Connect to a server as 2 clients.
1323        let mut server = TestServer::start().await;
1324        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1325        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1326
1327        cx_a.foreground().forbid_parking();
1328
1329        // Share a local worktree as client A
1330        let fs = Arc::new(FakeFs::new());
1331        fs.insert_tree(
1332            "/a",
1333            json!({
1334                ".zed.toml": r#"collaborators = ["user_b"]"#,
1335                "a.txt": "a-contents",
1336                "b.txt": "b-contents",
1337            }),
1338        )
1339        .await;
1340        let worktree_a = Worktree::open_local(
1341            client_a.clone(),
1342            "/a".as_ref(),
1343            fs,
1344            lang_registry.clone(),
1345            &mut cx_a.to_async(),
1346        )
1347        .await
1348        .unwrap();
1349        worktree_a
1350            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1351            .await;
1352        let worktree_id = worktree_a
1353            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1354            .await
1355            .unwrap();
1356
1357        // Join that worktree as client B, and see that a guest has joined as client A.
1358        let worktree_b = Worktree::open_remote(
1359            client_b.clone(),
1360            worktree_id,
1361            lang_registry.clone(),
1362            &mut cx_b.to_async(),
1363        )
1364        .await
1365        .unwrap();
1366        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1367        worktree_a
1368            .condition(&cx_a, |tree, _| {
1369                tree.peers()
1370                    .values()
1371                    .any(|replica_id| *replica_id == replica_id_b)
1372            })
1373            .await;
1374
1375        // Open the same file as client B and client A.
1376        let buffer_b = worktree_b
1377            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1378            .await
1379            .unwrap();
1380        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1381        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1382        let buffer_a = worktree_a
1383            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1384            .await
1385            .unwrap();
1386
1387        // Create a selection set as client B and see that selection set as client A.
1388        let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1389        buffer_a
1390            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1391            .await;
1392
1393        // Edit the buffer as client B and see that edit as client A.
1394        editor_b.update(&mut cx_b, |editor, cx| {
1395            editor.insert(&Insert("ok, ".into()), cx)
1396        });
1397        buffer_a
1398            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1399            .await;
1400
1401        // Remove the selection set as client B, see those selections disappear as client A.
1402        cx_b.update(move |_| drop(editor_b));
1403        buffer_a
1404            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1405            .await;
1406
1407        // Close the buffer as client A, see that the buffer is closed.
1408        cx_a.update(move |_| drop(buffer_a));
1409        worktree_a
1410            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1411            .await;
1412
1413        // Dropping the worktree removes client B from client A's peers.
1414        cx_b.update(move |_| drop(worktree_b));
1415        worktree_a
1416            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1417            .await;
1418    }
1419
1420    #[gpui::test]
1421    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1422        mut cx_a: TestAppContext,
1423        mut cx_b: TestAppContext,
1424        mut cx_c: TestAppContext,
1425    ) {
1426        cx_a.foreground().forbid_parking();
1427        let lang_registry = Arc::new(LanguageRegistry::new());
1428
1429        // Connect to a server as 3 clients.
1430        let mut server = TestServer::start().await;
1431        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1432        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1433        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1434
1435        let fs = Arc::new(FakeFs::new());
1436
1437        // Share a worktree as client A.
1438        fs.insert_tree(
1439            "/a",
1440            json!({
1441                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1442                "file1": "",
1443                "file2": ""
1444            }),
1445        )
1446        .await;
1447
1448        let worktree_a = Worktree::open_local(
1449            client_a.clone(),
1450            "/a".as_ref(),
1451            fs.clone(),
1452            lang_registry.clone(),
1453            &mut cx_a.to_async(),
1454        )
1455        .await
1456        .unwrap();
1457        worktree_a
1458            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1459            .await;
1460        let worktree_id = worktree_a
1461            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1462            .await
1463            .unwrap();
1464
1465        // Join that worktree as clients B and C.
1466        let worktree_b = Worktree::open_remote(
1467            client_b.clone(),
1468            worktree_id,
1469            lang_registry.clone(),
1470            &mut cx_b.to_async(),
1471        )
1472        .await
1473        .unwrap();
1474        let worktree_c = Worktree::open_remote(
1475            client_c.clone(),
1476            worktree_id,
1477            lang_registry.clone(),
1478            &mut cx_c.to_async(),
1479        )
1480        .await
1481        .unwrap();
1482
1483        // Open and edit a buffer as both guests B and C.
1484        let buffer_b = worktree_b
1485            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1486            .await
1487            .unwrap();
1488        let buffer_c = worktree_c
1489            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1490            .await
1491            .unwrap();
1492        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1493        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1494
1495        // Open and edit that buffer as the host.
1496        let buffer_a = worktree_a
1497            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1498            .await
1499            .unwrap();
1500
1501        buffer_a
1502            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1503            .await;
1504        buffer_a.update(&mut cx_a, |buf, cx| {
1505            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1506        });
1507
1508        // Wait for edits to propagate
1509        buffer_a
1510            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1511            .await;
1512        buffer_b
1513            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1514            .await;
1515        buffer_c
1516            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1517            .await;
1518
1519        // Edit the buffer as the host and concurrently save as guest B.
1520        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1521        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1522        save_b.await.unwrap();
1523        assert_eq!(
1524            fs.load("/a/file1".as_ref()).await.unwrap(),
1525            "hi-a, i-am-c, i-am-b, i-am-a"
1526        );
1527        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1528        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1529        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1530
1531        // Make changes on host's file system, see those changes on the guests.
1532        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1533            .await
1534            .unwrap();
1535        fs.insert_file(Path::new("/a/file4"), "4".into())
1536            .await
1537            .unwrap();
1538
1539        worktree_b
1540            .condition(&cx_b, |tree, _| tree.file_count() == 4)
1541            .await;
1542        worktree_c
1543            .condition(&cx_c, |tree, _| tree.file_count() == 4)
1544            .await;
1545        worktree_b.read_with(&cx_b, |tree, _| {
1546            assert_eq!(
1547                tree.paths()
1548                    .map(|p| p.to_string_lossy())
1549                    .collect::<Vec<_>>(),
1550                &[".zed.toml", "file1", "file3", "file4"]
1551            )
1552        });
1553        worktree_c.read_with(&cx_c, |tree, _| {
1554            assert_eq!(
1555                tree.paths()
1556                    .map(|p| p.to_string_lossy())
1557                    .collect::<Vec<_>>(),
1558                &[".zed.toml", "file1", "file3", "file4"]
1559            )
1560        });
1561    }
1562
1563    #[gpui::test]
1564    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1565        cx_a.foreground().forbid_parking();
1566        let lang_registry = Arc::new(LanguageRegistry::new());
1567
1568        // Connect to a server as 2 clients.
1569        let mut server = TestServer::start().await;
1570        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1571        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1572
1573        // Share a local worktree as client A
1574        let fs = Arc::new(FakeFs::new());
1575        fs.insert_tree(
1576            "/dir",
1577            json!({
1578                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1579                "a.txt": "a-contents",
1580            }),
1581        )
1582        .await;
1583
1584        let worktree_a = Worktree::open_local(
1585            client_a.clone(),
1586            "/dir".as_ref(),
1587            fs,
1588            lang_registry.clone(),
1589            &mut cx_a.to_async(),
1590        )
1591        .await
1592        .unwrap();
1593        worktree_a
1594            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1595            .await;
1596        let worktree_id = worktree_a
1597            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1598            .await
1599            .unwrap();
1600
1601        // Join that worktree as client B, and see that a guest has joined as client A.
1602        let worktree_b = Worktree::open_remote(
1603            client_b.clone(),
1604            worktree_id,
1605            lang_registry.clone(),
1606            &mut cx_b.to_async(),
1607        )
1608        .await
1609        .unwrap();
1610
1611        let buffer_b = worktree_b
1612            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1613            .await
1614            .unwrap();
1615        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1616
1617        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1618        buffer_b.read_with(&cx_b, |buf, _| {
1619            assert!(buf.is_dirty());
1620            assert!(!buf.has_conflict());
1621        });
1622
1623        buffer_b
1624            .update(&mut cx_b, |buf, cx| buf.save(cx))
1625            .unwrap()
1626            .await
1627            .unwrap();
1628        worktree_b
1629            .condition(&cx_b, |_, cx| {
1630                buffer_b.read(cx).file().unwrap().mtime != mtime
1631            })
1632            .await;
1633        buffer_b.read_with(&cx_b, |buf, _| {
1634            assert!(!buf.is_dirty());
1635            assert!(!buf.has_conflict());
1636        });
1637
1638        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1639        buffer_b.read_with(&cx_b, |buf, _| {
1640            assert!(buf.is_dirty());
1641            assert!(!buf.has_conflict());
1642        });
1643    }
1644
1645    #[gpui::test]
1646    async fn test_editing_while_guest_opens_buffer(
1647        mut cx_a: TestAppContext,
1648        mut cx_b: TestAppContext,
1649    ) {
1650        cx_a.foreground().forbid_parking();
1651        let lang_registry = Arc::new(LanguageRegistry::new());
1652
1653        // Connect to a server as 2 clients.
1654        let mut server = TestServer::start().await;
1655        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1656        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1657
1658        // Share a local worktree as client A
1659        let fs = Arc::new(FakeFs::new());
1660        fs.insert_tree(
1661            "/dir",
1662            json!({
1663                ".zed.toml": r#"collaborators = ["user_b"]"#,
1664                "a.txt": "a-contents",
1665            }),
1666        )
1667        .await;
1668        let worktree_a = Worktree::open_local(
1669            client_a.clone(),
1670            "/dir".as_ref(),
1671            fs,
1672            lang_registry.clone(),
1673            &mut cx_a.to_async(),
1674        )
1675        .await
1676        .unwrap();
1677        worktree_a
1678            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1679            .await;
1680        let worktree_id = worktree_a
1681            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1682            .await
1683            .unwrap();
1684
1685        // Join that worktree as client B, and see that a guest has joined as client A.
1686        let worktree_b = Worktree::open_remote(
1687            client_b.clone(),
1688            worktree_id,
1689            lang_registry.clone(),
1690            &mut cx_b.to_async(),
1691        )
1692        .await
1693        .unwrap();
1694
1695        let buffer_a = worktree_a
1696            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1697            .await
1698            .unwrap();
1699        let buffer_b = cx_b
1700            .background()
1701            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1702
1703        task::yield_now().await;
1704        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1705
1706        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1707        let buffer_b = buffer_b.await.unwrap();
1708        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1709    }
1710
1711    #[gpui::test]
1712    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1713        cx_a.foreground().forbid_parking();
1714        let lang_registry = Arc::new(LanguageRegistry::new());
1715
1716        // Connect to a server as 2 clients.
1717        let mut server = TestServer::start().await;
1718        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1719        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1720
1721        // Share a local worktree as client A
1722        let fs = Arc::new(FakeFs::new());
1723        fs.insert_tree(
1724            "/a",
1725            json!({
1726                ".zed.toml": r#"collaborators = ["user_b"]"#,
1727                "a.txt": "a-contents",
1728                "b.txt": "b-contents",
1729            }),
1730        )
1731        .await;
1732        let worktree_a = Worktree::open_local(
1733            client_a.clone(),
1734            "/a".as_ref(),
1735            fs,
1736            lang_registry.clone(),
1737            &mut cx_a.to_async(),
1738        )
1739        .await
1740        .unwrap();
1741        worktree_a
1742            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1743            .await;
1744        let worktree_id = worktree_a
1745            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1746            .await
1747            .unwrap();
1748
1749        // Join that worktree as client B, and see that a guest has joined as client A.
1750        let _worktree_b = Worktree::open_remote(
1751            client_b.clone(),
1752            worktree_id,
1753            lang_registry.clone(),
1754            &mut cx_b.to_async(),
1755        )
1756        .await
1757        .unwrap();
1758        worktree_a
1759            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1760            .await;
1761
1762        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1763        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1764        worktree_a
1765            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1766            .await;
1767    }
1768
1769    #[gpui::test]
1770    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1771        cx_a.foreground().forbid_parking();
1772
1773        // Connect to a server as 2 clients.
1774        let mut server = TestServer::start().await;
1775        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1776        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1777
1778        // Create an org that includes these 2 users.
1779        let db = &server.app_state.db;
1780        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1781        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1782            .await
1783            .unwrap();
1784        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1785            .await
1786            .unwrap();
1787
1788        // Create a channel that includes all the users.
1789        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1790        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1791            .await
1792            .unwrap();
1793        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1794            .await
1795            .unwrap();
1796        db.create_channel_message(
1797            channel_id,
1798            current_user_id(&user_store_b),
1799            "hello A, it's B.",
1800            OffsetDateTime::now_utc(),
1801            1,
1802        )
1803        .await
1804        .unwrap();
1805
1806        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1807        channels_a
1808            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1809            .await;
1810        channels_a.read_with(&cx_a, |list, _| {
1811            assert_eq!(
1812                list.available_channels().unwrap(),
1813                &[ChannelDetails {
1814                    id: channel_id.to_proto(),
1815                    name: "test-channel".to_string()
1816                }]
1817            )
1818        });
1819        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1820            this.get_channel(channel_id.to_proto(), cx).unwrap()
1821        });
1822        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1823        channel_a
1824            .condition(&cx_a, |channel, _| {
1825                channel_messages(channel)
1826                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1827            })
1828            .await;
1829
1830        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1831        channels_b
1832            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1833            .await;
1834        channels_b.read_with(&cx_b, |list, _| {
1835            assert_eq!(
1836                list.available_channels().unwrap(),
1837                &[ChannelDetails {
1838                    id: channel_id.to_proto(),
1839                    name: "test-channel".to_string()
1840                }]
1841            )
1842        });
1843
1844        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1845            this.get_channel(channel_id.to_proto(), cx).unwrap()
1846        });
1847        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1848        channel_b
1849            .condition(&cx_b, |channel, _| {
1850                channel_messages(channel)
1851                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1852            })
1853            .await;
1854
1855        channel_a
1856            .update(&mut cx_a, |channel, cx| {
1857                channel
1858                    .send_message("oh, hi B.".to_string(), cx)
1859                    .unwrap()
1860                    .detach();
1861                let task = channel.send_message("sup".to_string(), cx).unwrap();
1862                assert_eq!(
1863                    channel_messages(channel),
1864                    &[
1865                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1866                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
1867                        ("user_a".to_string(), "sup".to_string(), true)
1868                    ]
1869                );
1870                task
1871            })
1872            .await
1873            .unwrap();
1874
1875        channel_b
1876            .condition(&cx_b, |channel, _| {
1877                channel_messages(channel)
1878                    == [
1879                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1880                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
1881                        ("user_a".to_string(), "sup".to_string(), false),
1882                    ]
1883            })
1884            .await;
1885
1886        assert_eq!(
1887            server.state().await.channels[&channel_id]
1888                .connection_ids
1889                .len(),
1890            2
1891        );
1892        cx_b.update(|_| drop(channel_b));
1893        server
1894            .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1895            .await;
1896
1897        cx_a.update(|_| drop(channel_a));
1898        server
1899            .condition(|state| !state.channels.contains_key(&channel_id))
1900            .await;
1901    }
1902
1903    #[gpui::test]
1904    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1905        cx_a.foreground().forbid_parking();
1906
1907        let mut server = TestServer::start().await;
1908        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1909
1910        let db = &server.app_state.db;
1911        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1912        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1913        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1914            .await
1915            .unwrap();
1916        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1917            .await
1918            .unwrap();
1919
1920        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1921        channels_a
1922            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1923            .await;
1924        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1925            this.get_channel(channel_id.to_proto(), cx).unwrap()
1926        });
1927
1928        // Messages aren't allowed to be too long.
1929        channel_a
1930            .update(&mut cx_a, |channel, cx| {
1931                let long_body = "this is long.\n".repeat(1024);
1932                channel.send_message(long_body, cx).unwrap()
1933            })
1934            .await
1935            .unwrap_err();
1936
1937        // Messages aren't allowed to be blank.
1938        channel_a.update(&mut cx_a, |channel, cx| {
1939            channel.send_message(String::new(), cx).unwrap_err()
1940        });
1941
1942        // Leading and trailing whitespace are trimmed.
1943        channel_a
1944            .update(&mut cx_a, |channel, cx| {
1945                channel
1946                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1947                    .unwrap()
1948            })
1949            .await
1950            .unwrap();
1951        assert_eq!(
1952            db.get_channel_messages(channel_id, 10, None)
1953                .await
1954                .unwrap()
1955                .iter()
1956                .map(|m| &m.body)
1957                .collect::<Vec<_>>(),
1958            &["surrounded by whitespace"]
1959        );
1960    }
1961
1962    #[gpui::test]
1963    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1964        cx_a.foreground().forbid_parking();
1965        let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1966
1967        // Connect to a server as 2 clients.
1968        let mut server = TestServer::start().await;
1969        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1970        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1971        let mut status_b = client_b.status();
1972
1973        // Create an org that includes these 2 users.
1974        let db = &server.app_state.db;
1975        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1976        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1977            .await
1978            .unwrap();
1979        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1980            .await
1981            .unwrap();
1982
1983        // Create a channel that includes all the users.
1984        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1985        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1986            .await
1987            .unwrap();
1988        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1989            .await
1990            .unwrap();
1991        db.create_channel_message(
1992            channel_id,
1993            current_user_id(&user_store_b),
1994            "hello A, it's B.",
1995            OffsetDateTime::now_utc(),
1996            2,
1997        )
1998        .await
1999        .unwrap();
2000
2001        let user_store_a =
2002            UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
2003        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
2004        channels_a
2005            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
2006            .await;
2007
2008        channels_a.read_with(&cx_a, |list, _| {
2009            assert_eq!(
2010                list.available_channels().unwrap(),
2011                &[ChannelDetails {
2012                    id: channel_id.to_proto(),
2013                    name: "test-channel".to_string()
2014                }]
2015            )
2016        });
2017        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
2018            this.get_channel(channel_id.to_proto(), cx).unwrap()
2019        });
2020        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
2021        channel_a
2022            .condition(&cx_a, |channel, _| {
2023                channel_messages(channel)
2024                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2025            })
2026            .await;
2027
2028        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
2029        channels_b
2030            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2031            .await;
2032        channels_b.read_with(&cx_b, |list, _| {
2033            assert_eq!(
2034                list.available_channels().unwrap(),
2035                &[ChannelDetails {
2036                    id: channel_id.to_proto(),
2037                    name: "test-channel".to_string()
2038                }]
2039            )
2040        });
2041
2042        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2043            this.get_channel(channel_id.to_proto(), cx).unwrap()
2044        });
2045        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2046        channel_b
2047            .condition(&cx_b, |channel, _| {
2048                channel_messages(channel)
2049                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2050            })
2051            .await;
2052
2053        // Disconnect client B, ensuring we can still access its cached channel data.
2054        server.forbid_connections();
2055        server.disconnect_client(current_user_id(&user_store_b));
2056        while !matches!(
2057            status_b.recv().await,
2058            Some(rpc::Status::ReconnectionError { .. })
2059        ) {}
2060
2061        channels_b.read_with(&cx_b, |channels, _| {
2062            assert_eq!(
2063                channels.available_channels().unwrap(),
2064                [ChannelDetails {
2065                    id: channel_id.to_proto(),
2066                    name: "test-channel".to_string()
2067                }]
2068            )
2069        });
2070        channel_b.read_with(&cx_b, |channel, _| {
2071            assert_eq!(
2072                channel_messages(channel),
2073                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2074            )
2075        });
2076
2077        // Send a message from client B while it is disconnected.
2078        channel_b
2079            .update(&mut cx_b, |channel, cx| {
2080                let task = channel
2081                    .send_message("can you see this?".to_string(), cx)
2082                    .unwrap();
2083                assert_eq!(
2084                    channel_messages(channel),
2085                    &[
2086                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2087                        ("user_b".to_string(), "can you see this?".to_string(), true)
2088                    ]
2089                );
2090                task
2091            })
2092            .await
2093            .unwrap_err();
2094
2095        // Send a message from client A while B is disconnected.
2096        channel_a
2097            .update(&mut cx_a, |channel, cx| {
2098                channel
2099                    .send_message("oh, hi B.".to_string(), cx)
2100                    .unwrap()
2101                    .detach();
2102                let task = channel.send_message("sup".to_string(), cx).unwrap();
2103                assert_eq!(
2104                    channel_messages(channel),
2105                    &[
2106                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2107                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
2108                        ("user_a".to_string(), "sup".to_string(), true)
2109                    ]
2110                );
2111                task
2112            })
2113            .await
2114            .unwrap();
2115
2116        // Give client B a chance to reconnect.
2117        server.allow_connections();
2118        cx_b.foreground().advance_clock(Duration::from_secs(10));
2119
2120        // Verify that B sees the new messages upon reconnection, as well as the message client B
2121        // sent while offline.
2122        channel_b
2123            .condition(&cx_b, |channel, _| {
2124                channel_messages(channel)
2125                    == [
2126                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2127                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2128                        ("user_a".to_string(), "sup".to_string(), false),
2129                        ("user_b".to_string(), "can you see this?".to_string(), false),
2130                    ]
2131            })
2132            .await;
2133
2134        // Ensure client A and B can communicate normally after reconnection.
2135        channel_a
2136            .update(&mut cx_a, |channel, cx| {
2137                channel.send_message("you online?".to_string(), cx).unwrap()
2138            })
2139            .await
2140            .unwrap();
2141        channel_b
2142            .condition(&cx_b, |channel, _| {
2143                channel_messages(channel)
2144                    == [
2145                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2146                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2147                        ("user_a".to_string(), "sup".to_string(), false),
2148                        ("user_b".to_string(), "can you see this?".to_string(), false),
2149                        ("user_a".to_string(), "you online?".to_string(), false),
2150                    ]
2151            })
2152            .await;
2153
2154        channel_b
2155            .update(&mut cx_b, |channel, cx| {
2156                channel.send_message("yep".to_string(), cx).unwrap()
2157            })
2158            .await
2159            .unwrap();
2160        channel_a
2161            .condition(&cx_a, |channel, _| {
2162                channel_messages(channel)
2163                    == [
2164                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2165                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2166                        ("user_a".to_string(), "sup".to_string(), false),
2167                        ("user_b".to_string(), "can you see this?".to_string(), false),
2168                        ("user_a".to_string(), "you online?".to_string(), false),
2169                        ("user_b".to_string(), "yep".to_string(), false),
2170                    ]
2171            })
2172            .await;
2173    }
2174
2175    struct TestServer {
2176        peer: Arc<Peer>,
2177        app_state: Arc<AppState>,
2178        server: Arc<Server>,
2179        notifications: mpsc::Receiver<()>,
2180        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2181        forbid_connections: Arc<AtomicBool>,
2182        _test_db: TestDb,
2183    }
2184
2185    impl TestServer {
2186        async fn start() -> Self {
2187            let test_db = TestDb::new();
2188            let app_state = Self::build_app_state(&test_db).await;
2189            let peer = Peer::new();
2190            let notifications = mpsc::channel(128);
2191            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2192            Self {
2193                peer,
2194                app_state,
2195                server,
2196                notifications: notifications.1,
2197                connection_killers: Default::default(),
2198                forbid_connections: Default::default(),
2199                _test_db: test_db,
2200            }
2201        }
2202
2203        async fn create_client(
2204            &mut self,
2205            cx: &mut TestAppContext,
2206            name: &str,
2207        ) -> (Arc<Client>, Arc<UserStore>) {
2208            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2209            let client_name = name.to_string();
2210            let mut client = Client::new();
2211            let server = self.server.clone();
2212            let connection_killers = self.connection_killers.clone();
2213            let forbid_connections = self.forbid_connections.clone();
2214            Arc::get_mut(&mut client)
2215                .unwrap()
2216                .override_authenticate(move |cx| {
2217                    cx.spawn(|_| async move {
2218                        let access_token = "the-token".to_string();
2219                        Ok(Credentials {
2220                            user_id: user_id.0 as u64,
2221                            access_token,
2222                        })
2223                    })
2224                })
2225                .override_establish_connection(move |credentials, cx| {
2226                    assert_eq!(credentials.user_id, user_id.0 as u64);
2227                    assert_eq!(credentials.access_token, "the-token");
2228
2229                    let server = server.clone();
2230                    let connection_killers = connection_killers.clone();
2231                    let forbid_connections = forbid_connections.clone();
2232                    let client_name = client_name.clone();
2233                    cx.spawn(move |cx| async move {
2234                        if forbid_connections.load(SeqCst) {
2235                            Err(EstablishConnectionError::other(anyhow!(
2236                                "server is forbidding connections"
2237                            )))
2238                        } else {
2239                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2240                            connection_killers.lock().insert(user_id, kill_conn);
2241                            cx.background()
2242                                .spawn(server.handle_connection(server_conn, client_name, user_id))
2243                                .detach();
2244                            Ok(client_conn)
2245                        }
2246                    })
2247                });
2248
2249            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2250            client
2251                .authenticate_and_connect(&cx.to_async())
2252                .await
2253                .unwrap();
2254
2255            let user_store = UserStore::new(client.clone(), http, &cx.background());
2256            let mut authed_user = user_store.watch_current_user();
2257            while authed_user.recv().await.unwrap().is_none() {}
2258
2259            (client, user_store)
2260        }
2261
2262        fn disconnect_client(&self, user_id: UserId) {
2263            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2264                let _ = kill_conn.try_send(Some(()));
2265            }
2266        }
2267
2268        fn forbid_connections(&self) {
2269            self.forbid_connections.store(true, SeqCst);
2270        }
2271
2272        fn allow_connections(&self) {
2273            self.forbid_connections.store(false, SeqCst);
2274        }
2275
2276        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2277            let mut config = Config::default();
2278            config.session_secret = "a".repeat(32);
2279            config.database_url = test_db.url.clone();
2280            let github_client = github::AppClient::test();
2281            Arc::new(AppState {
2282                db: test_db.db().clone(),
2283                handlebars: Default::default(),
2284                auth_client: auth::build_client("", ""),
2285                repo_client: github::RepoClient::test(&github_client),
2286                github_client,
2287                config,
2288            })
2289        }
2290
2291        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2292            self.server.state.read().await
2293        }
2294
2295        async fn condition<F>(&mut self, mut predicate: F)
2296        where
2297            F: FnMut(&ServerState) -> bool,
2298        {
2299            async_std::future::timeout(Duration::from_millis(500), async {
2300                while !(predicate)(&*self.server.state.read().await) {
2301                    self.notifications.recv().await;
2302                }
2303            })
2304            .await
2305            .expect("condition timed out");
2306        }
2307    }
2308
2309    impl Drop for TestServer {
2310        fn drop(&mut self) {
2311            task::block_on(self.peer.reset());
2312        }
2313    }
2314
2315    fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2316        UserId::from_proto(user_store.current_user().unwrap().id)
2317    }
2318
2319    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2320        channel
2321            .messages()
2322            .cursor::<(), ()>()
2323            .map(|m| {
2324                (
2325                    m.sender.github_login.clone(),
2326                    m.body.clone(),
2327                    m.is_pending(),
2328                )
2329            })
2330            .collect()
2331    }
2332
2333    struct EmptyView;
2334
2335    impl gpui::Entity for EmptyView {
2336        type Event = ();
2337    }
2338
2339    impl gpui::View for EmptyView {
2340        fn ui_name() -> &'static str {
2341            "empty view"
2342        }
2343
2344        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2345            gpui::Element::boxed(gpui::elements::Empty)
2346        }
2347    }
2348}