rpc.rs

   1mod store;
   2
   3use super::{
   4    auth::process_auth_header,
   5    db::{ChannelId, MessageId, UserId},
   6    AppState,
   7};
   8use anyhow::anyhow;
   9use async_std::task;
  10use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
  11use futures::{future::BoxFuture, FutureExt};
  12use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
  13use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
  14use rpc::{
  15    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
  16    Connection, ConnectionId, Peer, TypedEnvelope,
  17};
  18use sha1::{Digest as _, Sha1};
  19use std::{
  20    any::TypeId,
  21    collections::{HashMap, HashSet},
  22    future::Future,
  23    mem,
  24    sync::Arc,
  25    time::Instant,
  26};
  27use store::{Store, Worktree};
  28use surf::StatusCode;
  29use tide::log;
  30use tide::{
  31    http::headers::{HeaderName, CONNECTION, UPGRADE},
  32    Request, Response,
  33};
  34use time::OffsetDateTime;
  35
  36type MessageHandler = Box<
  37    dyn Send
  38        + Sync
  39        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
  40>;
  41
  42pub struct Server {
  43    peer: Arc<Peer>,
  44    store: RwLock<Store>,
  45    app_state: Arc<AppState>,
  46    handlers: HashMap<TypeId, MessageHandler>,
  47    notifications: Option<mpsc::Sender<()>>,
  48}
  49
  50const MESSAGE_COUNT_PER_PAGE: usize = 100;
  51const MAX_MESSAGE_LEN: usize = 1024;
  52
  53impl Server {
  54    pub fn new(
  55        app_state: Arc<AppState>,
  56        peer: Arc<Peer>,
  57        notifications: Option<mpsc::Sender<()>>,
  58    ) -> Arc<Self> {
  59        let mut server = Self {
  60            peer,
  61            app_state,
  62            store: Default::default(),
  63            handlers: Default::default(),
  64            notifications,
  65        };
  66
  67        server
  68            .add_handler(Server::ping)
  69            .add_handler(Server::open_worktree)
  70            .add_handler(Server::close_worktree)
  71            .add_handler(Server::share_worktree)
  72            .add_handler(Server::unshare_worktree)
  73            .add_handler(Server::join_worktree)
  74            .add_handler(Server::leave_worktree)
  75            .add_handler(Server::update_worktree)
  76            .add_handler(Server::open_buffer)
  77            .add_handler(Server::close_buffer)
  78            .add_handler(Server::update_buffer)
  79            .add_handler(Server::buffer_saved)
  80            .add_handler(Server::save_buffer)
  81            .add_handler(Server::get_channels)
  82            .add_handler(Server::get_users)
  83            .add_handler(Server::join_channel)
  84            .add_handler(Server::leave_channel)
  85            .add_handler(Server::send_channel_message)
  86            .add_handler(Server::get_channel_messages);
  87
  88        Arc::new(server)
  89    }
  90
  91    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
  92    where
  93        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
  94        Fut: 'static + Send + Future<Output = tide::Result<()>>,
  95        M: EnvelopedMessage,
  96    {
  97        let prev_handler = self.handlers.insert(
  98            TypeId::of::<M>(),
  99            Box::new(move |server, envelope| {
 100                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 101                (handler)(server, *envelope).boxed()
 102            }),
 103        );
 104        if prev_handler.is_some() {
 105            panic!("registered a handler for the same message twice");
 106        }
 107        self
 108    }
 109
 110    pub fn handle_connection(
 111        self: &Arc<Self>,
 112        connection: Connection,
 113        addr: String,
 114        user_id: UserId,
 115    ) -> impl Future<Output = ()> {
 116        let mut this = self.clone();
 117        async move {
 118            let (connection_id, handle_io, mut incoming_rx) =
 119                this.peer.add_connection(connection).await;
 120            this.state_mut().add_connection(connection_id, user_id);
 121            if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
 122                log::error!("error updating collaborators for {:?}: {}", user_id, err);
 123            }
 124
 125            let handle_io = handle_io.fuse();
 126            futures::pin_mut!(handle_io);
 127            loop {
 128                let next_message = incoming_rx.recv().fuse();
 129                futures::pin_mut!(next_message);
 130                futures::select_biased! {
 131                    message = next_message => {
 132                        if let Some(message) = message {
 133                            let start_time = Instant::now();
 134                            log::info!("RPC message received: {}", message.payload_type_name());
 135                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 136                                if let Err(err) = (handler)(this.clone(), message).await {
 137                                    log::error!("error handling message: {:?}", err);
 138                                } else {
 139                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
 140                                }
 141
 142                                if let Some(mut notifications) = this.notifications.clone() {
 143                                    let _ = notifications.send(()).await;
 144                                }
 145                            } else {
 146                                log::warn!("unhandled message: {}", message.payload_type_name());
 147                            }
 148                        } else {
 149                            log::info!("rpc connection closed {:?}", addr);
 150                            break;
 151                        }
 152                    }
 153                    handle_io = handle_io => {
 154                        if let Err(err) = handle_io {
 155                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
 156                        }
 157                        break;
 158                    }
 159                }
 160            }
 161
 162            if let Err(err) = this.sign_out(connection_id).await {
 163                log::error!("error signing out connection {:?} - {:?}", addr, err);
 164            }
 165        }
 166    }
 167
 168    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
 169        self.peer.disconnect(connection_id).await;
 170        let removed_connection = self.state_mut().remove_connection(connection_id)?;
 171
 172        for (worktree_id, worktree) in removed_connection.hosted_worktrees {
 173            if let Some(share) = worktree.share {
 174                broadcast(
 175                    connection_id,
 176                    share.guests.keys().copied().collect(),
 177                    |conn_id| {
 178                        self.peer
 179                            .send(conn_id, proto::UnshareWorktree { worktree_id })
 180                    },
 181                )
 182                .await?;
 183            }
 184        }
 185
 186        for (worktree_id, peer_ids) in removed_connection.guest_worktree_ids {
 187            broadcast(connection_id, peer_ids, |conn_id| {
 188                self.peer.send(
 189                    conn_id,
 190                    proto::RemovePeer {
 191                        worktree_id,
 192                        peer_id: connection_id.0,
 193                    },
 194                )
 195            })
 196            .await?;
 197        }
 198
 199        self.update_collaborators_for_users(removed_connection.collaborator_ids.iter())
 200            .await?;
 201
 202        Ok(())
 203    }
 204
 205    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
 206        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 207        Ok(())
 208    }
 209
 210    async fn open_worktree(
 211        mut self: Arc<Server>,
 212        request: TypedEnvelope<proto::OpenWorktree>,
 213    ) -> tide::Result<()> {
 214        let receipt = request.receipt();
 215        let host_user_id = self.state().user_id_for_connection(request.sender_id)?;
 216
 217        let mut collaborator_user_ids = HashSet::new();
 218        collaborator_user_ids.insert(host_user_id);
 219        for github_login in request.payload.collaborator_logins {
 220            match self.app_state.db.create_user(&github_login, false).await {
 221                Ok(collaborator_user_id) => {
 222                    collaborator_user_ids.insert(collaborator_user_id);
 223                }
 224                Err(err) => {
 225                    let message = err.to_string();
 226                    self.peer
 227                        .respond_with_error(receipt, proto::Error { message })
 228                        .await?;
 229                    return Ok(());
 230                }
 231            }
 232        }
 233
 234        let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
 235        let worktree_id = self.state_mut().add_worktree(Worktree {
 236            host_connection_id: request.sender_id,
 237            host_user_id,
 238            collaborator_user_ids: collaborator_user_ids.clone(),
 239            root_name: request.payload.root_name,
 240            share: None,
 241        });
 242
 243        self.peer
 244            .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
 245            .await?;
 246        self.update_collaborators_for_users(&collaborator_user_ids)
 247            .await?;
 248
 249        Ok(())
 250    }
 251
 252    async fn close_worktree(
 253        mut self: Arc<Server>,
 254        request: TypedEnvelope<proto::CloseWorktree>,
 255    ) -> tide::Result<()> {
 256        let worktree_id = request.payload.worktree_id;
 257        let worktree = self
 258            .state_mut()
 259            .remove_worktree(worktree_id, request.sender_id)?;
 260
 261        if let Some(share) = worktree.share {
 262            broadcast(
 263                request.sender_id,
 264                share.guests.keys().copied().collect(),
 265                |conn_id| {
 266                    self.peer
 267                        .send(conn_id, proto::UnshareWorktree { worktree_id })
 268                },
 269            )
 270            .await?;
 271        }
 272        self.update_collaborators_for_users(&worktree.collaborator_user_ids)
 273            .await?;
 274        Ok(())
 275    }
 276
 277    async fn share_worktree(
 278        mut self: Arc<Server>,
 279        mut request: TypedEnvelope<proto::ShareWorktree>,
 280    ) -> tide::Result<()> {
 281        let worktree = request
 282            .payload
 283            .worktree
 284            .as_mut()
 285            .ok_or_else(|| anyhow!("missing worktree"))?;
 286        let entries = mem::take(&mut worktree.entries)
 287            .into_iter()
 288            .map(|entry| (entry.id, entry))
 289            .collect();
 290
 291        let collaborator_user_ids =
 292            self.state_mut()
 293                .share_worktree(worktree.id, request.sender_id, entries);
 294        if let Some(collaborator_user_ids) = collaborator_user_ids {
 295            self.peer
 296                .respond(request.receipt(), proto::ShareWorktreeResponse {})
 297                .await?;
 298            self.update_collaborators_for_users(&collaborator_user_ids)
 299                .await?;
 300        } else {
 301            self.peer
 302                .respond_with_error(
 303                    request.receipt(),
 304                    proto::Error {
 305                        message: "no such worktree".to_string(),
 306                    },
 307                )
 308                .await?;
 309        }
 310        Ok(())
 311    }
 312
 313    async fn unshare_worktree(
 314        mut self: Arc<Server>,
 315        request: TypedEnvelope<proto::UnshareWorktree>,
 316    ) -> tide::Result<()> {
 317        let worktree_id = request.payload.worktree_id;
 318        let worktree = self
 319            .state_mut()
 320            .unshare_worktree(worktree_id, request.sender_id)?;
 321
 322        broadcast(request.sender_id, worktree.connection_ids, |conn_id| {
 323            self.peer
 324                .send(conn_id, proto::UnshareWorktree { worktree_id })
 325        })
 326        .await?;
 327        self.update_collaborators_for_users(&worktree.collaborator_ids)
 328            .await?;
 329
 330        Ok(())
 331    }
 332
 333    async fn join_worktree(
 334        mut self: Arc<Server>,
 335        request: TypedEnvelope<proto::JoinWorktree>,
 336    ) -> tide::Result<()> {
 337        let worktree_id = request.payload.worktree_id;
 338
 339        let user_id = self.state().user_id_for_connection(request.sender_id)?;
 340        let response_data = self
 341            .state_mut()
 342            .join_worktree(request.sender_id, user_id, worktree_id)
 343            .and_then(|joined| {
 344                let share = joined.worktree.share()?;
 345                let peer_count = share.guests.len();
 346                let mut peers = Vec::with_capacity(peer_count);
 347                peers.push(proto::Peer {
 348                    peer_id: joined.worktree.host_connection_id.0,
 349                    replica_id: 0,
 350                    user_id: joined.worktree.host_user_id.to_proto(),
 351                });
 352                for (peer_conn_id, (peer_replica_id, peer_user_id)) in &share.guests {
 353                    if *peer_conn_id != request.sender_id {
 354                        peers.push(proto::Peer {
 355                            peer_id: peer_conn_id.0,
 356                            replica_id: *peer_replica_id as u32,
 357                            user_id: peer_user_id.to_proto(),
 358                        });
 359                    }
 360                }
 361                let response = proto::JoinWorktreeResponse {
 362                    worktree: Some(proto::Worktree {
 363                        id: worktree_id,
 364                        root_name: joined.worktree.root_name.clone(),
 365                        entries: share.entries.values().cloned().collect(),
 366                    }),
 367                    replica_id: joined.replica_id as u32,
 368                    peers,
 369                };
 370                let connection_ids = joined.worktree.connection_ids();
 371                let collaborator_user_ids = joined.worktree.collaborator_user_ids.clone();
 372                Ok((response, connection_ids, collaborator_user_ids))
 373            });
 374
 375        match response_data {
 376            Ok((response, connection_ids, collaborator_user_ids)) => {
 377                broadcast(request.sender_id, connection_ids, |conn_id| {
 378                    self.peer.send(
 379                        conn_id,
 380                        proto::AddPeer {
 381                            worktree_id,
 382                            peer: Some(proto::Peer {
 383                                peer_id: request.sender_id.0,
 384                                replica_id: response.replica_id,
 385                                user_id: user_id.to_proto(),
 386                            }),
 387                        },
 388                    )
 389                })
 390                .await?;
 391                self.peer.respond(request.receipt(), response).await?;
 392                self.update_collaborators_for_users(&collaborator_user_ids)
 393                    .await?;
 394            }
 395            Err(error) => {
 396                self.peer
 397                    .respond_with_error(
 398                        request.receipt(),
 399                        proto::Error {
 400                            message: error.to_string(),
 401                        },
 402                    )
 403                    .await?;
 404            }
 405        }
 406
 407        Ok(())
 408    }
 409
 410    async fn leave_worktree(
 411        mut self: Arc<Server>,
 412        request: TypedEnvelope<proto::LeaveWorktree>,
 413    ) -> tide::Result<()> {
 414        let sender_id = request.sender_id;
 415        let worktree_id = request.payload.worktree_id;
 416        let worktree = self.state_mut().leave_worktree(sender_id, worktree_id);
 417        if let Some(worktree) = worktree {
 418            broadcast(sender_id, worktree.connection_ids, |conn_id| {
 419                self.peer.send(
 420                    conn_id,
 421                    proto::RemovePeer {
 422                        worktree_id,
 423                        peer_id: sender_id.0,
 424                    },
 425                )
 426            })
 427            .await?;
 428            self.update_collaborators_for_users(&worktree.collaborator_ids)
 429                .await?;
 430        }
 431        Ok(())
 432    }
 433
 434    async fn update_worktree(
 435        mut self: Arc<Server>,
 436        request: TypedEnvelope<proto::UpdateWorktree>,
 437    ) -> tide::Result<()> {
 438        let connection_ids = self.state_mut().update_worktree(
 439            request.sender_id,
 440            request.payload.worktree_id,
 441            &request.payload.removed_entries,
 442            &request.payload.updated_entries,
 443        )?;
 444
 445        broadcast(request.sender_id, connection_ids, |connection_id| {
 446            self.peer
 447                .forward_send(request.sender_id, connection_id, request.payload.clone())
 448        })
 449        .await?;
 450
 451        Ok(())
 452    }
 453
 454    async fn open_buffer(
 455        self: Arc<Server>,
 456        request: TypedEnvelope<proto::OpenBuffer>,
 457    ) -> tide::Result<()> {
 458        let receipt = request.receipt();
 459        let host_connection_id = self
 460            .state()
 461            .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
 462        let response = self
 463            .peer
 464            .forward_request(request.sender_id, host_connection_id, request.payload)
 465            .await?;
 466        self.peer.respond(receipt, response).await?;
 467        Ok(())
 468    }
 469
 470    async fn close_buffer(
 471        self: Arc<Server>,
 472        request: TypedEnvelope<proto::CloseBuffer>,
 473    ) -> tide::Result<()> {
 474        let host_connection_id = self
 475            .state()
 476            .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
 477        self.peer
 478            .forward_send(request.sender_id, host_connection_id, request.payload)
 479            .await?;
 480        Ok(())
 481    }
 482
 483    async fn save_buffer(
 484        self: Arc<Server>,
 485        request: TypedEnvelope<proto::SaveBuffer>,
 486    ) -> tide::Result<()> {
 487        let host;
 488        let guests;
 489        {
 490            let state = self.state();
 491            host = state
 492                .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
 493            guests = state
 494                .worktree_guest_connection_ids(request.sender_id, request.payload.worktree_id)?;
 495        }
 496
 497        let sender = request.sender_id;
 498        let receipt = request.receipt();
 499        let response = self
 500            .peer
 501            .forward_request(sender, host, request.payload.clone())
 502            .await?;
 503
 504        broadcast(host, guests, |conn_id| {
 505            let response = response.clone();
 506            let peer = &self.peer;
 507            async move {
 508                if conn_id == sender {
 509                    peer.respond(receipt, response).await
 510                } else {
 511                    peer.forward_send(host, conn_id, response).await
 512                }
 513            }
 514        })
 515        .await?;
 516
 517        Ok(())
 518    }
 519
 520    async fn update_buffer(
 521        self: Arc<Server>,
 522        request: TypedEnvelope<proto::UpdateBuffer>,
 523    ) -> tide::Result<()> {
 524        let receiver_ids = self
 525            .state()
 526            .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?;
 527        broadcast(request.sender_id, receiver_ids, |connection_id| {
 528            self.peer
 529                .forward_send(request.sender_id, connection_id, request.payload.clone())
 530        })
 531        .await?;
 532        self.peer.respond(request.receipt(), proto::Ack {}).await?;
 533        Ok(())
 534    }
 535
 536    async fn buffer_saved(
 537        self: Arc<Server>,
 538        request: TypedEnvelope<proto::BufferSaved>,
 539    ) -> tide::Result<()> {
 540        let receiver_ids = self
 541            .state()
 542            .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?;
 543        broadcast(request.sender_id, receiver_ids, |connection_id| {
 544            self.peer
 545                .forward_send(request.sender_id, connection_id, request.payload.clone())
 546        })
 547        .await?;
 548        Ok(())
 549    }
 550
 551    async fn get_channels(
 552        self: Arc<Server>,
 553        request: TypedEnvelope<proto::GetChannels>,
 554    ) -> tide::Result<()> {
 555        let user_id = self.state().user_id_for_connection(request.sender_id)?;
 556        let channels = self.app_state.db.get_accessible_channels(user_id).await?;
 557        self.peer
 558            .respond(
 559                request.receipt(),
 560                proto::GetChannelsResponse {
 561                    channels: channels
 562                        .into_iter()
 563                        .map(|chan| proto::Channel {
 564                            id: chan.id.to_proto(),
 565                            name: chan.name,
 566                        })
 567                        .collect(),
 568                },
 569            )
 570            .await?;
 571        Ok(())
 572    }
 573
 574    async fn get_users(
 575        self: Arc<Server>,
 576        request: TypedEnvelope<proto::GetUsers>,
 577    ) -> tide::Result<()> {
 578        let receipt = request.receipt();
 579        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
 580        let users = self
 581            .app_state
 582            .db
 583            .get_users_by_ids(user_ids)
 584            .await?
 585            .into_iter()
 586            .map(|user| proto::User {
 587                id: user.id.to_proto(),
 588                avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
 589                github_login: user.github_login,
 590            })
 591            .collect();
 592        self.peer
 593            .respond(receipt, proto::GetUsersResponse { users })
 594            .await?;
 595        Ok(())
 596    }
 597
 598    async fn update_collaborators_for_users<'a>(
 599        self: &Arc<Server>,
 600        user_ids: impl IntoIterator<Item = &'a UserId>,
 601    ) -> tide::Result<()> {
 602        let mut send_futures = Vec::new();
 603
 604        {
 605            let state = self.state();
 606            for user_id in user_ids {
 607                let collaborators = state.collaborators_for_user(*user_id);
 608                for connection_id in state.connection_ids_for_user(*user_id) {
 609                    send_futures.push(self.peer.send(
 610                        connection_id,
 611                        proto::UpdateCollaborators {
 612                            collaborators: collaborators.clone(),
 613                        },
 614                    ));
 615                }
 616            }
 617        }
 618        futures::future::try_join_all(send_futures).await?;
 619
 620        Ok(())
 621    }
 622
 623    async fn join_channel(
 624        mut self: Arc<Self>,
 625        request: TypedEnvelope<proto::JoinChannel>,
 626    ) -> tide::Result<()> {
 627        let user_id = self.state().user_id_for_connection(request.sender_id)?;
 628        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 629        if !self
 630            .app_state
 631            .db
 632            .can_user_access_channel(user_id, channel_id)
 633            .await?
 634        {
 635            Err(anyhow!("access denied"))?;
 636        }
 637
 638        self.state_mut().join_channel(request.sender_id, channel_id);
 639        let messages = self
 640            .app_state
 641            .db
 642            .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
 643            .await?
 644            .into_iter()
 645            .map(|msg| proto::ChannelMessage {
 646                id: msg.id.to_proto(),
 647                body: msg.body,
 648                timestamp: msg.sent_at.unix_timestamp() as u64,
 649                sender_id: msg.sender_id.to_proto(),
 650                nonce: Some(msg.nonce.as_u128().into()),
 651            })
 652            .collect::<Vec<_>>();
 653        self.peer
 654            .respond(
 655                request.receipt(),
 656                proto::JoinChannelResponse {
 657                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 658                    messages,
 659                },
 660            )
 661            .await?;
 662        Ok(())
 663    }
 664
 665    async fn leave_channel(
 666        mut self: Arc<Self>,
 667        request: TypedEnvelope<proto::LeaveChannel>,
 668    ) -> tide::Result<()> {
 669        let user_id = self.state().user_id_for_connection(request.sender_id)?;
 670        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 671        if !self
 672            .app_state
 673            .db
 674            .can_user_access_channel(user_id, channel_id)
 675            .await?
 676        {
 677            Err(anyhow!("access denied"))?;
 678        }
 679
 680        self.state_mut()
 681            .leave_channel(request.sender_id, channel_id);
 682
 683        Ok(())
 684    }
 685
 686    async fn send_channel_message(
 687        self: Arc<Self>,
 688        request: TypedEnvelope<proto::SendChannelMessage>,
 689    ) -> tide::Result<()> {
 690        let receipt = request.receipt();
 691        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 692        let user_id;
 693        let connection_ids;
 694        {
 695            let state = self.state();
 696            user_id = state.user_id_for_connection(request.sender_id)?;
 697            if let Some(ids) = state.channel_connection_ids(channel_id) {
 698                connection_ids = ids;
 699            } else {
 700                return Ok(());
 701            }
 702        }
 703
 704        // Validate the message body.
 705        let body = request.payload.body.trim().to_string();
 706        if body.len() > MAX_MESSAGE_LEN {
 707            self.peer
 708                .respond_with_error(
 709                    receipt,
 710                    proto::Error {
 711                        message: "message is too long".to_string(),
 712                    },
 713                )
 714                .await?;
 715            return Ok(());
 716        }
 717        if body.is_empty() {
 718            self.peer
 719                .respond_with_error(
 720                    receipt,
 721                    proto::Error {
 722                        message: "message can't be blank".to_string(),
 723                    },
 724                )
 725                .await?;
 726            return Ok(());
 727        }
 728
 729        let timestamp = OffsetDateTime::now_utc();
 730        let nonce = if let Some(nonce) = request.payload.nonce {
 731            nonce
 732        } else {
 733            self.peer
 734                .respond_with_error(
 735                    receipt,
 736                    proto::Error {
 737                        message: "nonce can't be blank".to_string(),
 738                    },
 739                )
 740                .await?;
 741            return Ok(());
 742        };
 743
 744        let message_id = self
 745            .app_state
 746            .db
 747            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
 748            .await?
 749            .to_proto();
 750        let message = proto::ChannelMessage {
 751            sender_id: user_id.to_proto(),
 752            id: message_id,
 753            body,
 754            timestamp: timestamp.unix_timestamp() as u64,
 755            nonce: Some(nonce),
 756        };
 757        broadcast(request.sender_id, connection_ids, |conn_id| {
 758            self.peer.send(
 759                conn_id,
 760                proto::ChannelMessageSent {
 761                    channel_id: channel_id.to_proto(),
 762                    message: Some(message.clone()),
 763                },
 764            )
 765        })
 766        .await?;
 767        self.peer
 768            .respond(
 769                receipt,
 770                proto::SendChannelMessageResponse {
 771                    message: Some(message),
 772                },
 773            )
 774            .await?;
 775        Ok(())
 776    }
 777
 778    async fn get_channel_messages(
 779        self: Arc<Self>,
 780        request: TypedEnvelope<proto::GetChannelMessages>,
 781    ) -> tide::Result<()> {
 782        let user_id = self.state().user_id_for_connection(request.sender_id)?;
 783        let channel_id = ChannelId::from_proto(request.payload.channel_id);
 784        if !self
 785            .app_state
 786            .db
 787            .can_user_access_channel(user_id, channel_id)
 788            .await?
 789        {
 790            Err(anyhow!("access denied"))?;
 791        }
 792
 793        let messages = self
 794            .app_state
 795            .db
 796            .get_channel_messages(
 797                channel_id,
 798                MESSAGE_COUNT_PER_PAGE,
 799                Some(MessageId::from_proto(request.payload.before_message_id)),
 800            )
 801            .await?
 802            .into_iter()
 803            .map(|msg| proto::ChannelMessage {
 804                id: msg.id.to_proto(),
 805                body: msg.body,
 806                timestamp: msg.sent_at.unix_timestamp() as u64,
 807                sender_id: msg.sender_id.to_proto(),
 808                nonce: Some(msg.nonce.as_u128().into()),
 809            })
 810            .collect::<Vec<_>>();
 811        self.peer
 812            .respond(
 813                request.receipt(),
 814                proto::GetChannelMessagesResponse {
 815                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
 816                    messages,
 817                },
 818            )
 819            .await?;
 820        Ok(())
 821    }
 822
 823    fn state<'a>(self: &'a Arc<Self>) -> RwLockReadGuard<'a, Store> {
 824        self.store.read()
 825    }
 826
 827    fn state_mut<'a>(self: &'a mut Arc<Self>) -> RwLockWriteGuard<'a, Store> {
 828        self.store.write()
 829    }
 830}
 831
 832pub async fn broadcast<F, T>(
 833    sender_id: ConnectionId,
 834    receiver_ids: Vec<ConnectionId>,
 835    mut f: F,
 836) -> anyhow::Result<()>
 837where
 838    F: FnMut(ConnectionId) -> T,
 839    T: Future<Output = anyhow::Result<()>>,
 840{
 841    let futures = receiver_ids
 842        .into_iter()
 843        .filter(|id| *id != sender_id)
 844        .map(|id| f(id));
 845    futures::future::try_join_all(futures).await?;
 846    Ok(())
 847}
 848
 849pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
 850    let server = Server::new(app.state().clone(), rpc.clone(), None);
 851    app.at("/rpc").get(move |request: Request<Arc<AppState>>| {
 852        let server = server.clone();
 853        async move {
 854            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
 855
 856            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
 857            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
 858            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
 859            let client_protocol_version: Option<u32> = request
 860                .header("X-Zed-Protocol-Version")
 861                .and_then(|v| v.as_str().parse().ok());
 862
 863            if !upgrade_requested || client_protocol_version != Some(rpc::PROTOCOL_VERSION) {
 864                return Ok(Response::new(StatusCode::UpgradeRequired));
 865            }
 866
 867            let header = match request.header("Sec-Websocket-Key") {
 868                Some(h) => h.as_str(),
 869                None => return Err(anyhow!("expected sec-websocket-key"))?,
 870            };
 871
 872            let user_id = process_auth_header(&request).await?;
 873
 874            let mut response = Response::new(StatusCode::SwitchingProtocols);
 875            response.insert_header(UPGRADE, "websocket");
 876            response.insert_header(CONNECTION, "Upgrade");
 877            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
 878            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
 879            response.insert_header("Sec-Websocket-Version", "13");
 880
 881            let http_res: &mut tide::http::Response = response.as_mut();
 882            let upgrade_receiver = http_res.recv_upgrade().await;
 883            let addr = request.remote().unwrap_or("unknown").to_string();
 884            task::spawn(async move {
 885                if let Some(stream) = upgrade_receiver.await {
 886                    server
 887                        .handle_connection(
 888                            Connection::new(
 889                                WebSocketStream::from_raw_socket(stream, Role::Server, None).await,
 890                            ),
 891                            addr,
 892                            user_id,
 893                        )
 894                        .await;
 895                }
 896            });
 897
 898            Ok(response)
 899        }
 900    });
 901}
 902
 903fn header_contains_ignore_case<T>(
 904    request: &tide::Request<T>,
 905    header_name: HeaderName,
 906    value: &str,
 907) -> bool {
 908    request
 909        .header(header_name)
 910        .map(|h| {
 911            h.as_str()
 912                .split(',')
 913                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
 914        })
 915        .unwrap_or(false)
 916}
 917
 918#[cfg(test)]
 919mod tests {
 920    use super::*;
 921    use crate::{
 922        auth,
 923        db::{tests::TestDb, UserId},
 924        github, AppState, Config,
 925    };
 926    use ::rpc::Peer;
 927    use async_std::task;
 928    use gpui::{ModelHandle, TestAppContext};
 929    use parking_lot::Mutex;
 930    use postage::{mpsc, watch};
 931    use serde_json::json;
 932    use sqlx::types::time::OffsetDateTime;
 933    use std::{
 934        path::Path,
 935        sync::{
 936            atomic::{AtomicBool, Ordering::SeqCst},
 937            Arc,
 938        },
 939        time::Duration,
 940    };
 941    use zed::{
 942        client::{
 943            self, test::FakeHttpClient, Channel, ChannelDetails, ChannelList, Client, Credentials,
 944            EstablishConnectionError, UserStore,
 945        },
 946        editor::{Editor, EditorSettings, Input},
 947        fs::{FakeFs, Fs as _},
 948        language::{
 949            tree_sitter_rust, Diagnostic, Language, LanguageConfig, LanguageRegistry,
 950            LanguageServerConfig, Point,
 951        },
 952        lsp,
 953        people_panel::JoinWorktree,
 954        project::{ProjectPath, Worktree},
 955        test::test_app_state,
 956        workspace::Workspace,
 957    };
 958
 959    #[gpui::test]
 960    async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
 961        let (window_b, _) = cx_b.add_window(|_| EmptyView);
 962        let lang_registry = Arc::new(LanguageRegistry::new());
 963
 964        // Connect to a server as 2 clients.
 965        let mut server = TestServer::start().await;
 966        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
 967        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
 968
 969        cx_a.foreground().forbid_parking();
 970
 971        // Share a local worktree as client A
 972        let fs = Arc::new(FakeFs::new());
 973        fs.insert_tree(
 974            "/a",
 975            json!({
 976                ".zed.toml": r#"collaborators = ["user_b"]"#,
 977                "a.txt": "a-contents",
 978                "b.txt": "b-contents",
 979            }),
 980        )
 981        .await;
 982        let worktree_a = Worktree::open_local(
 983            client_a.clone(),
 984            "/a".as_ref(),
 985            fs,
 986            lang_registry.clone(),
 987            &mut cx_a.to_async(),
 988        )
 989        .await
 990        .unwrap();
 991        worktree_a
 992            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
 993            .await;
 994        let worktree_id = worktree_a
 995            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
 996            .await
 997            .unwrap();
 998
 999        // Join that worktree as client B, and see that a guest has joined as client A.
1000        let worktree_b = Worktree::open_remote(
1001            client_b.clone(),
1002            worktree_id,
1003            lang_registry.clone(),
1004            &mut cx_b.to_async(),
1005        )
1006        .await
1007        .unwrap();
1008        let replica_id_b = worktree_b.read_with(&cx_b, |tree, _| tree.replica_id());
1009        worktree_a
1010            .condition(&cx_a, |tree, _| {
1011                tree.peers()
1012                    .values()
1013                    .any(|replica_id| *replica_id == replica_id_b)
1014            })
1015            .await;
1016
1017        // Open the same file as client B and client A.
1018        let buffer_b = worktree_b
1019            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("b.txt", cx))
1020            .await
1021            .unwrap();
1022        buffer_b.read_with(&cx_b, |buf, _| assert_eq!(buf.text(), "b-contents"));
1023        worktree_a.read_with(&cx_a, |tree, cx| assert!(tree.has_open_buffer("b.txt", cx)));
1024        let buffer_a = worktree_a
1025            .update(&mut cx_a, |tree, cx| tree.open_buffer("b.txt", cx))
1026            .await
1027            .unwrap();
1028
1029        // Create a selection set as client B and see that selection set as client A.
1030        let editor_b = cx_b.add_view(window_b, |cx| {
1031            Editor::for_buffer(buffer_b, |cx| EditorSettings::test(cx), cx)
1032        });
1033        buffer_a
1034            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 1)
1035            .await;
1036
1037        // Edit the buffer as client B and see that edit as client A.
1038        editor_b.update(&mut cx_b, |editor, cx| {
1039            editor.handle_input(&Input("ok, ".into()), cx)
1040        });
1041        buffer_a
1042            .condition(&cx_a, |buffer, _| buffer.text() == "ok, b-contents")
1043            .await;
1044
1045        // Remove the selection set as client B, see those selections disappear as client A.
1046        cx_b.update(move |_| drop(editor_b));
1047        buffer_a
1048            .condition(&cx_a, |buffer, _| buffer.selection_sets().count() == 0)
1049            .await;
1050
1051        // Close the buffer as client A, see that the buffer is closed.
1052        cx_a.update(move |_| drop(buffer_a));
1053        worktree_a
1054            .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
1055            .await;
1056
1057        // Dropping the worktree removes client B from client A's peers.
1058        cx_b.update(move |_| drop(worktree_b));
1059        worktree_a
1060            .condition(&cx_a, |tree, _| tree.peers().is_empty())
1061            .await;
1062    }
1063
1064    #[gpui::test]
1065    async fn test_unshare_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1066        cx_b.update(zed::people_panel::init);
1067        let mut app_state_a = cx_a.update(test_app_state);
1068        let mut app_state_b = cx_b.update(test_app_state);
1069
1070        // Connect to a server as 2 clients.
1071        let mut server = TestServer::start().await;
1072        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1073        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1074        Arc::get_mut(&mut app_state_a).unwrap().client = client_a;
1075        Arc::get_mut(&mut app_state_a).unwrap().user_store = user_store_a;
1076        Arc::get_mut(&mut app_state_b).unwrap().client = client_b;
1077        Arc::get_mut(&mut app_state_b).unwrap().user_store = user_store_b;
1078
1079        cx_a.foreground().forbid_parking();
1080
1081        // Share a local worktree as client A
1082        let fs = Arc::new(FakeFs::new());
1083        fs.insert_tree(
1084            "/a",
1085            json!({
1086                ".zed.toml": r#"collaborators = ["user_b"]"#,
1087                "a.txt": "a-contents",
1088                "b.txt": "b-contents",
1089            }),
1090        )
1091        .await;
1092        let worktree_a = Worktree::open_local(
1093            app_state_a.client.clone(),
1094            "/a".as_ref(),
1095            fs,
1096            app_state_a.languages.clone(),
1097            &mut cx_a.to_async(),
1098        )
1099        .await
1100        .unwrap();
1101        worktree_a
1102            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1103            .await;
1104
1105        let remote_worktree_id = worktree_a
1106            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1107            .await
1108            .unwrap();
1109
1110        let (window_b, workspace_b) =
1111            cx_b.add_window(|cx| Workspace::new(&app_state_b.as_ref().into(), cx));
1112        cx_b.update(|cx| {
1113            cx.dispatch_action(
1114                window_b,
1115                vec![workspace_b.id()],
1116                &JoinWorktree(remote_worktree_id),
1117            );
1118        });
1119        workspace_b
1120            .condition(&cx_b, |workspace, cx| workspace.worktrees(cx).len() == 1)
1121            .await;
1122
1123        let local_worktree_id_b = workspace_b.read_with(&cx_b, |workspace, cx| {
1124            let active_pane = workspace.active_pane().read(cx);
1125            assert!(active_pane.active_item().is_none());
1126            workspace.worktrees(cx).first().unwrap().id()
1127        });
1128        workspace_b
1129            .update(&mut cx_b, |workspace, cx| {
1130                workspace.open_entry(
1131                    ProjectPath {
1132                        worktree_id: local_worktree_id_b,
1133                        path: Path::new("a.txt").into(),
1134                    },
1135                    cx,
1136                )
1137            })
1138            .unwrap()
1139            .await;
1140        workspace_b.read_with(&cx_b, |workspace, cx| {
1141            let active_pane = workspace.active_pane().read(cx);
1142            assert!(active_pane.active_item().is_some());
1143        });
1144
1145        worktree_a.update(&mut cx_a, |tree, cx| {
1146            tree.as_local_mut().unwrap().unshare(cx);
1147        });
1148        workspace_b
1149            .condition(&cx_b, |workspace, cx| workspace.worktrees(cx).len() == 0)
1150            .await;
1151        workspace_b.read_with(&cx_b, |workspace, cx| {
1152            let active_pane = workspace.active_pane().read(cx);
1153            assert!(active_pane.active_item().is_none());
1154        });
1155    }
1156
1157    #[gpui::test]
1158    async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
1159        mut cx_a: TestAppContext,
1160        mut cx_b: TestAppContext,
1161        mut cx_c: TestAppContext,
1162    ) {
1163        cx_a.foreground().forbid_parking();
1164        let lang_registry = Arc::new(LanguageRegistry::new());
1165
1166        // Connect to a server as 3 clients.
1167        let mut server = TestServer::start().await;
1168        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1169        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1170        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
1171
1172        let fs = Arc::new(FakeFs::new());
1173
1174        // Share a worktree as client A.
1175        fs.insert_tree(
1176            "/a",
1177            json!({
1178                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1179                "file1": "",
1180                "file2": ""
1181            }),
1182        )
1183        .await;
1184
1185        let worktree_a = Worktree::open_local(
1186            client_a.clone(),
1187            "/a".as_ref(),
1188            fs.clone(),
1189            lang_registry.clone(),
1190            &mut cx_a.to_async(),
1191        )
1192        .await
1193        .unwrap();
1194        worktree_a
1195            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1196            .await;
1197        let worktree_id = worktree_a
1198            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1199            .await
1200            .unwrap();
1201
1202        // Join that worktree as clients B and C.
1203        let worktree_b = Worktree::open_remote(
1204            client_b.clone(),
1205            worktree_id,
1206            lang_registry.clone(),
1207            &mut cx_b.to_async(),
1208        )
1209        .await
1210        .unwrap();
1211        let worktree_c = Worktree::open_remote(
1212            client_c.clone(),
1213            worktree_id,
1214            lang_registry.clone(),
1215            &mut cx_c.to_async(),
1216        )
1217        .await
1218        .unwrap();
1219
1220        // Open and edit a buffer as both guests B and C.
1221        let buffer_b = worktree_b
1222            .update(&mut cx_b, |tree, cx| tree.open_buffer("file1", cx))
1223            .await
1224            .unwrap();
1225        let buffer_c = worktree_c
1226            .update(&mut cx_c, |tree, cx| tree.open_buffer("file1", cx))
1227            .await
1228            .unwrap();
1229        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "i-am-b, ", cx));
1230        buffer_c.update(&mut cx_c, |buf, cx| buf.edit([0..0], "i-am-c, ", cx));
1231
1232        // Open and edit that buffer as the host.
1233        let buffer_a = worktree_a
1234            .update(&mut cx_a, |tree, cx| tree.open_buffer("file1", cx))
1235            .await
1236            .unwrap();
1237
1238        buffer_a
1239            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, ")
1240            .await;
1241        buffer_a.update(&mut cx_a, |buf, cx| {
1242            buf.edit([buf.len()..buf.len()], "i-am-a", cx)
1243        });
1244
1245        // Wait for edits to propagate
1246        buffer_a
1247            .condition(&mut cx_a, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1248            .await;
1249        buffer_b
1250            .condition(&mut cx_b, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1251            .await;
1252        buffer_c
1253            .condition(&mut cx_c, |buf, _| buf.text() == "i-am-c, i-am-b, i-am-a")
1254            .await;
1255
1256        // Edit the buffer as the host and concurrently save as guest B.
1257        let save_b = buffer_b.update(&mut cx_b, |buf, cx| buf.save(cx).unwrap());
1258        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "hi-a, ", cx));
1259        save_b.await.unwrap();
1260        assert_eq!(
1261            fs.load("/a/file1".as_ref()).await.unwrap(),
1262            "hi-a, i-am-c, i-am-b, i-am-a"
1263        );
1264        buffer_a.read_with(&cx_a, |buf, _| assert!(!buf.is_dirty()));
1265        buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
1266        buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
1267
1268        // Make changes on host's file system, see those changes on the guests.
1269        fs.rename("/a/file2".as_ref(), "/a/file3".as_ref())
1270            .await
1271            .unwrap();
1272        fs.insert_file(Path::new("/a/file4"), "4".into())
1273            .await
1274            .unwrap();
1275
1276        worktree_b
1277            .condition(&cx_b, |tree, _| tree.file_count() == 4)
1278            .await;
1279        worktree_c
1280            .condition(&cx_c, |tree, _| tree.file_count() == 4)
1281            .await;
1282        worktree_b.read_with(&cx_b, |tree, _| {
1283            assert_eq!(
1284                tree.paths()
1285                    .map(|p| p.to_string_lossy())
1286                    .collect::<Vec<_>>(),
1287                &[".zed.toml", "file1", "file3", "file4"]
1288            )
1289        });
1290        worktree_c.read_with(&cx_c, |tree, _| {
1291            assert_eq!(
1292                tree.paths()
1293                    .map(|p| p.to_string_lossy())
1294                    .collect::<Vec<_>>(),
1295                &[".zed.toml", "file1", "file3", "file4"]
1296            )
1297        });
1298    }
1299
1300    #[gpui::test]
1301    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1302        cx_a.foreground().forbid_parking();
1303        let lang_registry = Arc::new(LanguageRegistry::new());
1304
1305        // Connect to a server as 2 clients.
1306        let mut server = TestServer::start().await;
1307        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1308        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1309
1310        // Share a local worktree as client A
1311        let fs = Arc::new(FakeFs::new());
1312        fs.insert_tree(
1313            "/dir",
1314            json!({
1315                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
1316                "a.txt": "a-contents",
1317            }),
1318        )
1319        .await;
1320
1321        let worktree_a = Worktree::open_local(
1322            client_a.clone(),
1323            "/dir".as_ref(),
1324            fs,
1325            lang_registry.clone(),
1326            &mut cx_a.to_async(),
1327        )
1328        .await
1329        .unwrap();
1330        worktree_a
1331            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1332            .await;
1333        let worktree_id = worktree_a
1334            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1335            .await
1336            .unwrap();
1337
1338        // Join that worktree as client B, and see that a guest has joined as client A.
1339        let worktree_b = Worktree::open_remote(
1340            client_b.clone(),
1341            worktree_id,
1342            lang_registry.clone(),
1343            &mut cx_b.to_async(),
1344        )
1345        .await
1346        .unwrap();
1347
1348        let buffer_b = worktree_b
1349            .update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx))
1350            .await
1351            .unwrap();
1352        let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime());
1353
1354        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
1355        buffer_b.read_with(&cx_b, |buf, _| {
1356            assert!(buf.is_dirty());
1357            assert!(!buf.has_conflict());
1358        });
1359
1360        buffer_b
1361            .update(&mut cx_b, |buf, cx| buf.save(cx))
1362            .unwrap()
1363            .await
1364            .unwrap();
1365        worktree_b
1366            .condition(&cx_b, |_, cx| {
1367                buffer_b.read(cx).file().unwrap().mtime() != mtime
1368            })
1369            .await;
1370        buffer_b.read_with(&cx_b, |buf, _| {
1371            assert!(!buf.is_dirty());
1372            assert!(!buf.has_conflict());
1373        });
1374
1375        buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "hello ", cx));
1376        buffer_b.read_with(&cx_b, |buf, _| {
1377            assert!(buf.is_dirty());
1378            assert!(!buf.has_conflict());
1379        });
1380    }
1381
1382    #[gpui::test]
1383    async fn test_editing_while_guest_opens_buffer(
1384        mut cx_a: TestAppContext,
1385        mut cx_b: TestAppContext,
1386    ) {
1387        cx_a.foreground().forbid_parking();
1388        let lang_registry = Arc::new(LanguageRegistry::new());
1389
1390        // Connect to a server as 2 clients.
1391        let mut server = TestServer::start().await;
1392        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1393        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1394
1395        // Share a local worktree as client A
1396        let fs = Arc::new(FakeFs::new());
1397        fs.insert_tree(
1398            "/dir",
1399            json!({
1400                ".zed.toml": r#"collaborators = ["user_b"]"#,
1401                "a.txt": "a-contents",
1402            }),
1403        )
1404        .await;
1405        let worktree_a = Worktree::open_local(
1406            client_a.clone(),
1407            "/dir".as_ref(),
1408            fs,
1409            lang_registry.clone(),
1410            &mut cx_a.to_async(),
1411        )
1412        .await
1413        .unwrap();
1414        worktree_a
1415            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1416            .await;
1417        let worktree_id = worktree_a
1418            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1419            .await
1420            .unwrap();
1421
1422        // Join that worktree as client B, and see that a guest has joined as client A.
1423        let worktree_b = Worktree::open_remote(
1424            client_b.clone(),
1425            worktree_id,
1426            lang_registry.clone(),
1427            &mut cx_b.to_async(),
1428        )
1429        .await
1430        .unwrap();
1431
1432        let buffer_a = worktree_a
1433            .update(&mut cx_a, |tree, cx| tree.open_buffer("a.txt", cx))
1434            .await
1435            .unwrap();
1436        let buffer_b = cx_b
1437            .background()
1438            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1439
1440        task::yield_now().await;
1441        buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
1442
1443        let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
1444        let buffer_b = buffer_b.await.unwrap();
1445        buffer_b.condition(&cx_b, |buf, _| buf.text() == text).await;
1446    }
1447
1448    #[gpui::test]
1449    async fn test_leaving_worktree_while_opening_buffer(
1450        mut cx_a: TestAppContext,
1451        mut cx_b: TestAppContext,
1452    ) {
1453        cx_a.foreground().forbid_parking();
1454        let lang_registry = Arc::new(LanguageRegistry::new());
1455
1456        // Connect to a server as 2 clients.
1457        let mut server = TestServer::start().await;
1458        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1459        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
1460
1461        // Share a local worktree as client A
1462        let fs = Arc::new(FakeFs::new());
1463        fs.insert_tree(
1464            "/dir",
1465            json!({
1466                ".zed.toml": r#"collaborators = ["user_b"]"#,
1467                "a.txt": "a-contents",
1468            }),
1469        )
1470        .await;
1471        let worktree_a = Worktree::open_local(
1472            client_a.clone(),
1473            "/dir".as_ref(),
1474            fs,
1475            lang_registry.clone(),
1476            &mut cx_a.to_async(),
1477        )
1478        .await
1479        .unwrap();
1480        worktree_a
1481            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1482            .await;
1483        let worktree_id = worktree_a
1484            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1485            .await
1486            .unwrap();
1487
1488        // Join that worktree as client B, and see that a guest has joined as client A.
1489        let worktree_b = Worktree::open_remote(
1490            client_b.clone(),
1491            worktree_id,
1492            lang_registry.clone(),
1493            &mut cx_b.to_async(),
1494        )
1495        .await
1496        .unwrap();
1497        worktree_a
1498            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1499            .await;
1500
1501        let buffer_b = cx_b
1502            .background()
1503            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.txt", cx)));
1504        cx_b.update(|_| drop(worktree_b));
1505        drop(buffer_b);
1506        worktree_a
1507            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1508            .await;
1509    }
1510
1511    #[gpui::test]
1512    async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext) {
1513        cx_a.foreground().forbid_parking();
1514        let lang_registry = Arc::new(LanguageRegistry::new());
1515
1516        // Connect to a server as 2 clients.
1517        let mut server = TestServer::start().await;
1518        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1519        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1520
1521        // Share a local worktree as client A
1522        let fs = Arc::new(FakeFs::new());
1523        fs.insert_tree(
1524            "/a",
1525            json!({
1526                ".zed.toml": r#"collaborators = ["user_b"]"#,
1527                "a.txt": "a-contents",
1528                "b.txt": "b-contents",
1529            }),
1530        )
1531        .await;
1532        let worktree_a = Worktree::open_local(
1533            client_a.clone(),
1534            "/a".as_ref(),
1535            fs,
1536            lang_registry.clone(),
1537            &mut cx_a.to_async(),
1538        )
1539        .await
1540        .unwrap();
1541        worktree_a
1542            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1543            .await;
1544        let worktree_id = worktree_a
1545            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1546            .await
1547            .unwrap();
1548
1549        // Join that worktree as client B, and see that a guest has joined as client A.
1550        let _worktree_b = Worktree::open_remote(
1551            client_b.clone(),
1552            worktree_id,
1553            lang_registry.clone(),
1554            &mut cx_b.to_async(),
1555        )
1556        .await
1557        .unwrap();
1558        worktree_a
1559            .condition(&cx_a, |tree, _| tree.peers().len() == 1)
1560            .await;
1561
1562        // Drop client B's connection and ensure client A observes client B leaving the worktree.
1563        client_b.disconnect(&cx_b.to_async()).await.unwrap();
1564        worktree_a
1565            .condition(&cx_a, |tree, _| tree.peers().len() == 0)
1566            .await;
1567    }
1568
1569    #[gpui::test]
1570    async fn test_collaborating_with_diagnostics(
1571        mut cx_a: TestAppContext,
1572        mut cx_b: TestAppContext,
1573    ) {
1574        cx_a.foreground().forbid_parking();
1575        let (language_server_config, mut fake_language_server) =
1576            LanguageServerConfig::fake(cx_a.background()).await;
1577        let mut lang_registry = LanguageRegistry::new();
1578        lang_registry.add(Arc::new(Language::new(
1579            LanguageConfig {
1580                name: "Rust".to_string(),
1581                path_suffixes: vec!["rs".to_string()],
1582                language_server: Some(language_server_config),
1583                ..Default::default()
1584            },
1585            tree_sitter_rust::language(),
1586        )));
1587
1588        let lang_registry = Arc::new(lang_registry);
1589
1590        // Connect to a server as 2 clients.
1591        let mut server = TestServer::start().await;
1592        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
1593        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
1594
1595        // Share a local worktree as client A
1596        let fs = Arc::new(FakeFs::new());
1597        fs.insert_tree(
1598            "/a",
1599            json!({
1600                ".zed.toml": r#"collaborators = ["user_b"]"#,
1601                "a.rs": "let one = two",
1602                "other.rs": "",
1603            }),
1604        )
1605        .await;
1606        let worktree_a = Worktree::open_local(
1607            client_a.clone(),
1608            "/a".as_ref(),
1609            fs,
1610            lang_registry.clone(),
1611            &mut cx_a.to_async(),
1612        )
1613        .await
1614        .unwrap();
1615        worktree_a
1616            .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
1617            .await;
1618        let worktree_id = worktree_a
1619            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
1620            .await
1621            .unwrap();
1622
1623        // Cause language server to start.
1624        let _ = cx_a
1625            .background()
1626            .spawn(worktree_a.update(&mut cx_a, |worktree, cx| {
1627                worktree.open_buffer("other.rs", cx)
1628            }))
1629            .await
1630            .unwrap();
1631
1632        // Simulate a language server reporting errors for a file.
1633        fake_language_server
1634            .notify::<lsp::notification::PublishDiagnostics>(lsp::PublishDiagnosticsParams {
1635                uri: lsp::Url::from_file_path("/a/a.rs").unwrap(),
1636                version: None,
1637                diagnostics: vec![
1638                    lsp::Diagnostic {
1639                        severity: Some(lsp::DiagnosticSeverity::ERROR),
1640                        range: lsp::Range::new(lsp::Position::new(0, 4), lsp::Position::new(0, 7)),
1641                        message: "message 1".to_string(),
1642                        ..Default::default()
1643                    },
1644                    lsp::Diagnostic {
1645                        severity: Some(lsp::DiagnosticSeverity::WARNING),
1646                        range: lsp::Range::new(
1647                            lsp::Position::new(0, 10),
1648                            lsp::Position::new(0, 13),
1649                        ),
1650                        message: "message 2".to_string(),
1651                        ..Default::default()
1652                    },
1653                ],
1654            })
1655            .await;
1656
1657        // Join the worktree as client B.
1658        let worktree_b = Worktree::open_remote(
1659            client_b.clone(),
1660            worktree_id,
1661            lang_registry.clone(),
1662            &mut cx_b.to_async(),
1663        )
1664        .await
1665        .unwrap();
1666
1667        // Open the file with the errors.
1668        let buffer_b = cx_b
1669            .background()
1670            .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.rs", cx)))
1671            .await
1672            .unwrap();
1673
1674        buffer_b.read_with(&cx_b, |buffer, _| {
1675            assert_eq!(
1676                buffer
1677                    .diagnostics_in_range(0..buffer.len())
1678                    .collect::<Vec<_>>(),
1679                &[
1680                    (
1681                        Point::new(0, 4)..Point::new(0, 7),
1682                        &Diagnostic {
1683                            group_id: 0,
1684                            message: "message 1".to_string(),
1685                            severity: lsp::DiagnosticSeverity::ERROR,
1686                            is_primary: true
1687                        }
1688                    ),
1689                    (
1690                        Point::new(0, 10)..Point::new(0, 13),
1691                        &Diagnostic {
1692                            group_id: 1,
1693                            severity: lsp::DiagnosticSeverity::WARNING,
1694                            message: "message 2".to_string(),
1695                            is_primary: true
1696                        }
1697                    )
1698                ]
1699            );
1700        });
1701    }
1702
1703    #[gpui::test]
1704    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1705        cx_a.foreground().forbid_parking();
1706
1707        // Connect to a server as 2 clients.
1708        let mut server = TestServer::start().await;
1709        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1710        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1711
1712        // Create an org that includes these 2 users.
1713        let db = &server.app_state.db;
1714        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1715        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1716            .await
1717            .unwrap();
1718        db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1719            .await
1720            .unwrap();
1721
1722        // Create a channel that includes all the users.
1723        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1724        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1725            .await
1726            .unwrap();
1727        db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1728            .await
1729            .unwrap();
1730        db.create_channel_message(
1731            channel_id,
1732            current_user_id(&user_store_b, &cx_b),
1733            "hello A, it's B.",
1734            OffsetDateTime::now_utc(),
1735            1,
1736        )
1737        .await
1738        .unwrap();
1739
1740        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1741        channels_a
1742            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1743            .await;
1744        channels_a.read_with(&cx_a, |list, _| {
1745            assert_eq!(
1746                list.available_channels().unwrap(),
1747                &[ChannelDetails {
1748                    id: channel_id.to_proto(),
1749                    name: "test-channel".to_string()
1750                }]
1751            )
1752        });
1753        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1754            this.get_channel(channel_id.to_proto(), cx).unwrap()
1755        });
1756        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1757        channel_a
1758            .condition(&cx_a, |channel, _| {
1759                channel_messages(channel)
1760                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1761            })
1762            .await;
1763
1764        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
1765        channels_b
1766            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1767            .await;
1768        channels_b.read_with(&cx_b, |list, _| {
1769            assert_eq!(
1770                list.available_channels().unwrap(),
1771                &[ChannelDetails {
1772                    id: channel_id.to_proto(),
1773                    name: "test-channel".to_string()
1774                }]
1775            )
1776        });
1777
1778        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1779            this.get_channel(channel_id.to_proto(), cx).unwrap()
1780        });
1781        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1782        channel_b
1783            .condition(&cx_b, |channel, _| {
1784                channel_messages(channel)
1785                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1786            })
1787            .await;
1788
1789        channel_a
1790            .update(&mut cx_a, |channel, cx| {
1791                channel
1792                    .send_message("oh, hi B.".to_string(), cx)
1793                    .unwrap()
1794                    .detach();
1795                let task = channel.send_message("sup".to_string(), cx).unwrap();
1796                assert_eq!(
1797                    channel_messages(channel),
1798                    &[
1799                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1800                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
1801                        ("user_a".to_string(), "sup".to_string(), true)
1802                    ]
1803                );
1804                task
1805            })
1806            .await
1807            .unwrap();
1808
1809        channel_b
1810            .condition(&cx_b, |channel, _| {
1811                channel_messages(channel)
1812                    == [
1813                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
1814                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
1815                        ("user_a".to_string(), "sup".to_string(), false),
1816                    ]
1817            })
1818            .await;
1819
1820        assert_eq!(
1821            server
1822                .state()
1823                .await
1824                .channel(channel_id)
1825                .unwrap()
1826                .connection_ids
1827                .len(),
1828            2
1829        );
1830        cx_b.update(|_| drop(channel_b));
1831        server
1832            .condition(|state| state.channel(channel_id).unwrap().connection_ids.len() == 1)
1833            .await;
1834
1835        cx_a.update(|_| drop(channel_a));
1836        server
1837            .condition(|state| state.channel(channel_id).is_none())
1838            .await;
1839    }
1840
1841    #[gpui::test]
1842    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
1843        cx_a.foreground().forbid_parking();
1844
1845        let mut server = TestServer::start().await;
1846        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1847
1848        let db = &server.app_state.db;
1849        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1850        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1851        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1852            .await
1853            .unwrap();
1854        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1855            .await
1856            .unwrap();
1857
1858        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1859        channels_a
1860            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1861            .await;
1862        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1863            this.get_channel(channel_id.to_proto(), cx).unwrap()
1864        });
1865
1866        // Messages aren't allowed to be too long.
1867        channel_a
1868            .update(&mut cx_a, |channel, cx| {
1869                let long_body = "this is long.\n".repeat(1024);
1870                channel.send_message(long_body, cx).unwrap()
1871            })
1872            .await
1873            .unwrap_err();
1874
1875        // Messages aren't allowed to be blank.
1876        channel_a.update(&mut cx_a, |channel, cx| {
1877            channel.send_message(String::new(), cx).unwrap_err()
1878        });
1879
1880        // Leading and trailing whitespace are trimmed.
1881        channel_a
1882            .update(&mut cx_a, |channel, cx| {
1883                channel
1884                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
1885                    .unwrap()
1886            })
1887            .await
1888            .unwrap();
1889        assert_eq!(
1890            db.get_channel_messages(channel_id, 10, None)
1891                .await
1892                .unwrap()
1893                .iter()
1894                .map(|m| &m.body)
1895                .collect::<Vec<_>>(),
1896            &["surrounded by whitespace"]
1897        );
1898    }
1899
1900    #[gpui::test]
1901    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
1902        cx_a.foreground().forbid_parking();
1903
1904        // Connect to a server as 2 clients.
1905        let mut server = TestServer::start().await;
1906        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
1907        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
1908        let mut status_b = client_b.status();
1909
1910        // Create an org that includes these 2 users.
1911        let db = &server.app_state.db;
1912        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
1913        db.add_org_member(org_id, current_user_id(&user_store_a, &cx_a), false)
1914            .await
1915            .unwrap();
1916        db.add_org_member(org_id, current_user_id(&user_store_b, &cx_b), false)
1917            .await
1918            .unwrap();
1919
1920        // Create a channel that includes all the users.
1921        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
1922        db.add_channel_member(channel_id, current_user_id(&user_store_a, &cx_a), false)
1923            .await
1924            .unwrap();
1925        db.add_channel_member(channel_id, current_user_id(&user_store_b, &cx_b), false)
1926            .await
1927            .unwrap();
1928        db.create_channel_message(
1929            channel_id,
1930            current_user_id(&user_store_b, &cx_b),
1931            "hello A, it's B.",
1932            OffsetDateTime::now_utc(),
1933            2,
1934        )
1935        .await
1936        .unwrap();
1937
1938        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
1939        channels_a
1940            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
1941            .await;
1942
1943        channels_a.read_with(&cx_a, |list, _| {
1944            assert_eq!(
1945                list.available_channels().unwrap(),
1946                &[ChannelDetails {
1947                    id: channel_id.to_proto(),
1948                    name: "test-channel".to_string()
1949                }]
1950            )
1951        });
1952        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
1953            this.get_channel(channel_id.to_proto(), cx).unwrap()
1954        });
1955        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
1956        channel_a
1957            .condition(&cx_a, |channel, _| {
1958                channel_messages(channel)
1959                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1960            })
1961            .await;
1962
1963        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
1964        channels_b
1965            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
1966            .await;
1967        channels_b.read_with(&cx_b, |list, _| {
1968            assert_eq!(
1969                list.available_channels().unwrap(),
1970                &[ChannelDetails {
1971                    id: channel_id.to_proto(),
1972                    name: "test-channel".to_string()
1973                }]
1974            )
1975        });
1976
1977        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
1978            this.get_channel(channel_id.to_proto(), cx).unwrap()
1979        });
1980        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
1981        channel_b
1982            .condition(&cx_b, |channel, _| {
1983                channel_messages(channel)
1984                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
1985            })
1986            .await;
1987
1988        // Disconnect client B, ensuring we can still access its cached channel data.
1989        server.forbid_connections();
1990        server.disconnect_client(current_user_id(&user_store_b, &cx_b));
1991        while !matches!(
1992            status_b.recv().await,
1993            Some(client::Status::ReconnectionError { .. })
1994        ) {}
1995
1996        channels_b.read_with(&cx_b, |channels, _| {
1997            assert_eq!(
1998                channels.available_channels().unwrap(),
1999                [ChannelDetails {
2000                    id: channel_id.to_proto(),
2001                    name: "test-channel".to_string()
2002                }]
2003            )
2004        });
2005        channel_b.read_with(&cx_b, |channel, _| {
2006            assert_eq!(
2007                channel_messages(channel),
2008                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
2009            )
2010        });
2011
2012        // Send a message from client B while it is disconnected.
2013        channel_b
2014            .update(&mut cx_b, |channel, cx| {
2015                let task = channel
2016                    .send_message("can you see this?".to_string(), cx)
2017                    .unwrap();
2018                assert_eq!(
2019                    channel_messages(channel),
2020                    &[
2021                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2022                        ("user_b".to_string(), "can you see this?".to_string(), true)
2023                    ]
2024                );
2025                task
2026            })
2027            .await
2028            .unwrap_err();
2029
2030        // Send a message from client A while B is disconnected.
2031        channel_a
2032            .update(&mut cx_a, |channel, cx| {
2033                channel
2034                    .send_message("oh, hi B.".to_string(), cx)
2035                    .unwrap()
2036                    .detach();
2037                let task = channel.send_message("sup".to_string(), cx).unwrap();
2038                assert_eq!(
2039                    channel_messages(channel),
2040                    &[
2041                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2042                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
2043                        ("user_a".to_string(), "sup".to_string(), true)
2044                    ]
2045                );
2046                task
2047            })
2048            .await
2049            .unwrap();
2050
2051        // Give client B a chance to reconnect.
2052        server.allow_connections();
2053        cx_b.foreground().advance_clock(Duration::from_secs(10));
2054
2055        // Verify that B sees the new messages upon reconnection, as well as the message client B
2056        // sent while offline.
2057        channel_b
2058            .condition(&cx_b, |channel, _| {
2059                channel_messages(channel)
2060                    == [
2061                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2062                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2063                        ("user_a".to_string(), "sup".to_string(), false),
2064                        ("user_b".to_string(), "can you see this?".to_string(), false),
2065                    ]
2066            })
2067            .await;
2068
2069        // Ensure client A and B can communicate normally after reconnection.
2070        channel_a
2071            .update(&mut cx_a, |channel, cx| {
2072                channel.send_message("you online?".to_string(), cx).unwrap()
2073            })
2074            .await
2075            .unwrap();
2076        channel_b
2077            .condition(&cx_b, |channel, _| {
2078                channel_messages(channel)
2079                    == [
2080                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2081                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2082                        ("user_a".to_string(), "sup".to_string(), false),
2083                        ("user_b".to_string(), "can you see this?".to_string(), false),
2084                        ("user_a".to_string(), "you online?".to_string(), false),
2085                    ]
2086            })
2087            .await;
2088
2089        channel_b
2090            .update(&mut cx_b, |channel, cx| {
2091                channel.send_message("yep".to_string(), cx).unwrap()
2092            })
2093            .await
2094            .unwrap();
2095        channel_a
2096            .condition(&cx_a, |channel, _| {
2097                channel_messages(channel)
2098                    == [
2099                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
2100                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
2101                        ("user_a".to_string(), "sup".to_string(), false),
2102                        ("user_b".to_string(), "can you see this?".to_string(), false),
2103                        ("user_a".to_string(), "you online?".to_string(), false),
2104                        ("user_b".to_string(), "yep".to_string(), false),
2105                    ]
2106            })
2107            .await;
2108    }
2109
2110    #[gpui::test]
2111    async fn test_collaborators(
2112        mut cx_a: TestAppContext,
2113        mut cx_b: TestAppContext,
2114        mut cx_c: TestAppContext,
2115    ) {
2116        cx_a.foreground().forbid_parking();
2117        let lang_registry = Arc::new(LanguageRegistry::new());
2118
2119        // Connect to a server as 3 clients.
2120        let mut server = TestServer::start().await;
2121        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
2122        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
2123        let (_client_c, user_store_c) = server.create_client(&mut cx_c, "user_c").await;
2124
2125        let fs = Arc::new(FakeFs::new());
2126
2127        // Share a worktree as client A.
2128        fs.insert_tree(
2129            "/a",
2130            json!({
2131                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
2132            }),
2133        )
2134        .await;
2135
2136        let worktree_a = Worktree::open_local(
2137            client_a.clone(),
2138            "/a".as_ref(),
2139            fs.clone(),
2140            lang_registry.clone(),
2141            &mut cx_a.to_async(),
2142        )
2143        .await
2144        .unwrap();
2145
2146        user_store_a
2147            .condition(&cx_a, |user_store, _| {
2148                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2149            })
2150            .await;
2151        user_store_b
2152            .condition(&cx_b, |user_store, _| {
2153                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2154            })
2155            .await;
2156        user_store_c
2157            .condition(&cx_c, |user_store, _| {
2158                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
2159            })
2160            .await;
2161
2162        let worktree_id = worktree_a
2163            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
2164            .await
2165            .unwrap();
2166
2167        let _worktree_b = Worktree::open_remote(
2168            client_b.clone(),
2169            worktree_id,
2170            lang_registry.clone(),
2171            &mut cx_b.to_async(),
2172        )
2173        .await
2174        .unwrap();
2175
2176        user_store_a
2177            .condition(&cx_a, |user_store, _| {
2178                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2179            })
2180            .await;
2181        user_store_b
2182            .condition(&cx_b, |user_store, _| {
2183                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2184            })
2185            .await;
2186        user_store_c
2187            .condition(&cx_c, |user_store, _| {
2188                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
2189            })
2190            .await;
2191
2192        cx_a.update(move |_| drop(worktree_a));
2193        user_store_a
2194            .condition(&cx_a, |user_store, _| collaborators(user_store) == vec![])
2195            .await;
2196        user_store_b
2197            .condition(&cx_b, |user_store, _| collaborators(user_store) == vec![])
2198            .await;
2199        user_store_c
2200            .condition(&cx_c, |user_store, _| collaborators(user_store) == vec![])
2201            .await;
2202
2203        fn collaborators(user_store: &UserStore) -> Vec<(&str, Vec<(&str, Vec<&str>)>)> {
2204            user_store
2205                .collaborators()
2206                .iter()
2207                .map(|collaborator| {
2208                    let worktrees = collaborator
2209                        .worktrees
2210                        .iter()
2211                        .map(|w| {
2212                            (
2213                                w.root_name.as_str(),
2214                                w.guests.iter().map(|p| p.github_login.as_str()).collect(),
2215                            )
2216                        })
2217                        .collect();
2218                    (collaborator.user.github_login.as_str(), worktrees)
2219                })
2220                .collect()
2221        }
2222    }
2223
2224    struct TestServer {
2225        peer: Arc<Peer>,
2226        app_state: Arc<AppState>,
2227        server: Arc<Server>,
2228        notifications: mpsc::Receiver<()>,
2229        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
2230        forbid_connections: Arc<AtomicBool>,
2231        _test_db: TestDb,
2232    }
2233
2234    impl TestServer {
2235        async fn start() -> Self {
2236            let test_db = TestDb::new();
2237            let app_state = Self::build_app_state(&test_db).await;
2238            let peer = Peer::new();
2239            let notifications = mpsc::channel(128);
2240            let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
2241            Self {
2242                peer,
2243                app_state,
2244                server,
2245                notifications: notifications.1,
2246                connection_killers: Default::default(),
2247                forbid_connections: Default::default(),
2248                _test_db: test_db,
2249            }
2250        }
2251
2252        async fn create_client(
2253            &mut self,
2254            cx: &mut TestAppContext,
2255            name: &str,
2256        ) -> (Arc<Client>, ModelHandle<UserStore>) {
2257            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
2258            let client_name = name.to_string();
2259            let mut client = Client::new();
2260            let server = self.server.clone();
2261            let connection_killers = self.connection_killers.clone();
2262            let forbid_connections = self.forbid_connections.clone();
2263            Arc::get_mut(&mut client)
2264                .unwrap()
2265                .override_authenticate(move |cx| {
2266                    cx.spawn(|_| async move {
2267                        let access_token = "the-token".to_string();
2268                        Ok(Credentials {
2269                            user_id: user_id.0 as u64,
2270                            access_token,
2271                        })
2272                    })
2273                })
2274                .override_establish_connection(move |credentials, cx| {
2275                    assert_eq!(credentials.user_id, user_id.0 as u64);
2276                    assert_eq!(credentials.access_token, "the-token");
2277
2278                    let server = server.clone();
2279                    let connection_killers = connection_killers.clone();
2280                    let forbid_connections = forbid_connections.clone();
2281                    let client_name = client_name.clone();
2282                    cx.spawn(move |cx| async move {
2283                        if forbid_connections.load(SeqCst) {
2284                            Err(EstablishConnectionError::other(anyhow!(
2285                                "server is forbidding connections"
2286                            )))
2287                        } else {
2288                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
2289                            connection_killers.lock().insert(user_id, kill_conn);
2290                            cx.background()
2291                                .spawn(server.handle_connection(server_conn, client_name, user_id))
2292                                .detach();
2293                            Ok(client_conn)
2294                        }
2295                    })
2296                });
2297
2298            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
2299            client
2300                .authenticate_and_connect(&cx.to_async())
2301                .await
2302                .unwrap();
2303
2304            let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
2305            let mut authed_user =
2306                user_store.read_with(cx, |user_store, _| user_store.watch_current_user());
2307            while authed_user.recv().await.unwrap().is_none() {}
2308
2309            (client, user_store)
2310        }
2311
2312        fn disconnect_client(&self, user_id: UserId) {
2313            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
2314                let _ = kill_conn.try_send(Some(()));
2315            }
2316        }
2317
2318        fn forbid_connections(&self) {
2319            self.forbid_connections.store(true, SeqCst);
2320        }
2321
2322        fn allow_connections(&self) {
2323            self.forbid_connections.store(false, SeqCst);
2324        }
2325
2326        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
2327            let mut config = Config::default();
2328            config.session_secret = "a".repeat(32);
2329            config.database_url = test_db.url.clone();
2330            let github_client = github::AppClient::test();
2331            Arc::new(AppState {
2332                db: test_db.db().clone(),
2333                handlebars: Default::default(),
2334                auth_client: auth::build_client("", ""),
2335                repo_client: github::RepoClient::test(&github_client),
2336                github_client,
2337                config,
2338            })
2339        }
2340
2341        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
2342            self.server.store.read()
2343        }
2344
2345        async fn condition<F>(&mut self, mut predicate: F)
2346        where
2347            F: FnMut(&Store) -> bool,
2348        {
2349            async_std::future::timeout(Duration::from_millis(500), async {
2350                while !(predicate)(&*self.server.store.read()) {
2351                    self.notifications.recv().await;
2352                }
2353            })
2354            .await
2355            .expect("condition timed out");
2356        }
2357    }
2358
2359    impl Drop for TestServer {
2360        fn drop(&mut self) {
2361            task::block_on(self.peer.reset());
2362        }
2363    }
2364
2365    fn current_user_id(user_store: &ModelHandle<UserStore>, cx: &TestAppContext) -> UserId {
2366        UserId::from_proto(
2367            user_store.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
2368        )
2369    }
2370
2371    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
2372        channel
2373            .messages()
2374            .cursor::<()>()
2375            .map(|m| {
2376                (
2377                    m.sender.github_login.clone(),
2378                    m.body.clone(),
2379                    m.is_pending(),
2380                )
2381            })
2382            .collect()
2383    }
2384
2385    struct EmptyView;
2386
2387    impl gpui::Entity for EmptyView {
2388        type Event = ();
2389    }
2390
2391    impl gpui::View for EmptyView {
2392        fn ui_name() -> &'static str {
2393            "empty view"
2394        }
2395
2396        fn render(&mut self, _: &mut gpui::RenderContext<Self>) -> gpui::ElementBox {
2397            gpui::Element::boxed(gpui::elements::Empty)
2398        }
2399    }
2400}