rpc.rs

   1use super::{
   2    auth,
   3    db::{ChannelId, MessageId, UserId},
   4    AppState,
   5};
   6use anyhow::anyhow;
   7use async_std::{sync::RwLock, task};
   8use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
   9use futures::{future::BoxFuture, FutureExt};
  10use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  11use sha1::{Digest as _, Sha1};
  12use std::{
  13    any::TypeId,
  14    collections::{hash_map, HashMap, HashSet},
  15    future::Future,
  16    mem,
  17    sync::Arc,
  18    time::Instant,
  19};
  20use surf::StatusCode;
  21use tide::log;
  22use tide::{
  23    http::headers::{HeaderName, CONNECTION, UPGRADE},
  24    Request, Response,
  25};
  26use time::OffsetDateTime;
  27use zrpc::{
  28    auth::random_token,
  29    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  30    Conn, ConnectionId, Peer, TypedEnvelope,
  31};
  32
  33type ReplicaId = u16;
  34
  35type MessageHandler = Box<
  36    dyn Send
  37        + Sync
  38        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  39>;
  40
  41pub struct Server {
  42    peer: Arc<Peer>,
  43    state: RwLock<ServerState>,
  44    app_state: Arc<AppState>,
  45    handlers: HashMap<TypeId, MessageHandler>,
  46    notifications: Option<mpsc::Sender<()>>,
  47}
  48
  49#[derive(Default)]
  50struct ServerState {
  51    connections: HashMap<ConnectionId, Connection>,
  52    pub worktrees: HashMap<u64, Worktree>,
  53    channels: HashMap<ChannelId, Channel>,
  54    next_worktree_id: u64,
  55}
  56
  57struct Connection {
  58    user_id: UserId,
  59    worktrees: HashSet<u64>,
  60    channels: HashSet<ChannelId>,
  61}
  62
  63struct Worktree {
  64    host_connection_id: Option<ConnectionId>,
  65    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
  66    active_replica_ids: HashSet<ReplicaId>,
  67    access_token: String,
  68    root_name: String,
  69    entries: HashMap<u64, proto::Entry>,
  70}
  71
  72#[derive(Default)]
  73struct Channel {
  74    connection_ids: HashSet<ConnectionId>,
  75}
  76
  77const MESSAGE_COUNT_PER_PAGE: usize = 100;
  78const MAX_MESSAGE_LEN: usize = 1024;
  79
  80impl Server {
  81    pub fn new(
  82        app_state: Arc<AppState>,
  83        peer: Arc<Peer>,
  84        notifications: Option<mpsc::Sender<()>>,
  85    ) -> Arc<Self> {
  86        let mut server = Self {
  87            peer,
  88            app_state,
  89            state: Default::default(),
  90            handlers: Default::default(),
  91            notifications,
  92        };
  93
  94        server
  95            .add_handler(Server::share_worktree)
  96            .add_handler(Server::join_worktree)
  97            .add_handler(Server::update_worktree)
  98            .add_handler(Server::close_worktree)
  99            .add_handler(Server::open_buffer)
 100            .add_handler(Server::close_buffer)
 101            .add_handler(Server::update_buffer)
 102            .add_handler(Server::buffer_saved)
 103            .add_handler(Server::save_buffer)
 104            .add_handler(Server::get_channels)
 105            .add_handler(Server::get_users)
 106            .add_handler(Server::join_channel)
 107            .add_handler(Server::leave_channel)
 108            .add_handler(Server::send_channel_message)
 109            .add_handler(Server::get_channel_messages);
 110
 111        Arc::new(server)
 112    }
 113
 114    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 115    where
 116        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
 117        Fut: 'static + Send + Future<Output = tide::Result<()>>,
 118        M: EnvelopedMessage,
 119    {
 120        let prev_handler = self.handlers.insert(
 121            TypeId::of::<M>(),
 122            Box::new(move |server, envelope| {
 123                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 124                (handler)(server, *envelope).boxed()
 125            }),
 126        );
 127        if prev_handler.is_some() {
 128            panic!("registered a handler for the same message twice");
 129        }
 130        self
 131    }
 132
 133    pub fn handle_connection(
 134        self: &Arc<Self>,
 135        connection: Conn,
 136        addr: String,
 137        user_id: UserId,
 138    ) -> impl Future<Output = ()> {
 139        let this = self.clone();
 140        async move {
 141            let (connection_id, handle_io, mut incoming_rx) =
 142                this.peer.add_connection(connection).await;
 143            this.add_connection(connection_id, user_id).await;
 144
 145            let handle_io = handle_io.fuse();
 146            futures::pin_mut!(handle_io);
 147            loop {
 148                let next_message = incoming_rx.recv().fuse();
 149                futures::pin_mut!(next_message);
 150                futures::select_biased! {
 151                    message = next_message => {
 152                        if let Some(message) = message {
 153                            let start_time = Instant::now();
 154                            log::info!("RPC message received: {}", message.payload_type_name());
 155                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 156                                if let Err(err) = (handler)(this.clone(), message).await {
 157                                    log::error!("error handling message: {:?}", err);
 158                                } else {
 159                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 160                                }
 161
 162                                if let Some(mut notifications) = this.notifications.clone() {
 163                                    let _ = notifications.send(()).await;
 164                                }
 165                            } else {
 166                                log::warn!("unhandled message: {}", message.payload_type_name());
 167                            }
 168                        } else {
 169                            log::info!("rpc connection closed {:?}", addr);
 170                            break;
 171                        }
 172                    }
 173                    handle_io = handle_io => {
 174                        if let Err(err) = handle_io {
 175                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 176                        }
 177                        break;
 178                    }
 179                }
 180            }
 181
 182            if let Err(err) = this.sign_out(connection_id).await {
 183                log::error!("error signing out connection {:?} - {:?}", addr, err);
 184            }
 185        }
 186    }
 187
 188    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
 189        self.peer.disconnect(connection_id).await;
 190        let worktree_ids = self.remove_connection(connection_id).await;
 191        for worktree_id in worktree_ids {
 192            let state = self.state.read().await;
 193            if let Some(worktree) = state.worktrees.get(&worktree_id) {
 194                broadcast(connection_id, worktree.connection_ids(), |conn_id| {
 195                    self.peer.send(
 196                        conn_id,
 197                        proto::RemovePeer {
 198                            worktree_id,
 199                            peer_id: connection_id.0,
 200                        },
 201                    )
 202                })
 203                .await?;
 204            }
 205        }
 206        Ok(())
 207    }
 208
 209    // Add a new connection associated with a given user.
 210    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
 211        self.state.write().await.connections.insert(
 212            connection_id,
 213            Connection {
 214                user_id,
 215                worktrees: Default::default(),
 216                channels: Default::default(),
 217            },
 218        );
 219    }
 220
 221    // Remove the given connection and its association with any worktrees.
 222    async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
 223        let mut worktree_ids = Vec::new();
 224        let mut state = self.state.write().await;
 225        if let Some(connection) = state.connections.remove(&connection_id) {
 226            for channel_id in connection.channels {
 227                if let Some(channel) = state.channels.get_mut(&channel_id) {
 228                    channel.connection_ids.remove(&connection_id);
 229                }
 230            }
 231            for worktree_id in connection.worktrees {
 232                if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
 233                    if worktree.host_connection_id == Some(connection_id) {
 234                        worktree_ids.push(worktree_id);
 235                    } else if let Some(replica_id) =
 236                        worktree.guest_connection_ids.remove(&connection_id)
 237                    {
 238                        worktree.active_replica_ids.remove(&replica_id);
 239                        worktree_ids.push(worktree_id);
 240                    }
 241                }
 242            }
 243        }
 244        worktree_ids
 245    }
 246
 247    async fn share_worktree(
 248        self: Arc<Server>,
 249        mut request: TypedEnvelope<proto::ShareWorktree>,
 250    ) -> tide::Result<()> {
 251        let mut state = self.state.write().await;
 252        let worktree_id = state.next_worktree_id;
 253        state.next_worktree_id += 1;
 254        let access_token = random_token();
 255        let worktree = request
 256            .payload
 257            .worktree
 258            .as_mut()
 259            .ok_or_else(|| anyhow!("missing worktree"))?;
 260        let entries = mem::take(&mut worktree.entries)
 261            .into_iter()
 262            .map(|entry| (entry.id, entry))
 263            .collect();
 264        state.worktrees.insert(
 265            worktree_id,
 266            Worktree {
 267                host_connection_id: Some(request.sender_id),
 268                guest_connection_ids: Default::default(),
 269                active_replica_ids: Default::default(),
 270                access_token: access_token.clone(),
 271                root_name: mem::take(&mut worktree.root_name),
 272                entries,
 273            },
 274        );
 275
 276        self.peer
 277            .respond(
 278                request.receipt(),
 279                proto::ShareWorktreeResponse {
 280                    worktree_id,
 281                    access_token,
 282                },
 283            )
 284            .await?;
 285        Ok(())
 286    }
 287
 288    async fn join_worktree(
 289        self: Arc<Server>,
 290        request: TypedEnvelope<proto::OpenWorktree>,
 291    ) -> tide::Result<()> {
 292        let worktree_id = request.payload.worktree_id;
 293        let access_token = &request.payload.access_token;
 294
 295        let mut state = self.state.write().await;
 296        if let Some((peer_replica_id, worktree)) =
 297            state.join_worktree(request.sender_id, worktree_id, access_token)
 298        {
 299            let mut peers = Vec::new();
 300            if let Some(host_connection_id) = worktree.host_connection_id {
 301                peers.push(proto::Peer {
 302                    peer_id: host_connection_id.0,
 303                    replica_id: 0,
 304                });
 305            }
 306            for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
 307                if *peer_conn_id != request.sender_id {
 308                    peers.push(proto::Peer {
 309                        peer_id: peer_conn_id.0,
 310                        replica_id: *peer_replica_id as u32,
 311                    });
 312                }
 313            }
 314
 315            broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
 316                self.peer.send(
 317                    conn_id,
 318                    proto::AddPeer {
 319                        worktree_id,
 320                        peer: Some(proto::Peer {
 321                            peer_id: request.sender_id.0,
 322                            replica_id: peer_replica_id as u32,
 323                        }),
 324                    },
 325                )
 326            })
 327            .await?;
 328            self.peer
 329                .respond(
 330                    request.receipt(),
 331                    proto::OpenWorktreeResponse {
 332                        worktree_id,
 333                        worktree: Some(proto::Worktree {
 334                            root_name: worktree.root_name.clone(),
 335                            entries: worktree.entries.values().cloned().collect(),
 336                        }),
 337                        replica_id: peer_replica_id as u32,
 338                        peers,
 339                    },
 340                )
 341                .await?;
 342        } else {
 343            self.peer
 344                .respond(
 345                    request.receipt(),
 346                    proto::OpenWorktreeResponse {
 347                        worktree_id,
 348                        worktree: None,
 349                        replica_id: 0,
 350                        peers: Vec::new(),
 351                    },
 352                )
 353                .await?;
 354        }
 355
 356        Ok(())
 357    }
 358
 359    async fn update_worktree(
 360        self: Arc<Server>,
 361        request: TypedEnvelope<proto::UpdateWorktree>,
 362    ) -> tide::Result<()> {
 363        {
 364            let mut state = self.state.write().await;
 365            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 366            for entry_id in &request.payload.removed_entries {
 367                worktree.entries.remove(&entry_id);
 368            }
 369
 370            for entry in &request.payload.updated_entries {
 371                worktree.entries.insert(entry.id, entry.clone());
 372            }
 373        }
 374
 375        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 376            .await?;
 377        Ok(())
 378    }
 379
 380    async fn close_worktree(
 381        self: Arc<Server>,
 382        request: TypedEnvelope<proto::CloseWorktree>,
 383    ) -> tide::Result<()> {
 384        let connection_ids;
 385        {
 386            let mut state = self.state.write().await;
 387            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
 388            connection_ids = worktree.connection_ids();
 389            if worktree.host_connection_id == Some(request.sender_id) {
 390                worktree.host_connection_id = None;
 391            } else if let Some(replica_id) =
 392                worktree.guest_connection_ids.remove(&request.sender_id)
 393            {
 394                worktree.active_replica_ids.remove(&replica_id);
 395            }
 396        }
 397
 398        broadcast(request.sender_id, connection_ids, |conn_id| {
 399            self.peer.send(
 400                conn_id,
 401                proto::RemovePeer {
 402                    worktree_id: request.payload.worktree_id,
 403                    peer_id: request.sender_id.0,
 404                },
 405            )
 406        })
 407        .await?;
 408
 409        Ok(())
 410    }
 411
 412    async fn open_buffer(
 413        self: Arc<Server>,
 414        request: TypedEnvelope<proto::OpenBuffer>,
 415    ) -> tide::Result<()> {
 416        let receipt = request.receipt();
 417        let worktree_id = request.payload.worktree_id;
 418        let host_connection_id = self
 419            .state
 420            .read()
 421            .await
 422            .read_worktree(worktree_id, request.sender_id)?
 423            .host_connection_id()?;
 424
 425        let response = self
 426            .peer
 427            .forward_request(request.sender_id, host_connection_id, request.payload)
 428            .await?;
 429        self.peer.respond(receipt, response).await?;
 430        Ok(())
 431    }
 432
 433    async fn close_buffer(
 434        self: Arc<Server>,
 435        request: TypedEnvelope<proto::CloseBuffer>,
 436    ) -> tide::Result<()> {
 437        let host_connection_id = self
 438            .state
 439            .read()
 440            .await
 441            .read_worktree(request.payload.worktree_id, request.sender_id)?
 442            .host_connection_id()?;
 443
 444        self.peer
 445            .forward_send(request.sender_id, host_connection_id, request.payload)
 446            .await?;
 447
 448        Ok(())
 449    }
 450
 451    async fn save_buffer(
 452        self: Arc<Server>,
 453        request: TypedEnvelope<proto::SaveBuffer>,
 454    ) -> tide::Result<()> {
 455        let host;
 456        let guests;
 457        {
 458            let state = self.state.read().await;
 459            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
 460            host = worktree.host_connection_id()?;
 461            guests = worktree
 462                .guest_connection_ids
 463                .keys()
 464                .copied()
 465                .collect::<Vec<_>>();
 466        }
 467
 468        let sender = request.sender_id;
 469        let receipt = request.receipt();
 470        let response = self
 471            .peer
 472            .forward_request(sender, host, request.payload.clone())
 473            .await?;
 474
 475        broadcast(host, guests, |conn_id| {
 476            let response = response.clone();
 477            let peer = &self.peer;
 478            async move {
 479                if conn_id == sender {
 480                    peer.respond(receipt, response).await
 481                } else {
 482                    peer.forward_send(host, conn_id, response).await
 483                }
 484            }
 485        })
 486        .await?;
 487
 488        Ok(())
 489    }
 490
 491    async fn update_buffer(
 492        self: Arc<Server>,
 493        request: TypedEnvelope<proto::UpdateBuffer>,
 494    ) -> tide::Result<()> {
 495        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 496            .await
 497    }
 498
 499    async fn buffer_saved(
 500        self: Arc<Server>,
 501        request: TypedEnvelope<proto::BufferSaved>,
 502    ) -> tide::Result<()> {
 503        self.broadcast_in_worktree(request.payload.worktree_id, &request)
 504            .await
 505    }
 506
 507    async fn get_channels(
 508        self: Arc<Server>,
 509        request: TypedEnvelope<proto::GetChannels>,
 510    ) -> tide::Result<()> {
 511        let user_id = self
 512            .state
 513            .read()
 514            .await
 515            .user_id_for_connection(request.sender_id)?;
 516        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 517        self.peer
 518            .respond(
 519                request.receipt(),
 520                proto::GetChannelsResponse {
 521                    channels: channels
 522                        .into_iter()
 523                        .map(|chan| proto::Channel {
 524                            id: chan.id.to_proto(),
 525                            name: chan.name,
 526                        })
 527                        .collect(),
 528                },
 529            )
 530            .await?;
 531        Ok(())
 532    }
 533
 534    async fn get_users(
 535        self: Arc<Server>,
 536        request: TypedEnvelope<proto::GetUsers>,
 537    ) -> tide::Result<()> {
 538        let user_id = self
 539            .state
 540            .read()
 541            .await
 542            .user_id_for_connection(request.sender_id)?;
 543        let receipt = request.receipt();
 544        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 545        let users = self
 546            .app_state
 547            .db
 548            .get_users_by_ids(user_id, user_ids)
 549            .await?
 550            .into_iter()
 551            .map(|user| proto::User {
 552                id: user.id.to_proto(),
 553                github_login: user.github_login,
 554                avatar_url: String::new(),
 555            })
 556            .collect();
 557        self.peer
 558            .respond(receipt, proto::GetUsersResponse { users })
 559            .await?;
 560        Ok(())
 561    }
 562
 563    async fn join_channel(
 564        self: Arc<Self>,
 565        request: TypedEnvelope<proto::JoinChannel>,
 566    ) -> tide::Result<()> {
 567        let user_id = self
 568            .state
 569            .read()
 570            .await
 571            .user_id_for_connection(request.sender_id)?;
 572        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 573        if !self
 574            .app_state
 575            .db
 576            .can_user_access_channel(user_id, channel_id)
 577            .await?
 578        {
 579            Err(anyhow!("access denied"))?;
 580        }
 581
 582        self.state
 583            .write()
 584            .await
 585            .join_channel(request.sender_id, channel_id);
 586        let messages = self
 587            .app_state
 588            .db
 589            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 590            .await?
 591            .into_iter()
 592            .map(|msg| proto::ChannelMessage {
 593                id: msg.id.to_proto(),
 594                body: msg.body,
 595                timestamp: msg.sent_at.unix_timestamp() as u64,
 596                sender_id: msg.sender_id.to_proto(),
 597            })
 598            .collect::<Vec<_>>();
 599        self.peer
 600            .respond(
 601                request.receipt(),
 602                proto::JoinChannelResponse {
 603                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 604                    messages,
 605                },
 606            )
 607            .await?;
 608        Ok(())
 609    }
 610
 611    async fn leave_channel(
 612        self: Arc<Self>,
 613        request: TypedEnvelope<proto::LeaveChannel>,
 614    ) -> tide::Result<()> {
 615        let user_id = self
 616            .state
 617            .read()
 618            .await
 619            .user_id_for_connection(request.sender_id)?;
 620        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 621        if !self
 622            .app_state
 623            .db
 624            .can_user_access_channel(user_id, channel_id)
 625            .await?
 626        {
 627            Err(anyhow!("access denied"))?;
 628        }
 629
 630        self.state
 631            .write()
 632            .await
 633            .leave_channel(request.sender_id, channel_id);
 634
 635        Ok(())
 636    }
 637
 638    async fn send_channel_message(
 639        self: Arc<Self>,
 640        request: TypedEnvelope<proto::SendChannelMessage>,
 641    ) -> tide::Result<()> {
 642        let receipt = request.receipt();
 643        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 644        let user_id;
 645        let connection_ids;
 646        {
 647            let state = self.state.read().await;
 648            user_id = state.user_id_for_connection(request.sender_id)?;
 649            if let Some(channel) = state.channels.get(&channel_id) {
 650                connection_ids = channel.connection_ids();
 651            } else {
 652                return Ok(());
 653            }
 654        }
 655
 656        // Validate the message body.
 657        let body = request.payload.body.trim().to_string();
 658        if body.len() > MAX_MESSAGE_LEN {
 659            self.peer
 660                .respond_with_error(
 661                    receipt,
 662                    proto::Error {
 663                        message: "message is too long".to_string(),
 664                    },
 665                )
 666                .await?;
 667            return Ok(());
 668        }
 669        if body.is_empty() {
 670            self.peer
 671                .respond_with_error(
 672                    receipt,
 673                    proto::Error {
 674                        message: "message can't be blank".to_string(),
 675                    },
 676                )
 677                .await?;
 678            return Ok(());
 679        }
 680
 681        let timestamp = OffsetDateTime::now_utc();
 682        let message_id = self
 683            .app_state
 684            .db
 685            .create_channel_message(channel_id, user_id, &body, timestamp)
 686            .await?
 687            .to_proto();
 688        let message = proto::ChannelMessage {
 689            sender_id: user_id.to_proto(),
 690            id: message_id,
 691            body,
 692            timestamp: timestamp.unix_timestamp() as u64,
 693        };
 694        broadcast(request.sender_id, connection_ids, |conn_id| {
 695            self.peer.send(
 696                conn_id,
 697                proto::ChannelMessageSent {
 698                    channel_id: channel_id.to_proto(),
 699                    message: Some(message.clone()),
 700                },
 701            )
 702        })
 703        .await?;
 704        self.peer
 705            .respond(
 706                receipt,
 707                proto::SendChannelMessageResponse {
 708                    message: Some(message),
 709                },
 710            )
 711            .await?;
 712        Ok(())
 713    }
 714
 715    async fn get_channel_messages(
 716        self: Arc<Self>,
 717        request: TypedEnvelope<proto::GetChannelMessages>,
 718    ) -> tide::Result<()> {
 719        let user_id = self
 720            .state
 721            .read()
 722            .await
 723            .user_id_for_connection(request.sender_id)?;
 724        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 725        if !self
 726            .app_state
 727            .db
 728            .can_user_access_channel(user_id, channel_id)
 729            .await?
 730        {
 731            Err(anyhow!("access denied"))?;
 732        }
 733
 734        let messages = self
 735            .app_state
 736            .db
 737            .get_channel_messages(
 738                channel_id,
 739                MESSAGE_COUNT_PER_PAGE,
 740                Some(MessageId::from_proto(request.payload.before_message_id)),
 741            )
 742            .await?
 743            .into_iter()
 744            .map(|msg| proto::ChannelMessage {
 745                id: msg.id.to_proto(),
 746                body: msg.body,
 747                timestamp: msg.sent_at.unix_timestamp() as u64,
 748                sender_id: msg.sender_id.to_proto(),
 749            })
 750            .collect::<Vec<_>>();
 751        self.peer
 752            .respond(
 753                request.receipt(),
 754                proto::GetChannelMessagesResponse {
 755                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 756                    messages,
 757                },
 758            )
 759            .await?;
 760        Ok(())
 761    }
 762
 763    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
 764        &self,
 765        worktree_id: u64,
 766        message: &TypedEnvelope<T>,
 767    ) -> tide::Result<()> {
 768        let connection_ids = self
 769            .state
 770            .read()
 771            .await
 772            .read_worktree(worktree_id, message.sender_id)?
 773            .connection_ids();
 774
 775        broadcast(message.sender_id, connection_ids, |conn_id| {
 776            self.peer
 777                .forward_send(message.sender_id, conn_id, message.payload.clone())
 778        })
 779        .await?;
 780
 781        Ok(())
 782    }
 783}
 784
 785pub async fn broadcast<F, T>(
 786    sender_id: ConnectionId,
 787    receiver_ids: Vec<ConnectionId>,
 788    mut f: F,
 789) -> anyhow::Result<()>
 790where
 791    F: FnMut(ConnectionId) -> T,
 792    T: Future<Output = anyhow::Result<()>>,
 793{
 794    let futures = receiver_ids
 795        .into_iter()
 796        .filter(|id| *id != sender_id)
 797        .map(|id| f(id));
 798    futures::future::try_join_all(futures).await?;
 799    Ok(())
 800}
 801
 802impl ServerState {
 803    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
 804        if let Some(connection) = self.connections.get_mut(&connection_id) {
 805            connection.channels.insert(channel_id);
 806            self.channels
 807                .entry(channel_id)
 808                .or_default()
 809                .connection_ids
 810                .insert(connection_id);
 811        }
 812    }
 813
 814    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
 815        if let Some(connection) = self.connections.get_mut(&connection_id) {
 816            connection.channels.remove(&channel_id);
 817            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
 818                entry.get_mut().connection_ids.remove(&connection_id);
 819                if entry.get_mut().connection_ids.is_empty() {
 820                    entry.remove();
 821                }
 822            }
 823        }
 824    }
 825
 826    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
 827        Ok(self
 828            .connections
 829            .get(&connection_id)
 830            .ok_or_else(|| anyhow!("unknown connection"))?
 831            .user_id)
 832    }
 833
 834    // Add the given connection as a guest of the given worktree
 835    fn join_worktree(
 836        &mut self,
 837        connection_id: ConnectionId,
 838        worktree_id: u64,
 839        access_token: &str,
 840    ) -> Option<(ReplicaId, &Worktree)> {
 841        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
 842            if access_token == worktree.access_token {
 843                if let Some(connection) = self.connections.get_mut(&connection_id) {
 844                    connection.worktrees.insert(worktree_id);
 845                }
 846
 847                let mut replica_id = 1;
 848                while worktree.active_replica_ids.contains(&replica_id) {
 849                    replica_id += 1;
 850                }
 851                worktree.active_replica_ids.insert(replica_id);
 852                worktree
 853                    .guest_connection_ids
 854                    .insert(connection_id, replica_id);
 855                Some((replica_id, worktree))
 856            } else {
 857                None
 858            }
 859        } else {
 860            None
 861        }
 862    }
 863
 864    fn read_worktree(
 865        &self,
 866        worktree_id: u64,
 867        connection_id: ConnectionId,
 868    ) -> tide::Result<&Worktree> {
 869        let worktree = self
 870            .worktrees
 871            .get(&worktree_id)
 872            .ok_or_else(|| anyhow!("worktree not found"))?;
 873
 874        if worktree.host_connection_id == Some(connection_id)
 875            || worktree.guest_connection_ids.contains_key(&connection_id)
 876        {
 877            Ok(worktree)
 878        } else {
 879            Err(anyhow!(
 880                "{} is not a member of worktree {}",
 881                connection_id,
 882                worktree_id
 883            ))?
 884        }
 885    }
 886
 887    fn write_worktree(
 888        &mut self,
 889        worktree_id: u64,
 890        connection_id: ConnectionId,
 891    ) -> tide::Result<&mut Worktree> {
 892        let worktree = self
 893            .worktrees
 894            .get_mut(&worktree_id)
 895            .ok_or_else(|| anyhow!("worktree not found"))?;
 896
 897        if worktree.host_connection_id == Some(connection_id)
 898            || worktree.guest_connection_ids.contains_key(&connection_id)
 899        {
 900            Ok(worktree)
 901        } else {
 902            Err(anyhow!(
 903                "{} is not a member of worktree {}",
 904                connection_id,
 905                worktree_id
 906            ))?
 907        }
 908    }
 909}
 910
 911impl Worktree {
 912    pub fn connection_ids(&self) -> Vec<ConnectionId> {
 913        self.guest_connection_ids
 914            .keys()
 915            .copied()
 916            .chain(self.host_connection_id)
 917            .collect()
 918    }
 919
 920    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
 921        Ok(self
 922            .host_connection_id
 923            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
 924    }
 925}
 926
 927impl Channel {
 928    fn connection_ids(&self) -> Vec<ConnectionId> {
 929        self.connection_ids.iter().copied().collect()
 930    }
 931}
 932
 933pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
 934    let server = Server::new(app.state().clone(), rpc.clone(), None);
 935    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
 936        let user_id = request.ext::<UserId>().copied();
 937        let server = server.clone();
 938        async move {
 939            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
 940
 941            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
 942            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
 943            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
 944
 945            if !upgrade_requested {
 946                return Ok(Response::new(StatusCode::UpgradeRequired));
 947            }
 948
 949            let header = match request.header("Sec-Websocket-Key") {
 950                Some(h) => h.as_str(),
 951                None => return Err(anyhow!("expected sec-websocket-key"))?,
 952            };
 953
 954            let mut response = Response::new(StatusCode::SwitchingProtocols);
 955            response.insert_header(UPGRADE, "websocket");
 956            response.insert_header(CONNECTION, "Upgrade");
 957            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
 958            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
 959            response.insert_header("Sec-Websocket-Version", "13");
 960
 961            let http_res: &mut tide::http::Response = response.as_mut();
 962            let upgrade_receiver = http_res.recv_upgrade().await;
 963            let addr = request.remote().unwrap_or("unknown").to_string();
 964            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
 965            task::spawn(async move {
 966                if let Some(stream) = upgrade_receiver.await {
 967                    server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
 968                }
 969            });
 970
 971            Ok(response)
 972        }
 973    });
 974}
 975
 976fn header_contains_ignore_case<T>(
 977    request: &tide::Request<T>,
 978    header_name: HeaderName,
 979    value: &str,
 980) -> bool {
 981    request
 982        .header(header_name)
 983        .map(|h| {
 984            h.as_str()
 985                .split(',')
 986                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
 987        })
 988        .unwrap_or(false)
 989}
 990
 991#[cfg(test)]
 992mod tests {
 993    use super::*;
 994    use crate::{
 995        auth,
 996        db::{tests::TestDb, UserId},
 997        github, AppState, Config,
 998    };
 999    use async_std::{sync::RwLockReadGuard, task};
1000    use gpui::TestAppContext;
1001    use postage::mpsc;
1002    use serde_json::json;
1003    use sqlx::types::time::OffsetDateTime;
1004    use std::{path::Path, sync::Arc, time::Duration};
1005    use zed::{
1006        channel::{Channel, ChannelDetails, ChannelList},
1007        editor::{Editor, Insert},
1008        fs::{FakeFs, Fs as _},
1009        language::LanguageRegistry,
1010        rpc::Client,
1011        settings,
1012        user::UserStore,
1013        worktree::Worktree,
1014    };
1015    use zrpc::Peer;
1016
1017    #[gpui::test]
1018    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1019        let (window_b, _) = cx_b.add_window(|_| EmptyView);
1020        let settings = cx_b.read(settings::test).1;
1021        let lang_registry = Arc::new(LanguageRegistry::new());
1022
1023        // Connect to a server as 2 clients.
1024        let mut server = TestServer::start().await;
1025        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1026        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1027
1028        cx_a.foreground().forbid_parking();
1029
1030        // Share a local worktree as client A
1031        let fs = Arc::new(FakeFs::new());
1032        fs.insert_tree(
1033            "/a",
1034            json!({
1035                "a.txt": "a-contents",
1036                "b.txt": "b-contents",
1037            }),
1038        )
1039        .await;
1040        let worktree_a = Worktree::open_local(
1041            "/a".as_ref(),
1042            lang_registry.clone(),
1043            fs,
1044            &mut cx_a.to_async(),
1045        )
1046        .await
1047        .unwrap();
1048        worktree_a
1049            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1050            .await;
1051        let (worktree_id, worktree_token) = worktree_a
1052            .update(&mut cx_a, |tree, cx| {
1053                tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1054            })
1055            .await
1056            .unwrap();
1057
1058        // Join that worktree as client B, and see that a guest has joined as client A.
1059        let worktree_b = Worktree::open_remote(
1060            client_b.clone(),
1061            worktree_id,
1062            worktree_token,
1063            lang_registry.clone(),
1064            &mut cx_b.to_async(),
1065        )
1066        .await
1067        .unwrap();
1068        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1069        worktree_a
1070            .condition(&cx_a, |tree, _| {
1071                tree.peers()
1072                    .values()
1073                    .any(|replica_id| *replica_id == replica_id_b)
1074            })
1075            .await;
1076
1077        // Open the same file as client B and client A.
1078        let buffer_b = worktree_b
1079            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1080            .await
1081            .unwrap();
1082        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1083        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1084        let buffer_a = worktree_a
1085            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1086            .await
1087            .unwrap();
1088
1089        // Create a selection set as client B and see that selection set as client A.
1090        let editor_b = cx_b.add_view(window_b, |cx| Editor::for_buffer(buffer_b, settings, cx));
1091        buffer_a
1092            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1093            .await;
1094
1095        // Edit the buffer as client B and see that edit as client A.
1096        editor_b.update(&mut cx_b, |editor, cx| {
1097            editor.insert(&Insert("ok, ".into()), cx)
1098        });
1099        buffer_a
1100            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1101            .await;
1102
1103        // Remove the selection set as client B, see those selections disappear as client A.
1104        cx_b.update(move |_| drop(editor_b));
1105        buffer_a
1106            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1107            .await;
1108
1109        // Close the buffer as client A, see that the buffer is closed.
1110        drop(buffer_a);
1111        worktree_a
1112            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1113            .await;
1114
1115        // Dropping the worktree removes client B from client A's peers.
1116        cx_b.update(move |_| drop(worktree_b));
1117        worktree_a
1118            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1119            .await;
1120    }
1121
1122    #[gpui::test]
1123    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1124        mut cx_a: TestAppContext,
1125        mut cx_b: TestAppContext,
1126        mut cx_c: TestAppContext,
1127    ) {
1128        cx_a.foreground().forbid_parking();
1129        let lang_registry = Arc::new(LanguageRegistry::new());
1130
1131        // Connect to a server as 3 clients.
1132        let mut server = TestServer::start().await;
1133        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1134        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1135        let (_, client_c) = server.create_client(&mut cx_c, "user_c").await;
1136
1137        let fs = Arc::new(FakeFs::new());
1138
1139        // Share a worktree as client A.
1140        fs.insert_tree(
1141            "/a",
1142            json!({
1143                "file1": "",
1144                "file2": ""
1145            }),
1146        )
1147        .await;
1148
1149        let worktree_a = Worktree::open_local(
1150            "/a".as_ref(),
1151            lang_registry.clone(),
1152            fs.clone(),
1153            &mut cx_a.to_async(),
1154        )
1155        .await
1156        .unwrap();
1157        worktree_a
1158            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1159            .await;
1160        let (worktree_id, worktree_token) = worktree_a
1161            .update(&mut cx_a, |tree, cx| {
1162                tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1163            })
1164            .await
1165            .unwrap();
1166
1167        // Join that worktree as clients B and C.
1168        let worktree_b = Worktree::open_remote(
1169            client_b.clone(),
1170            worktree_id,
1171            worktree_token.clone(),
1172            lang_registry.clone(),
1173            &mut cx_b.to_async(),
1174        )
1175        .await
1176        .unwrap();
1177        let worktree_c = Worktree::open_remote(
1178            client_c.clone(),
1179            worktree_id,
1180            worktree_token,
1181            lang_registry.clone(),
1182            &mut cx_c.to_async(),
1183        )
1184        .await
1185        .unwrap();
1186
1187        // Open and edit a buffer as both guests B and C.
1188        let buffer_b = worktree_b
1189            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1190            .await
1191            .unwrap();
1192        let buffer_c = worktree_c
1193            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1194            .await
1195            .unwrap();
1196        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1197        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1198
1199        // Open and edit that buffer as the host.
1200        let buffer_a = worktree_a
1201            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1202            .await
1203            .unwrap();
1204
1205        buffer_a
1206            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1207            .await;
1208        buffer_a.update(&mut cx_a, |buf, cx| {
1209            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1210        });
1211
1212        // Wait for edits to propagate
1213        buffer_a
1214            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1215            .await;
1216        buffer_b
1217            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1218            .await;
1219        buffer_c
1220            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1221            .await;
1222
1223        // Edit the buffer as the host and concurrently save as guest B.
1224        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1225        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1226        save_b.await.unwrap();
1227        assert_eq!(
1228            fs.load("/a/file1".as_ref()).await.unwrap(),
1229            "hi-a, i-am-c, i-am-b, i-am-a"
1230        );
1231        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1232        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1233        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1234
1235        // Make changes on host's file system, see those changes on the guests.
1236        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1237            .await
1238            .unwrap();
1239        fs.insert_file(Path::new("/a/file4"), "4".into())
1240            .await
1241            .unwrap();
1242
1243        worktree_b
1244            .condition(&cx_b, |tree, _| tree.file_count() == 3)
1245            .await;
1246        worktree_c
1247            .condition(&cx_c, |tree, _| tree.file_count() == 3)
1248            .await;
1249        worktree_b.read_with(&cx_b, |tree, _| {
1250            assert_eq!(
1251                tree.paths()
1252                    .map(|p| p.to_string_lossy())
1253                    .collect::<Vec<_>>(),
1254                &["file1", "file3", "file4"]
1255            )
1256        });
1257        worktree_c.read_with(&cx_c, |tree, _| {
1258            assert_eq!(
1259                tree.paths()
1260                    .map(|p| p.to_string_lossy())
1261                    .collect::<Vec<_>>(),
1262                &["file1", "file3", "file4"]
1263            )
1264        });
1265    }
1266
1267    #[gpui::test]
1268    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1269        cx_a.foreground().forbid_parking();
1270        let lang_registry = Arc::new(LanguageRegistry::new());
1271
1272        // Connect to a server as 2 clients.
1273        let mut server = TestServer::start().await;
1274        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1275        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1276
1277        // Share a local worktree as client A
1278        let fs = Arc::new(FakeFs::new());
1279        fs.save(Path::new("/a.txt"), &"a-contents".into())
1280            .await
1281            .unwrap();
1282        let worktree_a = Worktree::open_local(
1283            "/".as_ref(),
1284            lang_registry.clone(),
1285            fs,
1286            &mut cx_a.to_async(),
1287        )
1288        .await
1289        .unwrap();
1290        worktree_a
1291            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1292            .await;
1293        let (worktree_id, worktree_token) = worktree_a
1294            .update(&mut cx_a, |tree, cx| {
1295                tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1296            })
1297            .await
1298            .unwrap();
1299
1300        // Join that worktree as client B, and see that a guest has joined as client A.
1301        let worktree_b = Worktree::open_remote(
1302            client_b.clone(),
1303            worktree_id,
1304            worktree_token,
1305            lang_registry.clone(),
1306            &mut cx_b.to_async(),
1307        )
1308        .await
1309        .unwrap();
1310
1311        let buffer_b = worktree_b
1312            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1313            .await
1314            .unwrap();
1315        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime);
1316
1317        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1318        buffer_b.read_with(&cx_b, |buf, _| {
1319            assert!(buf.is_dirty());
1320            assert!(!buf.has_conflict());
1321        });
1322
1323        buffer_b
1324            .update(&mut cx_b, |buf, cx| buf.save(cx))
1325            .unwrap()
1326            .await
1327            .unwrap();
1328        worktree_b
1329            .condition(&cx_b, |_, cx| {
1330                buffer_b.read(cx).file().unwrap().mtime != mtime
1331            })
1332            .await;
1333        buffer_b.read_with(&cx_b, |buf, _| {
1334            assert!(!buf.is_dirty());
1335            assert!(!buf.has_conflict());
1336        });
1337
1338        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1339        buffer_b.read_with(&cx_b, |buf, _| {
1340            assert!(buf.is_dirty());
1341            assert!(!buf.has_conflict());
1342        });
1343    }
1344
1345    #[gpui::test]
1346    async fn test_editing_while_guest_opens_buffer(
1347        mut cx_a: TestAppContext,
1348        mut cx_b: TestAppContext,
1349    ) {
1350        cx_a.foreground().forbid_parking();
1351        let lang_registry = Arc::new(LanguageRegistry::new());
1352
1353        // Connect to a server as 2 clients.
1354        let mut server = TestServer::start().await;
1355        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1356        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
1357
1358        // Share a local worktree as client A
1359        let fs = Arc::new(FakeFs::new());
1360        fs.save(Path::new("/a.txt"), &"a-contents".into())
1361            .await
1362            .unwrap();
1363        let worktree_a = Worktree::open_local(
1364            "/".as_ref(),
1365            lang_registry.clone(),
1366            fs,
1367            &mut cx_a.to_async(),
1368        )
1369        .await
1370        .unwrap();
1371        worktree_a
1372            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1373            .await;
1374        let (worktree_id, worktree_token) = worktree_a
1375            .update(&mut cx_a, |tree, cx| {
1376                tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1377            })
1378            .await
1379            .unwrap();
1380
1381        // Join that worktree as client B, and see that a guest has joined as client A.
1382        let worktree_b = Worktree::open_remote(
1383            client_b.clone(),
1384            worktree_id,
1385            worktree_token,
1386            lang_registry.clone(),
1387            &mut cx_b.to_async(),
1388        )
1389        .await
1390        .unwrap();
1391
1392        let buffer_a = worktree_a
1393            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1394            .await
1395            .unwrap();
1396        let buffer_b = cx_b
1397            .background()
1398            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1399
1400        task::yield_now().await;
1401        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1402
1403        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1404        let buffer_b = buffer_b.await.unwrap();
1405        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1406    }
1407
1408    #[gpui::test]
1409    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1410        cx_a.foreground().forbid_parking();
1411        let lang_registry = Arc::new(LanguageRegistry::new());
1412
1413        // Connect to a server as 2 clients.
1414        let mut server = TestServer::start().await;
1415        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
1416        let (_, client_b) = server.create_client(&mut cx_a, "user_b").await;
1417
1418        // Share a local worktree as client A
1419        let fs = Arc::new(FakeFs::new());
1420        fs.insert_tree(
1421            "/a",
1422            json!({
1423                "a.txt": "a-contents",
1424                "b.txt": "b-contents",
1425            }),
1426        )
1427        .await;
1428        let worktree_a = Worktree::open_local(
1429            "/a".as_ref(),
1430            lang_registry.clone(),
1431            fs,
1432            &mut cx_a.to_async(),
1433        )
1434        .await
1435        .unwrap();
1436        worktree_a
1437            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1438            .await;
1439        let (worktree_id, worktree_token) = worktree_a
1440            .update(&mut cx_a, |tree, cx| {
1441                tree.as_local_mut().unwrap().share(client_a.clone(), cx)
1442            })
1443            .await
1444            .unwrap();
1445
1446        // Join that worktree as client B, and see that a guest has joined as client A.
1447        let _worktree_b = Worktree::open_remote(
1448            client_b.clone(),
1449            worktree_id,
1450            worktree_token,
1451            lang_registry.clone(),
1452            &mut cx_b.to_async(),
1453        )
1454        .await
1455        .unwrap();
1456        worktree_a
1457            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1458            .await;
1459
1460        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1461        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1462        worktree_a
1463            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1464            .await;
1465    }
1466
1467    #[gpui::test]
1468    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1469        cx_a.foreground().forbid_parking();
1470
1471        // Connect to a server as 2 clients.
1472        let mut server = TestServer::start().await;
1473        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
1474        let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
1475
1476        // Create an org that includes these 2 users.
1477        let db = &server.app_state.db;
1478        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1479        db.add_org_member(org_id, user_id_a, false).await.unwrap();
1480        db.add_org_member(org_id, user_id_b, false).await.unwrap();
1481
1482        // Create a channel that includes all the users.
1483        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1484        db.add_channel_member(channel_id, user_id_a, false)
1485            .await
1486            .unwrap();
1487        db.add_channel_member(channel_id, user_id_b, false)
1488            .await
1489            .unwrap();
1490        db.create_channel_message(
1491            channel_id,
1492            user_id_b,
1493            "hello A, it's B.",
1494            OffsetDateTime::now_utc(),
1495        )
1496        .await
1497        .unwrap();
1498
1499        let user_store_a = Arc::new(UserStore::new(client_a.clone()));
1500        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1501        channels_a
1502            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1503            .await;
1504        channels_a.read_with(&cx_a, |list, _| {
1505            assert_eq!(
1506                list.available_channels().unwrap(),
1507                &[ChannelDetails {
1508                    id: channel_id.to_proto(),
1509                    name: "test-channel".to_string()
1510                }]
1511            )
1512        });
1513        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1514            this.get_channel(channel_id.to_proto(), cx).unwrap()
1515        });
1516        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1517        channel_a
1518            .condition(&cx_a, |channel, _| {
1519                channel_messages(channel)
1520                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
1521            })
1522            .await;
1523
1524        let user_store_b = Arc::new(UserStore::new(client_b.clone()));
1525        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1526        channels_b
1527            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1528            .await;
1529        channels_b.read_with(&cx_b, |list, _| {
1530            assert_eq!(
1531                list.available_channels().unwrap(),
1532                &[ChannelDetails {
1533                    id: channel_id.to_proto(),
1534                    name: "test-channel".to_string()
1535                }]
1536            )
1537        });
1538
1539        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1540            this.get_channel(channel_id.to_proto(), cx).unwrap()
1541        });
1542        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1543        channel_b
1544            .condition(&cx_b, |channel, _| {
1545                channel_messages(channel)
1546                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
1547            })
1548            .await;
1549
1550        channel_a
1551            .update(&mut cx_a, |channel, cx| {
1552                channel
1553                    .send_message("oh, hi B.".to_string(), cx)
1554                    .unwrap()
1555                    .detach();
1556                let task = channel.send_message("sup".to_string(), cx).unwrap();
1557                assert_eq!(
1558                    channel
1559                        .pending_messages()
1560                        .iter()
1561                        .map(|m| &m.body)
1562                        .collect::<Vec<_>>(),
1563                    &["oh, hi B.", "sup"]
1564                );
1565                task
1566            })
1567            .await
1568            .unwrap();
1569
1570        channel_a
1571            .condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
1572            .await;
1573        channel_b
1574            .condition(&cx_b, |channel, _| {
1575                channel_messages(channel)
1576                    == [
1577                        ("user_b".to_string(), "hello A, it's B.".to_string()),
1578                        ("user_a".to_string(), "oh, hi B.".to_string()),
1579                        ("user_a".to_string(), "sup".to_string()),
1580                    ]
1581            })
1582            .await;
1583
1584        assert_eq!(
1585            server.state().await.channels[&channel_id]
1586                .connection_ids
1587                .len(),
1588            2
1589        );
1590        cx_b.update(|_| drop(channel_b));
1591        server
1592            .condition(|state| state.channels[&channel_id].connection_ids.len() == 1)
1593            .await;
1594
1595        cx_a.update(|_| drop(channel_a));
1596        server
1597            .condition(|state| !state.channels.contains_key(&channel_id))
1598            .await;
1599
1600        fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
1601            channel
1602                .messages()
1603                .cursor::<(), ()>()
1604                .map(|m| (m.sender.github_login.clone(), m.body.clone()))
1605                .collect()
1606        }
1607    }
1608
1609    #[gpui::test]
1610    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1611        cx_a.foreground().forbid_parking();
1612
1613        let mut server = TestServer::start().await;
1614        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
1615
1616        let db = &server.app_state.db;
1617        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1618        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1619        db.add_org_member(org_id, user_id_a, false).await.unwrap();
1620        db.add_channel_member(channel_id, user_id_a, false)
1621            .await
1622            .unwrap();
1623
1624        let user_store_a = Arc::new(UserStore::new(client_a.clone()));
1625        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1626        channels_a
1627            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1628            .await;
1629        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1630            this.get_channel(channel_id.to_proto(), cx).unwrap()
1631        });
1632
1633        // Messages aren't allowed to be too long.
1634        channel_a
1635            .update(&mut cx_a, |channel, cx| {
1636                let long_body = "this is long.\n".repeat(1024);
1637                channel.send_message(long_body, cx).unwrap()
1638            })
1639            .await
1640            .unwrap_err();
1641
1642        // Messages aren't allowed to be blank.
1643        channel_a.update(&mut cx_a, |channel, cx| {
1644            channel.send_message(String::new(), cx).unwrap_err()
1645        });
1646
1647        // Leading and trailing whitespace are trimmed.
1648        channel_a
1649            .update(&mut cx_a, |channel, cx| {
1650                channel
1651                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1652                    .unwrap()
1653            })
1654            .await
1655            .unwrap();
1656        assert_eq!(
1657            db.get_channel_messages(channel_id, 10, None)
1658                .await
1659                .unwrap()
1660                .iter()
1661                .map(|m| &m.body)
1662                .collect::<Vec<_>>(),
1663            &["surrounded by whitespace"]
1664        );
1665    }
1666
1667    struct TestServer {
1668        peer: Arc<Peer>,
1669        app_state: Arc<AppState>,
1670        server: Arc<Server>,
1671        notifications: mpsc::Receiver<()>,
1672        _test_db: TestDb,
1673    }
1674
1675    impl TestServer {
1676        async fn start() -> Self {
1677            let test_db = TestDb::new();
1678            let app_state = Self::build_app_state(&test_db).await;
1679            let peer = Peer::new();
1680            let notifications = mpsc::channel(128);
1681            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
1682            Self {
1683                peer,
1684                app_state,
1685                server,
1686                notifications: notifications.1,
1687                _test_db: test_db,
1688            }
1689        }
1690
1691        async fn create_client(
1692            &mut self,
1693            cx: &mut TestAppContext,
1694            name: &str,
1695        ) -> (UserId, Arc<Client>) {
1696            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
1697            let client = Client::new();
1698            let (client_conn, server_conn) = Conn::in_memory();
1699            cx.background()
1700                .spawn(
1701                    self.server
1702                        .handle_connection(server_conn, name.to_string(), user_id),
1703                )
1704                .detach();
1705            client
1706                .set_connection(user_id.to_proto(), client_conn, &cx.to_async())
1707                .await
1708                .unwrap();
1709            (user_id, client)
1710        }
1711
1712        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
1713            let mut config = Config::default();
1714            config.session_secret = "a".repeat(32);
1715            config.database_url = test_db.url.clone();
1716            let github_client = github::AppClient::test();
1717            Arc::new(AppState {
1718                db: test_db.db().clone(),
1719                handlebars: Default::default(),
1720                auth_client: auth::build_client("", ""),
1721                repo_client: github::RepoClient::test(&github_client),
1722                github_client,
1723                config,
1724            })
1725        }
1726
1727        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
1728            self.server.state.read().await
1729        }
1730
1731        async fn condition<F>(&mut self, mut predicate: F)
1732        where
1733            F: FnMut(&ServerState) -> bool,
1734        {
1735            async_std::future::timeout(Duration::from_millis(500), async {
1736                while !(predicate)(&*self.server.state.read().await) {
1737                    self.notifications.recv().await;
1738                }
1739            })
1740            .await
1741            .expect("condition timed out");
1742        }
1743    }
1744
1745    impl Drop for TestServer {
1746        fn drop(&mut self) {
1747            task::block_on(self.peer.reset());
1748        }
1749    }
1750
1751    struct EmptyView;
1752
1753    impl gpui::Entity for EmptyView {
1754        type Event = ();
1755    }
1756
1757    impl gpui::View for EmptyView {
1758        fn ui_name() -> &'static str {
1759            "empty view"
1760        }
1761
1762        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
1763            gpui::Element::boxed(gpui::elements::Empty)
1764        }
1765    }
1766}