rpc.rs

   1use super::{
   2    auth,
   3    db::{ChannelId, MessageId, User, UserId},
   4    AppState,
   5};
   6use anyhow::anyhow;
   7use async_std::{sync::RwLock, task};
   8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
   9use futures::{future::BoxFuture, FutureExt};
  10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  11use sha1::{Digest as _, Sha1};
  12use std::{
  13    any::TypeId,
  14    collections::{hash_map, HashMap, HashSet},
  15    future::Future,
  16    mem,
  17    sync::Arc,
  18    time::Instant,
  19};
  20use surf::StatusCode;
  21use tide::log;
  22use tide::{
  23    http::headers::{HeaderName, CONNECTION, UPGRADE},
  24    Request, Response,
  25};
  26use time::OffsetDateTime;
  27use zrpc::{
  28    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  29    Connection, ConnectionId, Peer, TypedEnvelope,
  30};
  31
  32type ReplicaId = u16;
  33
  34type MessageHandler = Box<
  35    dyn Send
  36        + Sync
  37        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  38>;
  39
  40pub struct Server {
  41    peer: Arc<Peer>,
  42    state: RwLock<ServerState>,
  43    app_state: Arc<AppState>,
  44    handlers: HashMap<TypeId, MessageHandler>,
  45    notifications: Option<mpsc::Sender<()>>,
  46}
  47
  48#[derive(Default)]
  49struct ServerState {
  50    connections: HashMap<ConnectionId, ConnectionState>,
  51    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
  52    pub worktrees: HashMap<u64, Worktree>,
  53    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
  54    channels: HashMap<ChannelId, Channel>,
  55    next_worktree_id: u64,
  56}
  57
  58struct ConnectionState {
  59    user_id: UserId,
  60    worktrees: HashSet<u64>,
  61    channels: HashSet<ChannelId>,
  62}
  63
  64struct Worktree {
  65    host_connection_id: ConnectionId,
  66    collaborator_user_ids: Vec<UserId>,
  67    root_name: String,
  68    share: Option<WorktreeShare>,
  69}
  70
  71struct WorktreeShare {
  72    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
  73    active_replica_ids: HashSet<ReplicaId>,
  74    entries: HashMap<u64, proto::Entry>,
  75}
  76
  77#[derive(Default)]
  78struct Channel {
  79    connection_ids: HashSet<ConnectionId>,
  80}
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84
  85impl Server {
  86    pub fn new(
  87        app_state: Arc<AppState>,
  88        peer: Arc<Peer>,
  89        notifications: Option<mpsc::Sender<()>>,
  90    ) -> Arc<Self> {
  91        let mut server = Self {
  92            peer,
  93            app_state,
  94            state: Default::default(),
  95            handlers: Default::default(),
  96            notifications,
  97        };
  98
  99        server
 100            .add_handler(Server::ping)
 101            .add_handler(Server::open_worktree)
 102            .add_handler(Server::close_worktree)
 103            .add_handler(Server::share_worktree)
 104            .add_handler(Server::unshare_worktree)
 105            .add_handler(Server::join_worktree)
 106            .add_handler(Server::update_worktree)
 107            .add_handler(Server::open_buffer)
 108            .add_handler(Server::close_buffer)
 109            .add_handler(Server::update_buffer)
 110            .add_handler(Server::buffer_saved)
 111            .add_handler(Server::save_buffer)
 112            .add_handler(Server::get_channels)
 113            .add_handler(Server::get_users)
 114            .add_handler(Server::join_channel)
 115            .add_handler(Server::leave_channel)
 116            .add_handler(Server::send_channel_message)
 117            .add_handler(Server::get_channel_messages)
 118            .add_handler(Server::get_collaborators);
 119
 120        Arc::new(server)
 121    }
 122
 123    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 124    where
 125        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 126        Fut: 'static + Send + Future<Output = tide::Result<()>>,
 127        M: EnvelopedMessage,
 128    {
 129        let prev_handler = self.handlers.insert(
 130            TypeId::of::<M>(),
 131            Box::new(move |server, envelope| {
 132                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 133                (handler)(server, *envelope).boxed()
 134            }),
 135        );
 136        if prev_handler.is_some() {
 137            panic!("registered a handler for the same message twice");
 138        }
 139        self
 140    }
 141
 142    pub fn handle_connection(
 143        self: &Arc<Self>,
 144        connection: Connection,
 145        addr: String,
 146        user_id: UserId,
 147    ) -> impl Future<Output = ()> {
 148        let this = self.clone();
 149        async move {
 150            let (connection_id, handle_io, mut incoming_rx) =
 151                this.peer.add_connection(connection).await;
 152            this.add_connection(connection_id, user_id).await;
 153
 154            let handle_io = handle_io.fuse();
 155            futures::pin_mut!(handle_io);
 156            loop {
 157                let next_message = incoming_rx.recv().fuse();
 158                futures::pin_mut!(next_message);
 159                futures::select_biased! {
 160                    message = next_message => {
 161                        if let Some(message) = message {
 162                            let start_time = Instant::now();
 163                            log::info!("RPC message received: {}", message.payload_type_name());
 164                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 165                                if let Err(err) = (handler)(this.clone(), message).await {
 166                                    log::error!("error handling message: {:?}", err);
 167                                } else {
 168                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 169                                }
 170
 171                                if let Some(mut notifications) = this.notifications.clone() {
 172                                    let _ = notifications.send(()).await;
 173                                }
 174                            } else {
 175                                log::warn!("unhandled message: {}", message.payload_type_name());
 176                            }
 177                        } else {
 178                            log::info!("rpc connection closed {:?}", addr);
 179                            break;
 180                        }
 181                    }
 182                    handle_io = handle_io => {
 183                        if let Err(err) = handle_io {
 184                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 185                        }
 186                        break;
 187                    }
 188                }
 189            }
 190
 191            if let Err(err) = this.sign_out(connection_id).await {
 192                log::error!("error signing out connection {:?} - {:?}", addr, err);
 193            }
 194        }
 195    }
 196
 197    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
 198        self.peer.disconnect(connection_id).await;
 199        let worktree_ids = self.remove_connection(connection_id).await;
 200        for worktree_id in worktree_ids {
 201            let state = self.state.read().await;
 202            if let Some(worktree) = state.worktrees.get(&worktree_id) {
 203                broadcast(connection_id, worktree.connection_ids(), |conn_id| {
 204                    self.peer.send(
 205                        conn_id,
 206                        proto::RemovePeer {
 207                            worktree_id,
 208                            peer_id: connection_id.0,
 209                        },
 210                    )
 211                })
 212                .await?;
 213            }
 214        }
 215        Ok(())
 216    }
 217
 218    // Add a new connection associated with a given user.
 219    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
 220        let mut state = self.state.write().await;
 221        state.connections.insert(
 222            connection_id,
 223            ConnectionState {
 224                user_id,
 225                worktrees: Default::default(),
 226                channels: Default::default(),
 227            },
 228        );
 229        state
 230            .connections_by_user_id
 231            .entry(user_id)
 232            .or_default()
 233            .insert(connection_id);
 234    }
 235
 236    // Remove the given connection and its association with any worktrees.
 237    async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
 238        let mut worktree_ids = Vec::new();
 239        let mut state = self.state.write().await;
 240        if let Some(connection) = state.connections.remove(&connection_id) {
 241            for channel_id in connection.channels {
 242                if let Some(channel) = state.channels.get_mut(&channel_id) {
 243                    channel.connection_ids.remove(&connection_id);
 244                }
 245            }
 246            for worktree_id in connection.worktrees {
 247                if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
 248                    if worktree.host_connection_id == connection_id {
 249                        worktree_ids.push(worktree_id);
 250                    } else if let Some(share_state) = worktree.share.as_mut() {
 251                        if let Some(replica_id) =
 252                            share_state.guest_connection_ids.remove(&connection_id)
 253                        {
 254                            share_state.active_replica_ids.remove(&replica_id);
 255                            worktree_ids.push(worktree_id);
 256                        }
 257                    }
 258                }
 259            }
 260
 261            let user_connections = state
 262                .connections_by_user_id
 263                .get_mut(&connection.user_id)
 264                .unwrap();
 265            user_connections.remove(&connection_id);
 266            if user_connections.is_empty() {
 267                state.connections_by_user_id.remove(&connection.user_id);
 268            }
 269        }
 270        worktree_ids
 271    }
 272
 273    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
 274        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 275        Ok(())
 276    }
 277
 278    async fn open_worktree(
 279        self: Arc<Server>,
 280        request: TypedEnvelope<proto::OpenWorktree>,
 281    ) -> tide::Result<()> {
 282        let receipt = request.receipt();
 283
 284        let mut collaborator_user_ids = Vec::new();
 285        for github_login in request.payload.collaborator_logins {
 286            match self.app_state.db.create_user(&github_login, false).await {
 287                Ok(user_id) => collaborator_user_ids.push(user_id),
 288                Err(err) => {
 289                    let message = err.to_string();
 290                    self.peer
 291                        .respond_with_error(receipt, proto::Error { message })
 292                        .await?;
 293                    return Ok(());
 294                }
 295            }
 296        }
 297
 298        let mut state = self.state.write().await;
 299        let worktree_id = state.add_worktree(Worktree {
 300            host_connection_id: request.sender_id,
 301            collaborator_user_ids,
 302            root_name: request.payload.root_name,
 303            share: None,
 304        });
 305
 306        self.peer
 307            .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
 308            .await?;
 309        Ok(())
 310    }
 311
 312    async fn share_worktree(
 313        self: Arc<Server>,
 314        mut request: TypedEnvelope<proto::ShareWorktree>,
 315    ) -> tide::Result<()> {
 316        let worktree = request
 317            .payload
 318            .worktree
 319            .as_mut()
 320            .ok_or_else(|| anyhow!("missing worktree"))?;
 321        let entries = mem::take(&mut worktree.entries)
 322            .into_iter()
 323            .map(|entry| (entry.id, entry))
 324            .collect();
 325        let mut state = self.state.write().await;
 326        if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
 327            worktree.share = Some(WorktreeShare {
 328                guest_connection_ids: Default::default(),
 329                active_replica_ids: Default::default(),
 330                entries,
 331            });
 332            self.peer
 333                .respond(request.receipt(), proto::ShareWorktreeResponse {})
 334                .await?;
 335        } else {
 336            self.peer
 337                .respond_with_error(
 338                    request.receipt(),
 339                    proto::Error {
 340                        message: "no such worktree".to_string(),
 341                    },
 342                )
 343                .await?;
 344        }
 345        Ok(())
 346    }
 347
 348    async fn unshare_worktree(
 349        self: Arc<Server>,
 350        request: TypedEnvelope<proto::UnshareWorktree>,
 351    ) -> tide::Result<()> {
 352        let worktree_id = request.payload.worktree_id;
 353
 354        let connection_ids;
 355        {
 356            let mut state = self.state.write().await;
 357            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 358            if worktree.host_connection_id != request.sender_id {
 359                return Err(anyhow!("no such worktree"))?;
 360            }
 361
 362            connection_ids = worktree.connection_ids();
 363            worktree.share.take();
 364            for connection_id in &connection_ids {
 365                if let Some(connection) = state.connections.get_mut(connection_id) {
 366                    connection.worktrees.remove(&worktree_id);
 367                }
 368            }
 369        }
 370
 371        broadcast(request.sender_id, connection_ids, |conn_id| {
 372            self.peer
 373                .send(conn_id, proto::UnshareWorktree { worktree_id })
 374        })
 375        .await?;
 376
 377        Ok(())
 378    }
 379
 380    async fn join_worktree(
 381        self: Arc<Server>,
 382        request: TypedEnvelope<proto::JoinWorktree>,
 383    ) -> tide::Result<()> {
 384        let worktree_id = request.payload.worktree_id;
 385        let user_id = self
 386            .state
 387            .read()
 388            .await
 389            .user_id_for_connection(request.sender_id)?;
 390
 391        let response;
 392        let connection_ids;
 393        let mut state = self.state.write().await;
 394        match state.join_worktree(request.sender_id, user_id, worktree_id) {
 395            Ok((peer_replica_id, worktree)) => {
 396                let share = worktree.share()?;
 397                let peer_count = share.guest_connection_ids.len();
 398                let mut peers = Vec::with_capacity(peer_count);
 399                peers.push(proto::Peer {
 400                    peer_id: worktree.host_connection_id.0,
 401                    replica_id: 0,
 402                });
 403                for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
 404                    if *peer_conn_id != request.sender_id {
 405                        peers.push(proto::Peer {
 406                            peer_id: peer_conn_id.0,
 407                            replica_id: *peer_replica_id as u32,
 408                        });
 409                    }
 410                }
 411                connection_ids = worktree.connection_ids();
 412                response = proto::JoinWorktreeResponse {
 413                    worktree: Some(proto::Worktree {
 414                        id: worktree_id,
 415                        root_name: worktree.root_name.clone(),
 416                        entries: share.entries.values().cloned().collect(),
 417                    }),
 418                    replica_id: peer_replica_id as u32,
 419                    peers,
 420                };
 421            }
 422            Err(error) => {
 423                self.peer
 424                    .respond_with_error(
 425                        request.receipt(),
 426                        proto::Error {
 427                            message: error.to_string(),
 428                        },
 429                    )
 430                    .await?;
 431                return Ok(());
 432            }
 433        }
 434
 435        broadcast(request.sender_id, connection_ids, |conn_id| {
 436            self.peer.send(
 437                conn_id,
 438                proto::AddPeer {
 439                    worktree_id,
 440                    peer: Some(proto::Peer {
 441                        peer_id: request.sender_id.0,
 442                        replica_id: response.replica_id,
 443                    }),
 444                },
 445            )
 446        })
 447        .await?;
 448        self.peer.respond(request.receipt(), response).await?;
 449
 450        Ok(())
 451    }
 452
 453    async fn close_worktree(
 454        self: Arc<Server>,
 455        request: TypedEnvelope<proto::CloseWorktree>,
 456    ) -> tide::Result<()> {
 457        let worktree_id = request.payload.worktree_id;
 458        let connection_ids;
 459        let mut is_host = false;
 460        let mut is_guest = false;
 461        {
 462            let mut state = self.state.write().await;
 463            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
 464            connection_ids = worktree.connection_ids();
 465
 466            if worktree.host_connection_id == request.sender_id {
 467                is_host = true;
 468                state.remove_worktree(worktree_id);
 469            } else {
 470                let share = worktree.share_mut()?;
 471                if let Some(replica_id) = share.guest_connection_ids.remove(&request.sender_id) {
 472                    is_guest = true;
 473                    share.active_replica_ids.remove(&replica_id);
 474                }
 475            }
 476        }
 477
 478        if is_host {
 479            broadcast(request.sender_id, connection_ids, |conn_id| {
 480                self.peer
 481                    .send(conn_id, proto::UnshareWorktree { worktree_id })
 482            })
 483            .await?;
 484        } else if is_guest {
 485            broadcast(request.sender_id, connection_ids, |conn_id| {
 486                self.peer.send(
 487                    conn_id,
 488                    proto::RemovePeer {
 489                        worktree_id,
 490                        peer_id: request.sender_id.0,
 491                    },
 492                )
 493            })
 494            .await?
 495        }
 496
 497        Ok(())
 498    }
 499
 500    async fn update_worktree(
 501        self: Arc<Server>,
 502        request: TypedEnvelope<proto::UpdateWorktree>,
 503    ) -> tide::Result<()> {
 504        {
 505            let mut state = self.state.write().await;
 506            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 507            let share = worktree.share_mut()?;
 508
 509            for entry_id in &request.payload.removed_entries {
 510                share.entries.remove(&entry_id);
 511            }
 512
 513            for entry in &request.payload.updated_entries {
 514                share.entries.insert(entry.id, entry.clone());
 515            }
 516        }
 517
 518        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 519            .await?;
 520        Ok(())
 521    }
 522
 523    async fn open_buffer(
 524        self: Arc<Server>,
 525        request: TypedEnvelope<proto::OpenBuffer>,
 526    ) -> tide::Result<()> {
 527        let receipt = request.receipt();
 528        let worktree_id = request.payload.worktree_id;
 529        let host_connection_id = self
 530            .state
 531            .read()
 532            .await
 533            .read_worktree(worktree_id, request.sender_id)?
 534            .host_connection_id;
 535
 536        let response = self
 537            .peer
 538            .forward_request(request.sender_id, host_connection_id, request.payload)
 539            .await?;
 540        self.peer.respond(receipt, response).await?;
 541        Ok(())
 542    }
 543
 544    async fn close_buffer(
 545        self: Arc<Server>,
 546        request: TypedEnvelope<proto::CloseBuffer>,
 547    ) -> tide::Result<()> {
 548        let host_connection_id = self
 549            .state
 550            .read()
 551            .await
 552            .read_worktree(request.payload.worktree_id, request.sender_id)?
 553            .host_connection_id;
 554
 555        self.peer
 556            .forward_send(request.sender_id, host_connection_id, request.payload)
 557            .await?;
 558
 559        Ok(())
 560    }
 561
 562    async fn save_buffer(
 563        self: Arc<Server>,
 564        request: TypedEnvelope<proto::SaveBuffer>,
 565    ) -> tide::Result<()> {
 566        let host;
 567        let guests;
 568        {
 569            let state = self.state.read().await;
 570            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
 571            host = worktree.host_connection_id;
 572            guests = worktree
 573                .share()?
 574                .guest_connection_ids
 575                .keys()
 576                .copied()
 577                .collect::<Vec<_>>();
 578        }
 579
 580        let sender = request.sender_id;
 581        let receipt = request.receipt();
 582        let response = self
 583            .peer
 584            .forward_request(sender, host, request.payload.clone())
 585            .await?;
 586
 587        broadcast(host, guests, |conn_id| {
 588            let response = response.clone();
 589            let peer = &self.peer;
 590            async move {
 591                if conn_id == sender {
 592                    peer.respond(receipt, response).await
 593                } else {
 594                    peer.forward_send(host, conn_id, response).await
 595                }
 596            }
 597        })
 598        .await?;
 599
 600        Ok(())
 601    }
 602
 603    async fn update_buffer(
 604        self: Arc<Server>,
 605        request: TypedEnvelope<proto::UpdateBuffer>,
 606    ) -> tide::Result<()> {
 607        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 608            .await?;
 609        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 610        Ok(())
 611    }
 612
 613    async fn buffer_saved(
 614        self: Arc<Server>,
 615        request: TypedEnvelope<proto::BufferSaved>,
 616    ) -> tide::Result<()> {
 617        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 618            .await
 619    }
 620
 621    async fn get_channels(
 622        self: Arc<Server>,
 623        request: TypedEnvelope<proto::GetChannels>,
 624    ) -> tide::Result<()> {
 625        let user_id = self
 626            .state
 627            .read()
 628            .await
 629            .user_id_for_connection(request.sender_id)?;
 630        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 631        self.peer
 632            .respond(
 633                request.receipt(),
 634                proto::GetChannelsResponse {
 635                    channels: channels
 636                        .into_iter()
 637                        .map(|chan| proto::Channel {
 638                            id: chan.id.to_proto(),
 639                            name: chan.name,
 640                        })
 641                        .collect(),
 642                },
 643            )
 644            .await?;
 645        Ok(())
 646    }
 647
 648    async fn get_users(
 649        self: Arc<Server>,
 650        request: TypedEnvelope<proto::GetUsers>,
 651    ) -> tide::Result<()> {
 652        let user_id = self
 653            .state
 654            .read()
 655            .await
 656            .user_id_for_connection(request.sender_id)?;
 657        let receipt = request.receipt();
 658        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 659        let users = self
 660            .app_state
 661            .db
 662            .get_users_by_ids(user_id, user_ids)
 663            .await?
 664            .into_iter()
 665            .map(|user| proto::User {
 666                id: user.id.to_proto(),
 667                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
 668                github_login: user.github_login,
 669            })
 670            .collect();
 671        self.peer
 672            .respond(receipt, proto::GetUsersResponse { users })
 673            .await?;
 674        Ok(())
 675    }
 676
 677    async fn get_collaborators(
 678        self: Arc<Server>,
 679        request: TypedEnvelope<proto::GetCollaborators>,
 680    ) -> tide::Result<()> {
 681        let mut collaborators = HashMap::new();
 682        {
 683            let state = self.state.read().await;
 684            let user_id = state.user_id_for_connection(request.sender_id)?;
 685            for worktree_id in state
 686                .visible_worktrees_by_user_id
 687                .get(&user_id)
 688                .unwrap_or(&HashSet::new())
 689            {
 690                let worktree = &state.worktrees[worktree_id];
 691
 692                let mut participants = Vec::new();
 693                for collaborator_user_id in &worktree.collaborator_user_ids {
 694                    collaborators
 695                        .entry(*collaborator_user_id)
 696                        .or_insert_with(|| proto::Collaborator {
 697                            user_id: collaborator_user_id.to_proto(),
 698                            worktrees: Vec::new(),
 699                            is_online: state.is_online(*collaborator_user_id),
 700                        });
 701
 702                    if let Ok(share) = worktree.share() {
 703                        let mut conn_ids = state.user_connection_ids(*collaborator_user_id);
 704                        if conn_ids.any(|c| share.guest_connection_ids.contains_key(&c)) {
 705                            participants.push(collaborator_user_id.to_proto());
 706                        }
 707                    }
 708                }
 709
 710                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
 711                let host =
 712                    collaborators
 713                        .entry(host_user_id)
 714                        .or_insert_with(|| proto::Collaborator {
 715                            user_id: host_user_id.to_proto(),
 716                            worktrees: Vec::new(),
 717                            is_online: true,
 718                        });
 719                host.worktrees.push(proto::CollaboratorWorktree {
 720                    is_shared: worktree.share().is_ok(),
 721                    participants,
 722                });
 723            }
 724        }
 725
 726        self.peer
 727            .respond(
 728                request.receipt(),
 729                proto::GetCollaboratorsResponse {
 730                    collaborators: collaborators.into_values().collect(),
 731                },
 732            )
 733            .await?;
 734        Ok(())
 735    }
 736
 737    async fn join_channel(
 738        self: Arc<Self>,
 739        request: TypedEnvelope<proto::JoinChannel>,
 740    ) -> tide::Result<()> {
 741        let user_id = self
 742            .state
 743            .read()
 744            .await
 745            .user_id_for_connection(request.sender_id)?;
 746        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 747        if !self
 748            .app_state
 749            .db
 750            .can_user_access_channel(user_id, channel_id)
 751            .await?
 752        {
 753            Err(anyhow!("access denied"))?;
 754        }
 755
 756        self.state
 757            .write()
 758            .await
 759            .join_channel(request.sender_id, channel_id);
 760        let messages = self
 761            .app_state
 762            .db
 763            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 764            .await?
 765            .into_iter()
 766            .map(|msg| proto::ChannelMessage {
 767                id: msg.id.to_proto(),
 768                body: msg.body,
 769                timestamp: msg.sent_at.unix_timestamp() as u64,
 770                sender_id: msg.sender_id.to_proto(),
 771                nonce: Some(msg.nonce.as_u128().into()),
 772            })
 773            .collect::<Vec<_>>();
 774        self.peer
 775            .respond(
 776                request.receipt(),
 777                proto::JoinChannelResponse {
 778                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 779                    messages,
 780                },
 781            )
 782            .await?;
 783        Ok(())
 784    }
 785
 786    async fn leave_channel(
 787        self: Arc<Self>,
 788        request: TypedEnvelope<proto::LeaveChannel>,
 789    ) -> tide::Result<()> {
 790        let user_id = self
 791            .state
 792            .read()
 793            .await
 794            .user_id_for_connection(request.sender_id)?;
 795        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 796        if !self
 797            .app_state
 798            .db
 799            .can_user_access_channel(user_id, channel_id)
 800            .await?
 801        {
 802            Err(anyhow!("access denied"))?;
 803        }
 804
 805        self.state
 806            .write()
 807            .await
 808            .leave_channel(request.sender_id, channel_id);
 809
 810        Ok(())
 811    }
 812
 813    async fn send_channel_message(
 814        self: Arc<Self>,
 815        request: TypedEnvelope<proto::SendChannelMessage>,
 816    ) -> tide::Result<()> {
 817        let receipt = request.receipt();
 818        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 819        let user_id;
 820        let connection_ids;
 821        {
 822            let state = self.state.read().await;
 823            user_id = state.user_id_for_connection(request.sender_id)?;
 824            if let Some(channel) = state.channels.get(&channel_id) {
 825                connection_ids = channel.connection_ids();
 826            } else {
 827                return Ok(());
 828            }
 829        }
 830
 831        // Validate the message body.
 832        let body = request.payload.body.trim().to_string();
 833        if body.len() > MAX_MESSAGE_LEN {
 834            self.peer
 835                .respond_with_error(
 836                    receipt,
 837                    proto::Error {
 838                        message: "message is too long".to_string(),
 839                    },
 840                )
 841                .await?;
 842            return Ok(());
 843        }
 844        if body.is_empty() {
 845            self.peer
 846                .respond_with_error(
 847                    receipt,
 848                    proto::Error {
 849                        message: "message can't be blank".to_string(),
 850                    },
 851                )
 852                .await?;
 853            return Ok(());
 854        }
 855
 856        let timestamp = OffsetDateTime::now_utc();
 857        let nonce = if let Some(nonce) = request.payload.nonce {
 858            nonce
 859        } else {
 860            self.peer
 861                .respond_with_error(
 862                    receipt,
 863                    proto::Error {
 864                        message: "nonce can't be blank".to_string(),
 865                    },
 866                )
 867                .await?;
 868            return Ok(());
 869        };
 870
 871        let message_id = self
 872            .app_state
 873            .db
 874            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
 875            .await?
 876            .to_proto();
 877        let message = proto::ChannelMessage {
 878            sender_id: user_id.to_proto(),
 879            id: message_id,
 880            body,
 881            timestamp: timestamp.unix_timestamp() as u64,
 882            nonce: Some(nonce),
 883        };
 884        broadcast(request.sender_id, connection_ids, |conn_id| {
 885            self.peer.send(
 886                conn_id,
 887                proto::ChannelMessageSent {
 888                    channel_id: channel_id.to_proto(),
 889                    message: Some(message.clone()),
 890                },
 891            )
 892        })
 893        .await?;
 894        self.peer
 895            .respond(
 896                receipt,
 897                proto::SendChannelMessageResponse {
 898                    message: Some(message),
 899                },
 900            )
 901            .await?;
 902        Ok(())
 903    }
 904
 905    async fn get_channel_messages(
 906        self: Arc<Self>,
 907        request: TypedEnvelope<proto::GetChannelMessages>,
 908    ) -> tide::Result<()> {
 909        let user_id = self
 910            .state
 911            .read()
 912            .await
 913            .user_id_for_connection(request.sender_id)?;
 914        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 915        if !self
 916            .app_state
 917            .db
 918            .can_user_access_channel(user_id, channel_id)
 919            .await?
 920        {
 921            Err(anyhow!("access denied"))?;
 922        }
 923
 924        let messages = self
 925            .app_state
 926            .db
 927            .get_channel_messages(
 928                channel_id,
 929                MESSAGE_COUNT_PER_PAGE,
 930                Some(MessageId::from_proto(request.payload.before_message_id)),
 931            )
 932            .await?
 933            .into_iter()
 934            .map(|msg| proto::ChannelMessage {
 935                id: msg.id.to_proto(),
 936                body: msg.body,
 937                timestamp: msg.sent_at.unix_timestamp() as u64,
 938                sender_id: msg.sender_id.to_proto(),
 939                nonce: Some(msg.nonce.as_u128().into()),
 940            })
 941            .collect::<Vec<_>>();
 942        self.peer
 943            .respond(
 944                request.receipt(),
 945                proto::GetChannelMessagesResponse {
 946                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 947                    messages,
 948                },
 949            )
 950            .await?;
 951        Ok(())
 952    }
 953
 954    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
 955        &self,
 956        worktree_id: u64,
 957        message: &TypedEnvelope<T>,
 958    ) -> tide::Result<()> {
 959        let connection_ids = self
 960            .state
 961            .read()
 962            .await
 963            .read_worktree(worktree_id, message.sender_id)?
 964            .connection_ids();
 965
 966        broadcast(message.sender_id, connection_ids, |conn_id| {
 967            self.peer
 968                .forward_send(message.sender_id, conn_id, message.payload.clone())
 969        })
 970        .await?;
 971
 972        Ok(())
 973    }
 974}
 975
 976pub async fn broadcast<F, T>(
 977    sender_id: ConnectionId,
 978    receiver_ids: Vec<ConnectionId>,
 979    mut f: F,
 980) -> anyhow::Result<()>
 981where
 982    F: FnMut(ConnectionId) -> T,
 983    T: Future<Output = anyhow::Result<()>>,
 984{
 985    let futures = receiver_ids
 986        .into_iter()
 987        .filter(|id| *id != sender_id)
 988        .map(|id| f(id));
 989    futures::future::try_join_all(futures).await?;
 990    Ok(())
 991}
 992
 993impl ServerState {
 994    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
 995        if let Some(connection) = self.connections.get_mut(&connection_id) {
 996            connection.channels.insert(channel_id);
 997            self.channels
 998                .entry(channel_id)
 999                .or_default()
1000                .connection_ids
1001                .insert(connection_id);
1002        }
1003    }
1004
1005    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
1006        if let Some(connection) = self.connections.get_mut(&connection_id) {
1007            connection.channels.remove(&channel_id);
1008            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
1009                entry.get_mut().connection_ids.remove(&connection_id);
1010                if entry.get_mut().connection_ids.is_empty() {
1011                    entry.remove();
1012                }
1013            }
1014        }
1015    }
1016
1017    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
1018        Ok(self
1019            .connections
1020            .get(&connection_id)
1021            .ok_or_else(|| anyhow!("unknown connection"))?
1022            .user_id)
1023    }
1024
1025    fn user_connection_ids<'a>(
1026        &'a self,
1027        user_id: UserId,
1028    ) -> impl 'a + Iterator<Item = ConnectionId> {
1029        self.connections_by_user_id
1030            .get(&user_id)
1031            .into_iter()
1032            .flatten()
1033            .copied()
1034    }
1035
1036    fn is_online(&self, user_id: UserId) -> bool {
1037        self.connections_by_user_id.contains_key(&user_id)
1038    }
1039
1040    // Add the given connection as a guest of the given worktree
1041    fn join_worktree(
1042        &mut self,
1043        connection_id: ConnectionId,
1044        user_id: UserId,
1045        worktree_id: u64,
1046    ) -> tide::Result<(ReplicaId, &Worktree)> {
1047        let connection = self
1048            .connections
1049            .get_mut(&connection_id)
1050            .ok_or_else(|| anyhow!("no such connection"))?;
1051        let worktree = self
1052            .worktrees
1053            .get_mut(&worktree_id)
1054            .ok_or_else(|| anyhow!("no such worktree"))?;
1055        if !worktree.collaborator_user_ids.contains(&user_id) {
1056            Err(anyhow!("no such worktree"))?;
1057        }
1058
1059        let share = worktree.share_mut()?;
1060        connection.worktrees.insert(worktree_id);
1061
1062        let mut replica_id = 1;
1063        while share.active_replica_ids.contains(&replica_id) {
1064            replica_id += 1;
1065        }
1066        share.active_replica_ids.insert(replica_id);
1067        share.guest_connection_ids.insert(connection_id, replica_id);
1068        return Ok((replica_id, worktree));
1069    }
1070
1071    fn read_worktree(
1072        &self,
1073        worktree_id: u64,
1074        connection_id: ConnectionId,
1075    ) -> tide::Result<&Worktree> {
1076        let worktree = self
1077            .worktrees
1078            .get(&worktree_id)
1079            .ok_or_else(|| anyhow!("worktree not found"))?;
1080
1081        if worktree.host_connection_id == connection_id
1082            || worktree
1083                .share()?
1084                .guest_connection_ids
1085                .contains_key(&connection_id)
1086        {
1087            Ok(worktree)
1088        } else {
1089            Err(anyhow!(
1090                "{} is not a member of worktree {}",
1091                connection_id,
1092                worktree_id
1093            ))?
1094        }
1095    }
1096
1097    fn write_worktree(
1098        &mut self,
1099        worktree_id: u64,
1100        connection_id: ConnectionId,
1101    ) -> tide::Result<&mut Worktree> {
1102        let worktree = self
1103            .worktrees
1104            .get_mut(&worktree_id)
1105            .ok_or_else(|| anyhow!("worktree not found"))?;
1106
1107        if worktree.host_connection_id == connection_id
1108            || worktree.share.as_ref().map_or(false, |share| {
1109                share.guest_connection_ids.contains_key(&connection_id)
1110            })
1111        {
1112            Ok(worktree)
1113        } else {
1114            Err(anyhow!(
1115                "{} is not a member of worktree {}",
1116                connection_id,
1117                worktree_id
1118            ))?
1119        }
1120    }
1121
1122    fn add_worktree(&mut self, worktree: Worktree) -> u64 {
1123        let worktree_id = self.next_worktree_id;
1124        for collaborator_user_id in &worktree.collaborator_user_ids {
1125            self.visible_worktrees_by_user_id
1126                .entry(*collaborator_user_id)
1127                .or_default()
1128                .insert(worktree_id);
1129        }
1130        self.next_worktree_id += 1;
1131        self.worktrees.insert(worktree_id, worktree);
1132        worktree_id
1133    }
1134
1135    fn remove_worktree(&mut self, worktree_id: u64) {
1136        let worktree = self.worktrees.remove(&worktree_id).unwrap();
1137        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
1138            connection.worktrees.remove(&worktree_id);
1139        }
1140        if let Some(share) = worktree.share {
1141            for connection_id in share.guest_connection_ids.keys() {
1142                if let Some(connection) = self.connections.get_mut(connection_id) {
1143                    connection.worktrees.remove(&worktree_id);
1144                }
1145            }
1146        }
1147        for collaborator_user_id in worktree.collaborator_user_ids {
1148            if let Some(visible_worktrees) = self
1149                .visible_worktrees_by_user_id
1150                .get_mut(&collaborator_user_id)
1151            {
1152                visible_worktrees.remove(&worktree_id);
1153            }
1154        }
1155    }
1156}
1157
1158impl Worktree {
1159    pub fn connection_ids(&self) -> Vec<ConnectionId> {
1160        if let Some(share) = &self.share {
1161            share
1162                .guest_connection_ids
1163                .keys()
1164                .copied()
1165                .chain(Some(self.host_connection_id))
1166                .collect()
1167        } else {
1168            vec![self.host_connection_id]
1169        }
1170    }
1171
1172    fn share(&self) -> tide::Result<&WorktreeShare> {
1173        Ok(self
1174            .share
1175            .as_ref()
1176            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1177    }
1178
1179    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
1180        Ok(self
1181            .share
1182            .as_mut()
1183            .ok_or_else(|| anyhow!("worktree is not shared"))?)
1184    }
1185}
1186
1187impl Channel {
1188    fn connection_ids(&self) -> Vec<ConnectionId> {
1189        self.connection_ids.iter().copied().collect()
1190    }
1191}
1192
1193pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
1194    let server = Server::new(app.state().clone(), rpc.clone(), None);
1195    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
1196        let user_id = request.ext::<UserId>().copied();
1197        let server = server.clone();
1198        async move {
1199            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1200
1201            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
1202            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
1203            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
1204
1205            if !upgrade_requested {
1206                return Ok(Response::new(StatusCode::UpgradeRequired));
1207            }
1208
1209            let header = match request.header("Sec-Websocket-Key") {
1210                Some(h) => h.as_str(),
1211                None => return Err(anyhow!("expected sec-websocket-key"))?,
1212            };
1213
1214            let mut response = Response::new(StatusCode::SwitchingProtocols);
1215            response.insert_header(UPGRADE, "websocket");
1216            response.insert_header(CONNECTION, "Upgrade");
1217            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
1218            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
1219            response.insert_header("Sec-Websocket-Version", "13");
1220
1221            let http_res: &mut tide::http::Response = response.as_mut();
1222            let upgrade_receiver = http_res.recv_upgrade().await;
1223            let addr = request.remote().unwrap_or("unknown").to_string();
1224            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
1225            task::spawn(async move {
1226                if let Some(stream) = upgrade_receiver.await {
1227                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
1228                }
1229            });
1230
1231            Ok(response)
1232        }
1233    });
1234}
1235
1236fn header_contains_ignore_case<T>(
1237    request: &tide::Request<T>,
1238    header_name: HeaderName,
1239    value: &str,
1240) -> bool {
1241    request
1242        .header(header_name)
1243        .map(|h| {
1244            h.as_str()
1245                .split(',')
1246                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
1247        })
1248        .unwrap_or(false)
1249}
1250
1251#[cfg(test)]
1252mod tests {
1253    use super::*;
1254    use crate::{
1255        auth,
1256        db::{tests::TestDb, UserId},
1257        github, AppState, Config,
1258    };
1259    use async_std::{sync::RwLockReadGuard, task};
1260    use gpui::TestAppContext;
1261    use parking_lot::Mutex;
1262    use postage::{mpsc, watch};
1263    use serde_json::json;
1264    use sqlx::types::time::OffsetDateTime;
1265    use std::{
1266        path::Path,
1267        sync::{
1268            atomic::{AtomicBool, Ordering::SeqCst},
1269            Arc,
1270        },
1271        time::Duration,
1272    };
1273    use zed::{
1274        channel::{Channel, ChannelDetails, ChannelList},
1275        editor::{Editor, Insert},
1276        fs::{FakeFs, Fs as _},
1277        language::LanguageRegistry,
1278        rpc::{self, Client, Credentials, EstablishConnectionError},
1279        settings,
1280        test::FakeHttpClient,
1281        user::UserStore,
1282        worktree::Worktree,
1283    };
1284    use zrpc::Peer;
1285
1286    #[gpui::test]
1287    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1288        let (window_b, _) = cx_b.add_window(|_| EmptyView);
1289        let settings = cx_b.read(settings::test).1;
1290        let lang_registry = Arc::new(LanguageRegistry::new());
1291
1292        // Connect to a server as 2 clients.
1293        let mut server = TestServer::start().await;
1294        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1295        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1296
1297        cx_a.foreground().forbid_parking();
1298
1299        // Share a local worktree as client A
1300        let fs = Arc::new(FakeFs::new());
1301        fs.insert_tree(
1302            "/a",
1303            json!({
1304                ".zed.toml": r#"collaborators = ["user_b"]"#,
1305                "a.txt": "a-contents",
1306                "b.txt": "b-contents",
1307            }),
1308        )
1309        .await;
1310        let worktree_a = Worktree::open_local(
1311            client_a.clone(),
1312            "/a".as_ref(),
1313            fs,
1314            lang_registry.clone(),
1315            &mut cx_a.to_async(),
1316        )
1317        .await
1318        .unwrap();
1319        worktree_a
1320            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1321            .await;
1322        let worktree_id = worktree_a
1323            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1324            .await
1325            .unwrap();
1326
1327        // Join that worktree as client B, and see that a guest has joined as client A.
1328        let worktree_b = Worktree::open_remote(
1329            client_b.clone(),
1330            worktree_id,
1331            lang_registry.clone(),
1332            &mut cx_b.to_async(),
1333        )
1334        .await
1335        .unwrap();
1336        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1337        worktree_a
1338            .condition(&cx_a, |tree, _| {
1339                tree.peers()
1340                    .values()
1341                    .any(|replica_id| *replica_id == replica_id_b)
1342            })
1343            .await;
1344
1345        // Open the same file as client B and client A.
1346        let buffer_b = worktree_b
1347            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1348            .await
1349            .unwrap();
1350        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1351        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1352        let buffer_a = worktree_a
1353            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1354            .await
1355            .unwrap();
1356
1357        // Create a selection set as client B and see that selection set as client A.
1358        let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1359        buffer_a
1360            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1361            .await;
1362
1363        // Edit the buffer as client B and see that edit as client A.
1364        editor_b.update(&mut cx_b, |editor, cx| {
1365            editor.insert(&Insert("ok, ".into()), cx)
1366        });
1367        buffer_a
1368            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1369            .await;
1370
1371        // Remove the selection set as client B, see those selections disappear as client A.
1372        cx_b.update(move |_| drop(editor_b));
1373        buffer_a
1374            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1375            .await;
1376
1377        // Close the buffer as client A, see that the buffer is closed.
1378        cx_a.update(move |_| drop(buffer_a));
1379        worktree_a
1380            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1381            .await;
1382
1383        // Dropping the worktree removes client B from client A's peers.
1384        cx_b.update(move |_| drop(worktree_b));
1385        worktree_a
1386            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1387            .await;
1388    }
1389
1390    #[gpui::test]
1391    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1392        mut cx_a: TestAppContext,
1393        mut cx_b: TestAppContext,
1394        mut cx_c: TestAppContext,
1395    ) {
1396        cx_a.foreground().forbid_parking();
1397        let lang_registry = Arc::new(LanguageRegistry::new());
1398
1399        // Connect to a server as 3 clients.
1400        let mut server = TestServer::start().await;
1401        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1402        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1403        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1404
1405        let fs = Arc::new(FakeFs::new());
1406
1407        // Share a worktree as client A.
1408        fs.insert_tree(
1409            "/a",
1410            json!({
1411                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1412                "file1": "",
1413                "file2": ""
1414            }),
1415        )
1416        .await;
1417
1418        let worktree_a = Worktree::open_local(
1419            client_a.clone(),
1420            "/a".as_ref(),
1421            fs.clone(),
1422            lang_registry.clone(),
1423            &mut cx_a.to_async(),
1424        )
1425        .await
1426        .unwrap();
1427        worktree_a
1428            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1429            .await;
1430        let worktree_id = worktree_a
1431            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1432            .await
1433            .unwrap();
1434
1435        // Join that worktree as clients B and C.
1436        let worktree_b = Worktree::open_remote(
1437            client_b.clone(),
1438            worktree_id,
1439            lang_registry.clone(),
1440            &mut cx_b.to_async(),
1441        )
1442        .await
1443        .unwrap();
1444        let worktree_c = Worktree::open_remote(
1445            client_c.clone(),
1446            worktree_id,
1447            lang_registry.clone(),
1448            &mut cx_c.to_async(),
1449        )
1450        .await
1451        .unwrap();
1452
1453        // Open and edit a buffer as both guests B and C.
1454        let buffer_b = worktree_b
1455            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1456            .await
1457            .unwrap();
1458        let buffer_c = worktree_c
1459            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1460            .await
1461            .unwrap();
1462        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1463        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1464
1465        // Open and edit that buffer as the host.
1466        let buffer_a = worktree_a
1467            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1468            .await
1469            .unwrap();
1470
1471        buffer_a
1472            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1473            .await;
1474        buffer_a.update(&mut cx_a, |buf, cx| {
1475            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1476        });
1477
1478        // Wait for edits to propagate
1479        buffer_a
1480            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1481            .await;
1482        buffer_b
1483            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1484            .await;
1485        buffer_c
1486            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1487            .await;
1488
1489        // Edit the buffer as the host and concurrently save as guest B.
1490        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1491        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1492        save_b.await.unwrap();
1493        assert_eq!(
1494            fs.load("/a/file1".as_ref()).await.unwrap(),
1495            "hi-a, i-am-c, i-am-b, i-am-a"
1496        );
1497        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1498        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1499        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1500
1501        // Make changes on host's file system, see those changes on the guests.
1502        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1503            .await
1504            .unwrap();
1505        fs.insert_file(Path::new("/a/file4"), "4".into())
1506            .await
1507            .unwrap();
1508
1509        worktree_b
1510            .condition(&cx_b, |tree, _| tree.file_count() == 4)
1511            .await;
1512        worktree_c
1513            .condition(&cx_c, |tree, _| tree.file_count() == 4)
1514            .await;
1515        worktree_b.read_with(&cx_b, |tree, _| {
1516            assert_eq!(
1517                tree.paths()
1518                    .map(|p| p.to_string_lossy())
1519                    .collect::<Vec<_>>(),
1520                &[".zed.toml", "file1", "file3", "file4"]
1521            )
1522        });
1523        worktree_c.read_with(&cx_c, |tree, _| {
1524            assert_eq!(
1525                tree.paths()
1526                    .map(|p| p.to_string_lossy())
1527                    .collect::<Vec<_>>(),
1528                &[".zed.toml", "file1", "file3", "file4"]
1529            )
1530        });
1531    }
1532
1533    #[gpui::test]
1534    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1535        cx_a.foreground().forbid_parking();
1536        let lang_registry = Arc::new(LanguageRegistry::new());
1537
1538        // Connect to a server as 2 clients.
1539        let mut server = TestServer::start().await;
1540        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1541        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1542
1543        // Share a local worktree as client A
1544        let fs = Arc::new(FakeFs::new());
1545        fs.insert_tree(
1546            "/dir",
1547            json!({
1548                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1549                "a.txt": "a-contents",
1550            }),
1551        )
1552        .await;
1553
1554        let worktree_a = Worktree::open_local(
1555            client_a.clone(),
1556            "/dir".as_ref(),
1557            fs,
1558            lang_registry.clone(),
1559            &mut cx_a.to_async(),
1560        )
1561        .await
1562        .unwrap();
1563        worktree_a
1564            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1565            .await;
1566        let worktree_id = worktree_a
1567            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1568            .await
1569            .unwrap();
1570
1571        // Join that worktree as client B, and see that a guest has joined as client A.
1572        let worktree_b = Worktree::open_remote(
1573            client_b.clone(),
1574            worktree_id,
1575            lang_registry.clone(),
1576            &mut cx_b.to_async(),
1577        )
1578        .await
1579        .unwrap();
1580
1581        let buffer_b = worktree_b
1582            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1583            .await
1584            .unwrap();
1585        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1586
1587        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1588        buffer_b.read_with(&cx_b, |buf, _| {
1589            assert!(buf.is_dirty());
1590            assert!(!buf.has_conflict());
1591        });
1592
1593        buffer_b
1594            .update(&mut cx_b, |buf, cx| buf.save(cx))
1595            .unwrap()
1596            .await
1597            .unwrap();
1598        worktree_b
1599            .condition(&cx_b, |_, cx| {
1600                buffer_b.read(cx).file().unwrap().mtime != mtime
1601            })
1602            .await;
1603        buffer_b.read_with(&cx_b, |buf, _| {
1604            assert!(!buf.is_dirty());
1605            assert!(!buf.has_conflict());
1606        });
1607
1608        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1609        buffer_b.read_with(&cx_b, |buf, _| {
1610            assert!(buf.is_dirty());
1611            assert!(!buf.has_conflict());
1612        });
1613    }
1614
1615    #[gpui::test]
1616    async fn test_editing_while_guest_opens_buffer(
1617        mut cx_a: TestAppContext,
1618        mut cx_b: TestAppContext,
1619    ) {
1620        cx_a.foreground().forbid_parking();
1621        let lang_registry = Arc::new(LanguageRegistry::new());
1622
1623        // Connect to a server as 2 clients.
1624        let mut server = TestServer::start().await;
1625        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1626        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1627
1628        // Share a local worktree as client A
1629        let fs = Arc::new(FakeFs::new());
1630        fs.insert_tree(
1631            "/dir",
1632            json!({
1633                ".zed.toml": r#"collaborators = ["user_b"]"#,
1634                "a.txt": "a-contents",
1635            }),
1636        )
1637        .await;
1638        let worktree_a = Worktree::open_local(
1639            client_a.clone(),
1640            "/dir".as_ref(),
1641            fs,
1642            lang_registry.clone(),
1643            &mut cx_a.to_async(),
1644        )
1645        .await
1646        .unwrap();
1647        worktree_a
1648            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1649            .await;
1650        let worktree_id = worktree_a
1651            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1652            .await
1653            .unwrap();
1654
1655        // Join that worktree as client B, and see that a guest has joined as client A.
1656        let worktree_b = Worktree::open_remote(
1657            client_b.clone(),
1658            worktree_id,
1659            lang_registry.clone(),
1660            &mut cx_b.to_async(),
1661        )
1662        .await
1663        .unwrap();
1664
1665        let buffer_a = worktree_a
1666            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1667            .await
1668            .unwrap();
1669        let buffer_b = cx_b
1670            .background()
1671            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1672
1673        task::yield_now().await;
1674        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1675
1676        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1677        let buffer_b = buffer_b.await.unwrap();
1678        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1679    }
1680
1681    #[gpui::test]
1682    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1683        cx_a.foreground().forbid_parking();
1684        let lang_registry = Arc::new(LanguageRegistry::new());
1685
1686        // Connect to a server as 2 clients.
1687        let mut server = TestServer::start().await;
1688        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1689        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1690
1691        // Share a local worktree as client A
1692        let fs = Arc::new(FakeFs::new());
1693        fs.insert_tree(
1694            "/a",
1695            json!({
1696                ".zed.toml": r#"collaborators = ["user_b"]"#,
1697                "a.txt": "a-contents",
1698                "b.txt": "b-contents",
1699            }),
1700        )
1701        .await;
1702        let worktree_a = Worktree::open_local(
1703            client_a.clone(),
1704            "/a".as_ref(),
1705            fs,
1706            lang_registry.clone(),
1707            &mut cx_a.to_async(),
1708        )
1709        .await
1710        .unwrap();
1711        worktree_a
1712            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1713            .await;
1714        let worktree_id = worktree_a
1715            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1716            .await
1717            .unwrap();
1718
1719        // Join that worktree as client B, and see that a guest has joined as client A.
1720        let _worktree_b = Worktree::open_remote(
1721            client_b.clone(),
1722            worktree_id,
1723            lang_registry.clone(),
1724            &mut cx_b.to_async(),
1725        )
1726        .await
1727        .unwrap();
1728        worktree_a
1729            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1730            .await;
1731
1732        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1733        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1734        worktree_a
1735            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1736            .await;
1737    }
1738
1739    #[gpui::test]
1740    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1741        cx_a.foreground().forbid_parking();
1742
1743        // Connect to a server as 2 clients.
1744        let mut server = TestServer::start().await;
1745        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1746        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1747
1748        // Create an org that includes these 2 users.
1749        let db = &server.app_state.db;
1750        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1751        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1752            .await
1753            .unwrap();
1754        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1755            .await
1756            .unwrap();
1757
1758        // Create a channel that includes all the users.
1759        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1760        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1761            .await
1762            .unwrap();
1763        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1764            .await
1765            .unwrap();
1766        db.create_channel_message(
1767            channel_id,
1768            current_user_id(&user_store_b),
1769            "hello A, it's B.",
1770            OffsetDateTime::now_utc(),
1771            1,
1772        )
1773        .await
1774        .unwrap();
1775
1776        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1777        channels_a
1778            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1779            .await;
1780        channels_a.read_with(&cx_a, |list, _| {
1781            assert_eq!(
1782                list.available_channels().unwrap(),
1783                &[ChannelDetails {
1784                    id: channel_id.to_proto(),
1785                    name: "test-channel".to_string()
1786                }]
1787            )
1788        });
1789        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1790            this.get_channel(channel_id.to_proto(), cx).unwrap()
1791        });
1792        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1793        channel_a
1794            .condition(&cx_a, |channel, _| {
1795                channel_messages(channel)
1796                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1797            })
1798            .await;
1799
1800        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1801        channels_b
1802            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1803            .await;
1804        channels_b.read_with(&cx_b, |list, _| {
1805            assert_eq!(
1806                list.available_channels().unwrap(),
1807                &[ChannelDetails {
1808                    id: channel_id.to_proto(),
1809                    name: "test-channel".to_string()
1810                }]
1811            )
1812        });
1813
1814        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1815            this.get_channel(channel_id.to_proto(), cx).unwrap()
1816        });
1817        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1818        channel_b
1819            .condition(&cx_b, |channel, _| {
1820                channel_messages(channel)
1821                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1822            })
1823            .await;
1824
1825        channel_a
1826            .update(&mut cx_a, |channel, cx| {
1827                channel
1828                    .send_message("oh, hi B.".to_string(), cx)
1829                    .unwrap()
1830                    .detach();
1831                let task = channel.send_message("sup".to_string(), cx).unwrap();
1832                assert_eq!(
1833                    channel_messages(channel),
1834                    &[
1835                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1836                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
1837                        ("user_a".to_string(), "sup".to_string(), true)
1838                    ]
1839                );
1840                task
1841            })
1842            .await
1843            .unwrap();
1844
1845        channel_b
1846            .condition(&cx_b, |channel, _| {
1847                channel_messages(channel)
1848                    == [
1849                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1850                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
1851                        ("user_a".to_string(), "sup".to_string(), false),
1852                    ]
1853            })
1854            .await;
1855
1856        assert_eq!(
1857            server.state().await.channels[&channel_id]
1858                .connection_ids
1859                .len(),
1860            2
1861        );
1862        cx_b.update(|_| drop(channel_b));
1863        server
1864            .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1865            .await;
1866
1867        cx_a.update(|_| drop(channel_a));
1868        server
1869            .condition(|state| !state.channels.contains_key(&channel_id))
1870            .await;
1871    }
1872
1873    #[gpui::test]
1874    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1875        cx_a.foreground().forbid_parking();
1876
1877        let mut server = TestServer::start().await;
1878        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1879
1880        let db = &server.app_state.db;
1881        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1882        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1883        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1884            .await
1885            .unwrap();
1886        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1887            .await
1888            .unwrap();
1889
1890        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1891        channels_a
1892            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1893            .await;
1894        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1895            this.get_channel(channel_id.to_proto(), cx).unwrap()
1896        });
1897
1898        // Messages aren't allowed to be too long.
1899        channel_a
1900            .update(&mut cx_a, |channel, cx| {
1901                let long_body = "this is long.\n".repeat(1024);
1902                channel.send_message(long_body, cx).unwrap()
1903            })
1904            .await
1905            .unwrap_err();
1906
1907        // Messages aren't allowed to be blank.
1908        channel_a.update(&mut cx_a, |channel, cx| {
1909            channel.send_message(String::new(), cx).unwrap_err()
1910        });
1911
1912        // Leading and trailing whitespace are trimmed.
1913        channel_a
1914            .update(&mut cx_a, |channel, cx| {
1915                channel
1916                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1917                    .unwrap()
1918            })
1919            .await
1920            .unwrap();
1921        assert_eq!(
1922            db.get_channel_messages(channel_id, 10, None)
1923                .await
1924                .unwrap()
1925                .iter()
1926                .map(|m| &m.body)
1927                .collect::<Vec<_>>(),
1928            &["surrounded by whitespace"]
1929        );
1930    }
1931
1932    #[gpui::test]
1933    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1934        cx_a.foreground().forbid_parking();
1935        let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
1936
1937        // Connect to a server as 2 clients.
1938        let mut server = TestServer::start().await;
1939        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1940        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1941        let mut status_b = client_b.status();
1942
1943        // Create an org that includes these 2 users.
1944        let db = &server.app_state.db;
1945        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1946        db.add_org_member(org_id, current_user_id(&user_store_a), false)
1947            .await
1948            .unwrap();
1949        db.add_org_member(org_id, current_user_id(&user_store_b), false)
1950            .await
1951            .unwrap();
1952
1953        // Create a channel that includes all the users.
1954        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1955        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
1956            .await
1957            .unwrap();
1958        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
1959            .await
1960            .unwrap();
1961        db.create_channel_message(
1962            channel_id,
1963            current_user_id(&user_store_b),
1964            "hello A, it's B.",
1965            OffsetDateTime::now_utc(),
1966            2,
1967        )
1968        .await
1969        .unwrap();
1970
1971        let user_store_a =
1972            UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
1973        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1974        channels_a
1975            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1976            .await;
1977
1978        channels_a.read_with(&cx_a, |list, _| {
1979            assert_eq!(
1980                list.available_channels().unwrap(),
1981                &[ChannelDetails {
1982                    id: channel_id.to_proto(),
1983                    name: "test-channel".to_string()
1984                }]
1985            )
1986        });
1987        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1988            this.get_channel(channel_id.to_proto(), cx).unwrap()
1989        });
1990        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1991        channel_a
1992            .condition(&cx_a, |channel, _| {
1993                channel_messages(channel)
1994                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1995            })
1996            .await;
1997
1998        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
1999        channels_b
2000            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
2001            .await;
2002        channels_b.read_with(&cx_b, |list, _| {
2003            assert_eq!(
2004                list.available_channels().unwrap(),
2005                &[ChannelDetails {
2006                    id: channel_id.to_proto(),
2007                    name: "test-channel".to_string()
2008                }]
2009            )
2010        });
2011
2012        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
2013            this.get_channel(channel_id.to_proto(), cx).unwrap()
2014        });
2015        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
2016        channel_b
2017            .condition(&cx_b, |channel, _| {
2018                channel_messages(channel)
2019                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2020            })
2021            .await;
2022
2023        // Disconnect client B, ensuring we can still access its cached channel data.
2024        server.forbid_connections();
2025        server.disconnect_client(current_user_id(&user_store_b));
2026        while !matches!(
2027            status_b.recv().await,
2028            Some(rpc::Status::ReconnectionError { .. })
2029        ) {}
2030
2031        channels_b.read_with(&cx_b, |channels, _| {
2032            assert_eq!(
2033                channels.available_channels().unwrap(),
2034                [ChannelDetails {
2035                    id: channel_id.to_proto(),
2036                    name: "test-channel".to_string()
2037                }]
2038            )
2039        });
2040        channel_b.read_with(&cx_b, |channel, _| {
2041            assert_eq!(
2042                channel_messages(channel),
2043                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2044            )
2045        });
2046
2047        // Send a message from client B while it is disconnected.
2048        channel_b
2049            .update(&mut cx_b, |channel, cx| {
2050                let task = channel
2051                    .send_message("can you see this?".to_string(), cx)
2052                    .unwrap();
2053                assert_eq!(
2054                    channel_messages(channel),
2055                    &[
2056                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2057                        ("user_b".to_string(), "can you see this?".to_string(), true)
2058                    ]
2059                );
2060                task
2061            })
2062            .await
2063            .unwrap_err();
2064
2065        // Send a message from client A while B is disconnected.
2066        channel_a
2067            .update(&mut cx_a, |channel, cx| {
2068                channel
2069                    .send_message("oh, hi B.".to_string(), cx)
2070                    .unwrap()
2071                    .detach();
2072                let task = channel.send_message("sup".to_string(), cx).unwrap();
2073                assert_eq!(
2074                    channel_messages(channel),
2075                    &[
2076                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2077                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
2078                        ("user_a".to_string(), "sup".to_string(), true)
2079                    ]
2080                );
2081                task
2082            })
2083            .await
2084            .unwrap();
2085
2086        // Give client B a chance to reconnect.
2087        server.allow_connections();
2088        cx_b.foreground().advance_clock(Duration::from_secs(10));
2089
2090        // Verify that B sees the new messages upon reconnection, as well as the message client B
2091        // sent while offline.
2092        channel_b
2093            .condition(&cx_b, |channel, _| {
2094                channel_messages(channel)
2095                    == [
2096                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2097                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2098                        ("user_a".to_string(), "sup".to_string(), false),
2099                        ("user_b".to_string(), "can you see this?".to_string(), false),
2100                    ]
2101            })
2102            .await;
2103
2104        // Ensure client A and B can communicate normally after reconnection.
2105        channel_a
2106            .update(&mut cx_a, |channel, cx| {
2107                channel.send_message("you online?".to_string(), cx).unwrap()
2108            })
2109            .await
2110            .unwrap();
2111        channel_b
2112            .condition(&cx_b, |channel, _| {
2113                channel_messages(channel)
2114                    == [
2115                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2116                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2117                        ("user_a".to_string(), "sup".to_string(), false),
2118                        ("user_b".to_string(), "can you see this?".to_string(), false),
2119                        ("user_a".to_string(), "you online?".to_string(), false),
2120                    ]
2121            })
2122            .await;
2123
2124        channel_b
2125            .update(&mut cx_b, |channel, cx| {
2126                channel.send_message("yep".to_string(), cx).unwrap()
2127            })
2128            .await
2129            .unwrap();
2130        channel_a
2131            .condition(&cx_a, |channel, _| {
2132                channel_messages(channel)
2133                    == [
2134                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2135                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2136                        ("user_a".to_string(), "sup".to_string(), false),
2137                        ("user_b".to_string(), "can you see this?".to_string(), false),
2138                        ("user_a".to_string(), "you online?".to_string(), false),
2139                        ("user_b".to_string(), "yep".to_string(), false),
2140                    ]
2141            })
2142            .await;
2143    }
2144
2145    struct TestServer {
2146        peer: Arc<Peer>,
2147        app_state: Arc<AppState>,
2148        server: Arc<Server>,
2149        notifications: mpsc::Receiver<()>,
2150        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2151        forbid_connections: Arc<AtomicBool>,
2152        _test_db: TestDb,
2153    }
2154
2155    impl TestServer {
2156        async fn start() -> Self {
2157            let test_db = TestDb::new();
2158            let app_state = Self::build_app_state(&test_db).await;
2159            let peer = Peer::new();
2160            let notifications = mpsc::channel(128);
2161            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2162            Self {
2163                peer,
2164                app_state,
2165                server,
2166                notifications: notifications.1,
2167                connection_killers: Default::default(),
2168                forbid_connections: Default::default(),
2169                _test_db: test_db,
2170            }
2171        }
2172
2173        async fn create_client(
2174            &mut self,
2175            cx: &mut TestAppContext,
2176            name: &str,
2177        ) -> (Arc<Client>, Arc<UserStore>) {
2178            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2179            let client_name = name.to_string();
2180            let mut client = Client::new();
2181            let server = self.server.clone();
2182            let connection_killers = self.connection_killers.clone();
2183            let forbid_connections = self.forbid_connections.clone();
2184            Arc::get_mut(&mut client)
2185                .unwrap()
2186                .override_authenticate(move |cx| {
2187                    cx.spawn(|_| async move {
2188                        let access_token = "the-token".to_string();
2189                        Ok(Credentials {
2190                            user_id: user_id.0 as u64,
2191                            access_token,
2192                        })
2193                    })
2194                })
2195                .override_establish_connection(move |credentials, cx| {
2196                    assert_eq!(credentials.user_id, user_id.0 as u64);
2197                    assert_eq!(credentials.access_token, "the-token");
2198
2199                    let server = server.clone();
2200                    let connection_killers = connection_killers.clone();
2201                    let forbid_connections = forbid_connections.clone();
2202                    let client_name = client_name.clone();
2203                    cx.spawn(move |cx| async move {
2204                        if forbid_connections.load(SeqCst) {
2205                            Err(EstablishConnectionError::other(anyhow!(
2206                                "server is forbidding connections"
2207                            )))
2208                        } else {
2209                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2210                            connection_killers.lock().insert(user_id, kill_conn);
2211                            cx.background()
2212                                .spawn(server.handle_connection(server_conn, client_name, user_id))
2213                                .detach();
2214                            Ok(client_conn)
2215                        }
2216                    })
2217                });
2218
2219            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2220            client
2221                .authenticate_and_connect(&cx.to_async())
2222                .await
2223                .unwrap();
2224
2225            let user_store = UserStore::new(client.clone(), http, &cx.background());
2226            let mut authed_user = user_store.watch_current_user();
2227            while authed_user.recv().await.unwrap().is_none() {}
2228
2229            (client, user_store)
2230        }
2231
2232        fn disconnect_client(&self, user_id: UserId) {
2233            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2234                let _ = kill_conn.try_send(Some(()));
2235            }
2236        }
2237
2238        fn forbid_connections(&self) {
2239            self.forbid_connections.store(true, SeqCst);
2240        }
2241
2242        fn allow_connections(&self) {
2243            self.forbid_connections.store(false, SeqCst);
2244        }
2245
2246        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2247            let mut config = Config::default();
2248            config.session_secret = "a".repeat(32);
2249            config.database_url = test_db.url.clone();
2250            let github_client = github::AppClient::test();
2251            Arc::new(AppState {
2252                db: test_db.db().clone(),
2253                handlebars: Default::default(),
2254                auth_client: auth::build_client("", ""),
2255                repo_client: github::RepoClient::test(&github_client),
2256                github_client,
2257                config,
2258            })
2259        }
2260
2261        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
2262            self.server.state.read().await
2263        }
2264
2265        async fn condition<F>(&mut self, mut predicate: F)
2266        where
2267            F: FnMut(&ServerState) -> bool,
2268        {
2269            async_std::future::timeout(Duration::from_millis(500), async {
2270                while !(predicate)(&*self.server.state.read().await) {
2271                    self.notifications.recv().await;
2272                }
2273            })
2274            .await
2275            .expect("condition timed out");
2276        }
2277    }
2278
2279    impl Drop for TestServer {
2280        fn drop(&mut self) {
2281            task::block_on(self.peer.reset());
2282        }
2283    }
2284
2285    fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
2286        UserId::from_proto(user_store.current_user().unwrap().id)
2287    }
2288
2289    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2290        channel
2291            .messages()
2292            .cursor::<(), ()>()
2293            .map(|m| {
2294                (
2295                    m.sender.github_login.clone(),
2296                    m.body.clone(),
2297                    m.is_pending(),
2298                )
2299            })
2300            .collect()
2301    }
2302
2303    struct EmptyView;
2304
2305    impl gpui::Entity for EmptyView {
2306        type Event = ();
2307    }
2308
2309    impl gpui::View for EmptyView {
2310        fn ui_name() -> &'static str {
2311            "empty view"
2312        }
2313
2314        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2315            gpui::Element::boxed(gpui::elements::Empty)
2316        }
2317    }
2318}