rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User,
   7        UserId,
   8    },
   9    executor::Executor,
  10    AppState, Result,
  11};
  12use anyhow::anyhow;
  13use async_tungstenite::tungstenite::{
  14    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  15};
  16use axum::{
  17    body::Body,
  18    extract::{
  19        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  20        ConnectInfo, WebSocketUpgrade,
  21    },
  22    headers::{Header, HeaderName},
  23    http::StatusCode,
  24    middleware,
  25    response::IntoResponse,
  26    routing::get,
  27    Extension, Router, TypedHeader,
  28};
  29use collections::{HashMap, HashSet};
  30pub use connection_pool::ConnectionPool;
  31use futures::{
  32    channel::oneshot,
  33    future::{self, BoxFuture},
  34    stream::FuturesUnordered,
  35    FutureExt, SinkExt, StreamExt, TryStreamExt,
  36};
  37use lazy_static::lazy_static;
  38use prometheus::{register_int_gauge, IntGauge};
  39use rpc::{
  40    proto::{
  41        self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
  42        LiveKitConnectionInfo, RequestMessage,
  43    },
  44    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  45};
  46use serde::{Serialize, Serializer};
  47use std::{
  48    any::TypeId,
  49    fmt,
  50    future::Future,
  51    marker::PhantomData,
  52    mem,
  53    net::SocketAddr,
  54    ops::{Deref, DerefMut},
  55    rc::Rc,
  56    sync::{
  57        atomic::{AtomicBool, Ordering::SeqCst},
  58        Arc,
  59    },
  60    time::{Duration, Instant},
  61};
  62use time::OffsetDateTime;
  63use tokio::sync::{watch, Semaphore};
  64use tower::ServiceBuilder;
  65use tracing::{info_span, instrument, Instrument};
  66
  67pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  68pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  69
  70const MESSAGE_COUNT_PER_PAGE: usize = 100;
  71const MAX_MESSAGE_LEN: usize = 1024;
  72
  73lazy_static! {
  74    static ref METRIC_CONNECTIONS: IntGauge =
  75        register_int_gauge!("connections", "number of connections").unwrap();
  76    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  77        "shared_projects",
  78        "number of open projects with one or more guests"
  79    )
  80    .unwrap();
  81}
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100#[derive(Clone)]
 101struct Session {
 102    user_id: UserId,
 103    connection_id: ConnectionId,
 104    db: Arc<tokio::sync::Mutex<DbHandle>>,
 105    peer: Arc<Peer>,
 106    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 107    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 108    executor: Executor,
 109}
 110
 111impl Session {
 112    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        let guard = self.db.lock().await;
 116        #[cfg(test)]
 117        tokio::task::yield_now().await;
 118        guard
 119    }
 120
 121    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 122        #[cfg(test)]
 123        tokio::task::yield_now().await;
 124        let guard = self.connection_pool.lock();
 125        ConnectionPoolGuard {
 126            guard,
 127            _not_send: PhantomData,
 128        }
 129    }
 130}
 131
 132impl fmt::Debug for Session {
 133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 134        f.debug_struct("Session")
 135            .field("user_id", &self.user_id)
 136            .field("connection_id", &self.connection_id)
 137            .finish()
 138    }
 139}
 140
 141struct DbHandle(Arc<Database>);
 142
 143impl Deref for DbHandle {
 144    type Target = Database;
 145
 146    fn deref(&self) -> &Self::Target {
 147        self.0.as_ref()
 148    }
 149}
 150
 151pub struct Server {
 152    id: parking_lot::Mutex<ServerId>,
 153    peer: Arc<Peer>,
 154    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 155    app_state: Arc<AppState>,
 156    executor: Executor,
 157    handlers: HashMap<TypeId, MessageHandler>,
 158    teardown: watch::Sender<()>,
 159}
 160
 161pub(crate) struct ConnectionPoolGuard<'a> {
 162    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 163    _not_send: PhantomData<Rc<()>>,
 164}
 165
 166#[derive(Serialize)]
 167pub struct ServerSnapshot<'a> {
 168    peer: &'a Peer,
 169    #[serde(serialize_with = "serialize_deref")]
 170    connection_pool: ConnectionPoolGuard<'a>,
 171}
 172
 173pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 174where
 175    S: Serializer,
 176    T: Deref<Target = U>,
 177    U: Serialize,
 178{
 179    Serialize::serialize(value.deref(), serializer)
 180}
 181
 182impl Server {
 183    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 184        let mut server = Self {
 185            id: parking_lot::Mutex::new(id),
 186            peer: Peer::new(id.0 as u32),
 187            app_state,
 188            executor,
 189            connection_pool: Default::default(),
 190            handlers: Default::default(),
 191            teardown: watch::channel(()).0,
 192        };
 193
 194        server
 195            .add_request_handler(ping)
 196            .add_request_handler(create_room)
 197            .add_request_handler(join_room)
 198            .add_request_handler(rejoin_room)
 199            .add_request_handler(leave_room)
 200            .add_request_handler(call)
 201            .add_request_handler(cancel_call)
 202            .add_message_handler(decline_call)
 203            .add_request_handler(update_participant_location)
 204            .add_request_handler(share_project)
 205            .add_message_handler(unshare_project)
 206            .add_request_handler(join_project)
 207            .add_message_handler(leave_project)
 208            .add_request_handler(update_project)
 209            .add_request_handler(update_worktree)
 210            .add_message_handler(start_language_server)
 211            .add_message_handler(update_language_server)
 212            .add_message_handler(update_diagnostic_summary)
 213            .add_message_handler(update_worktree_settings)
 214            .add_message_handler(refresh_inlay_hints)
 215            .add_request_handler(forward_project_request::<proto::GetHover>)
 216            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 217            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 218            .add_request_handler(forward_project_request::<proto::GetReferences>)
 219            .add_request_handler(forward_project_request::<proto::SearchProject>)
 220            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 221            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 222            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 223            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 224            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 225            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 226            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 227            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 228            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 229            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 230            .add_request_handler(forward_project_request::<proto::PerformRename>)
 231            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 232            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
 233            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 234            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 235            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 236            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 237            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 238            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
 239            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
 240            .add_request_handler(forward_project_request::<proto::InlayHints>)
 241            .add_message_handler(create_buffer_for_peer)
 242            .add_request_handler(update_buffer)
 243            .add_message_handler(update_buffer_file)
 244            .add_message_handler(buffer_reloaded)
 245            .add_message_handler(buffer_saved)
 246            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
 247            .add_request_handler(get_users)
 248            .add_request_handler(fuzzy_search_users)
 249            .add_request_handler(request_contact)
 250            .add_request_handler(remove_contact)
 251            .add_request_handler(respond_to_contact_request)
 252            .add_request_handler(create_channel)
 253            .add_request_handler(remove_channel)
 254            .add_request_handler(invite_channel_member)
 255            .add_request_handler(remove_channel_member)
 256            .add_request_handler(set_channel_member_admin)
 257            .add_request_handler(rename_channel)
 258            .add_request_handler(join_channel_buffer)
 259            .add_request_handler(leave_channel_buffer)
 260            .add_message_handler(update_channel_buffer)
 261            .add_request_handler(rejoin_channel_buffers)
 262            .add_request_handler(get_channel_members)
 263            .add_request_handler(respond_to_channel_invite)
 264            .add_request_handler(join_channel)
 265            .add_request_handler(join_channel_chat)
 266            .add_message_handler(leave_channel_chat)
 267            .add_request_handler(send_channel_message)
 268            .add_request_handler(get_channel_messages)
 269            .add_request_handler(follow)
 270            .add_message_handler(unfollow)
 271            .add_message_handler(update_followers)
 272            .add_message_handler(update_diff_base)
 273            .add_request_handler(get_private_user_info);
 274
 275        Arc::new(server)
 276    }
 277
 278    pub async fn start(&self) -> Result<()> {
 279        let server_id = *self.id.lock();
 280        let app_state = self.app_state.clone();
 281        let peer = self.peer.clone();
 282        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 283        let pool = self.connection_pool.clone();
 284        let live_kit_client = self.app_state.live_kit_client.clone();
 285
 286        let span = info_span!("start server");
 287        self.executor.spawn_detached(
 288            async move {
 289                tracing::info!("waiting for cleanup timeout");
 290                timeout.await;
 291                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 292                if let Some((room_ids, channel_ids)) = app_state
 293                    .db
 294                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 295                    .await
 296                    .trace_err()
 297                {
 298                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 299                    tracing::info!(
 300                        stale_channel_buffer_count = channel_ids.len(),
 301                        "retrieved stale channel buffers"
 302                    );
 303
 304                    for channel_id in channel_ids {
 305                        if let Some(refreshed_channel_buffer) = app_state
 306                            .db
 307                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 308                            .await
 309                            .trace_err()
 310                        {
 311                            for connection_id in refreshed_channel_buffer.connection_ids {
 312                                for message in &refreshed_channel_buffer.removed_collaborators {
 313                                    peer.send(connection_id, message.clone()).trace_err();
 314                                }
 315                            }
 316                        }
 317                    }
 318
 319                    for room_id in room_ids {
 320                        let mut contacts_to_update = HashSet::default();
 321                        let mut canceled_calls_to_user_ids = Vec::new();
 322                        let mut live_kit_room = String::new();
 323                        let mut delete_live_kit_room = false;
 324
 325                        if let Some(mut refreshed_room) = app_state
 326                            .db
 327                            .clear_stale_room_participants(room_id, server_id)
 328                            .await
 329                            .trace_err()
 330                        {
 331                            tracing::info!(
 332                                room_id = room_id.0,
 333                                new_participant_count = refreshed_room.room.participants.len(),
 334                                "refreshed room"
 335                            );
 336                            room_updated(&refreshed_room.room, &peer);
 337                            if let Some(channel_id) = refreshed_room.channel_id {
 338                                channel_updated(
 339                                    channel_id,
 340                                    &refreshed_room.room,
 341                                    &refreshed_room.channel_members,
 342                                    &peer,
 343                                    &*pool.lock(),
 344                                );
 345                            }
 346                            contacts_to_update
 347                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 348                            contacts_to_update
 349                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 350                            canceled_calls_to_user_ids =
 351                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 352                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 353                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 354                        }
 355
 356                        {
 357                            let pool = pool.lock();
 358                            for canceled_user_id in canceled_calls_to_user_ids {
 359                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 360                                    peer.send(
 361                                        connection_id,
 362                                        proto::CallCanceled {
 363                                            room_id: room_id.to_proto(),
 364                                        },
 365                                    )
 366                                    .trace_err();
 367                                }
 368                            }
 369                        }
 370
 371                        for user_id in contacts_to_update {
 372                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 373                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 374                            if let Some((busy, contacts)) = busy.zip(contacts) {
 375                                let pool = pool.lock();
 376                                let updated_contact = contact_for_user(user_id, false, busy, &pool);
 377                                for contact in contacts {
 378                                    if let db::Contact::Accepted {
 379                                        user_id: contact_user_id,
 380                                        ..
 381                                    } = contact
 382                                    {
 383                                        for contact_conn_id in
 384                                            pool.user_connection_ids(contact_user_id)
 385                                        {
 386                                            peer.send(
 387                                                contact_conn_id,
 388                                                proto::UpdateContacts {
 389                                                    contacts: vec![updated_contact.clone()],
 390                                                    remove_contacts: Default::default(),
 391                                                    incoming_requests: Default::default(),
 392                                                    remove_incoming_requests: Default::default(),
 393                                                    outgoing_requests: Default::default(),
 394                                                    remove_outgoing_requests: Default::default(),
 395                                                },
 396                                            )
 397                                            .trace_err();
 398                                        }
 399                                    }
 400                                }
 401                            }
 402                        }
 403
 404                        if let Some(live_kit) = live_kit_client.as_ref() {
 405                            if delete_live_kit_room {
 406                                live_kit.delete_room(live_kit_room).await.trace_err();
 407                            }
 408                        }
 409                    }
 410                }
 411
 412                app_state
 413                    .db
 414                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 415                    .await
 416                    .trace_err();
 417            }
 418            .instrument(span),
 419        );
 420        Ok(())
 421    }
 422
 423    pub fn teardown(&self) {
 424        self.peer.teardown();
 425        self.connection_pool.lock().reset();
 426        let _ = self.teardown.send(());
 427    }
 428
 429    #[cfg(test)]
 430    pub fn reset(&self, id: ServerId) {
 431        self.teardown();
 432        *self.id.lock() = id;
 433        self.peer.reset(id.0 as u32);
 434    }
 435
 436    #[cfg(test)]
 437    pub fn id(&self) -> ServerId {
 438        *self.id.lock()
 439    }
 440
 441    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 442    where
 443        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 444        Fut: 'static + Send + Future<Output = Result<()>>,
 445        M: EnvelopedMessage,
 446    {
 447        let prev_handler = self.handlers.insert(
 448            TypeId::of::<M>(),
 449            Box::new(move |envelope, session| {
 450                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 451                let span = info_span!(
 452                    "handle message",
 453                    payload_type = envelope.payload_type_name()
 454                );
 455                span.in_scope(|| {
 456                    tracing::info!(
 457                        payload_type = envelope.payload_type_name(),
 458                        "message received"
 459                    );
 460                });
 461                let start_time = Instant::now();
 462                let future = (handler)(*envelope, session);
 463                async move {
 464                    let result = future.await;
 465                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 466                    match result {
 467                        Err(error) => {
 468                            tracing::error!(%error, ?duration_ms, "error handling message")
 469                        }
 470                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 471                    }
 472                }
 473                .instrument(span)
 474                .boxed()
 475            }),
 476        );
 477        if prev_handler.is_some() {
 478            panic!("registered a handler for the same message twice");
 479        }
 480        self
 481    }
 482
 483    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 484    where
 485        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 486        Fut: 'static + Send + Future<Output = Result<()>>,
 487        M: EnvelopedMessage,
 488    {
 489        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 490        self
 491    }
 492
 493    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 494    where
 495        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 496        Fut: Send + Future<Output = Result<()>>,
 497        M: RequestMessage,
 498    {
 499        let handler = Arc::new(handler);
 500        self.add_handler(move |envelope, session| {
 501            let receipt = envelope.receipt();
 502            let handler = handler.clone();
 503            async move {
 504                let peer = session.peer.clone();
 505                let responded = Arc::new(AtomicBool::default());
 506                let response = Response {
 507                    peer: peer.clone(),
 508                    responded: responded.clone(),
 509                    receipt,
 510                };
 511                match (handler)(envelope.payload, response, session).await {
 512                    Ok(()) => {
 513                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 514                            Ok(())
 515                        } else {
 516                            Err(anyhow!("handler did not send a response"))?
 517                        }
 518                    }
 519                    Err(error) => {
 520                        peer.respond_with_error(
 521                            receipt,
 522                            proto::Error {
 523                                message: error.to_string(),
 524                            },
 525                        )?;
 526                        Err(error)
 527                    }
 528                }
 529            }
 530        })
 531    }
 532
 533    pub fn handle_connection(
 534        self: &Arc<Self>,
 535        connection: Connection,
 536        address: String,
 537        user: User,
 538        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 539        executor: Executor,
 540    ) -> impl Future<Output = Result<()>> {
 541        let this = self.clone();
 542        let user_id = user.id;
 543        let login = user.github_login;
 544        let span = info_span!("handle connection", %user_id, %login, %address);
 545        let mut teardown = self.teardown.subscribe();
 546        async move {
 547            let (connection_id, handle_io, mut incoming_rx) = this
 548                .peer
 549                .add_connection(connection, {
 550                    let executor = executor.clone();
 551                    move |duration| executor.sleep(duration)
 552                });
 553
 554            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 555            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 556            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 557
 558            if let Some(send_connection_id) = send_connection_id.take() {
 559                let _ = send_connection_id.send(connection_id);
 560            }
 561
 562            if !user.connected_once {
 563                this.peer.send(connection_id, proto::ShowContacts {})?;
 564                this.app_state.db.set_user_connected_once(user_id, true).await?;
 565            }
 566
 567            let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
 568                this.app_state.db.get_contacts(user_id),
 569                this.app_state.db.get_invite_code_for_user(user_id),
 570                this.app_state.db.get_channels_for_user(user_id),
 571                this.app_state.db.get_channel_invites_for_user(user_id)
 572            ).await?;
 573
 574            {
 575                let mut pool = this.connection_pool.lock();
 576                pool.add_connection(connection_id, user_id, user.admin);
 577                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 578                this.peer.send(connection_id, build_initial_channels_update(
 579                    channels_for_user,
 580                    channel_invites
 581                ))?;
 582
 583                if let Some((code, count)) = invite_code {
 584                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 585                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 586                        count: count as u32,
 587                    })?;
 588                }
 589            }
 590
 591            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 592                this.peer.send(connection_id, incoming_call)?;
 593            }
 594
 595            let session = Session {
 596                user_id,
 597                connection_id,
 598                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 599                peer: this.peer.clone(),
 600                connection_pool: this.connection_pool.clone(),
 601                live_kit_client: this.app_state.live_kit_client.clone(),
 602                executor: executor.clone(),
 603            };
 604            update_user_contacts(user_id, &session).await?;
 605
 606            let handle_io = handle_io.fuse();
 607            futures::pin_mut!(handle_io);
 608
 609            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 610            // This prevents deadlocks when e.g., client A performs a request to client B and
 611            // client B performs a request to client A. If both clients stop processing further
 612            // messages until their respective request completes, they won't have a chance to
 613            // respond to the other client's request and cause a deadlock.
 614            //
 615            // This arrangement ensures we will attempt to process earlier messages first, but fall
 616            // back to processing messages arrived later in the spirit of making progress.
 617            let mut foreground_message_handlers = FuturesUnordered::new();
 618            let concurrent_handlers = Arc::new(Semaphore::new(256));
 619            loop {
 620                let next_message = async {
 621                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 622                    let message = incoming_rx.next().await;
 623                    (permit, message)
 624                }.fuse();
 625                futures::pin_mut!(next_message);
 626                futures::select_biased! {
 627                    _ = teardown.changed().fuse() => return Ok(()),
 628                    result = handle_io => {
 629                        if let Err(error) = result {
 630                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 631                        }
 632                        break;
 633                    }
 634                    _ = foreground_message_handlers.next() => {}
 635                    next_message = next_message => {
 636                        let (permit, message) = next_message;
 637                        if let Some(message) = message {
 638                            let type_name = message.payload_type_name();
 639                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 640                            let span_enter = span.enter();
 641                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 642                                let is_background = message.is_background();
 643                                let handle_message = (handler)(message, session.clone());
 644                                drop(span_enter);
 645
 646                                let handle_message = async move {
 647                                    handle_message.await;
 648                                    drop(permit);
 649                                }.instrument(span);
 650                                if is_background {
 651                                    executor.spawn_detached(handle_message);
 652                                } else {
 653                                    foreground_message_handlers.push(handle_message);
 654                                }
 655                            } else {
 656                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 657                            }
 658                        } else {
 659                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 660                            break;
 661                        }
 662                    }
 663                }
 664            }
 665
 666            drop(foreground_message_handlers);
 667            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 668            if let Err(error) = connection_lost(session, teardown, executor).await {
 669                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 670            }
 671
 672            Ok(())
 673        }.instrument(span)
 674    }
 675
 676    pub async fn invite_code_redeemed(
 677        self: &Arc<Self>,
 678        inviter_id: UserId,
 679        invitee_id: UserId,
 680    ) -> Result<()> {
 681        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 682            if let Some(code) = &user.invite_code {
 683                let pool = self.connection_pool.lock();
 684                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 685                for connection_id in pool.user_connection_ids(inviter_id) {
 686                    self.peer.send(
 687                        connection_id,
 688                        proto::UpdateContacts {
 689                            contacts: vec![invitee_contact.clone()],
 690                            ..Default::default()
 691                        },
 692                    )?;
 693                    self.peer.send(
 694                        connection_id,
 695                        proto::UpdateInviteInfo {
 696                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 697                            count: user.invite_count as u32,
 698                        },
 699                    )?;
 700                }
 701            }
 702        }
 703        Ok(())
 704    }
 705
 706    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 707        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 708            if let Some(invite_code) = &user.invite_code {
 709                let pool = self.connection_pool.lock();
 710                for connection_id in pool.user_connection_ids(user_id) {
 711                    self.peer.send(
 712                        connection_id,
 713                        proto::UpdateInviteInfo {
 714                            url: format!(
 715                                "{}{}",
 716                                self.app_state.config.invite_link_prefix, invite_code
 717                            ),
 718                            count: user.invite_count as u32,
 719                        },
 720                    )?;
 721                }
 722            }
 723        }
 724        Ok(())
 725    }
 726
 727    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 728        ServerSnapshot {
 729            connection_pool: ConnectionPoolGuard {
 730                guard: self.connection_pool.lock(),
 731                _not_send: PhantomData,
 732            },
 733            peer: &self.peer,
 734        }
 735    }
 736}
 737
 738impl<'a> Deref for ConnectionPoolGuard<'a> {
 739    type Target = ConnectionPool;
 740
 741    fn deref(&self) -> &Self::Target {
 742        &*self.guard
 743    }
 744}
 745
 746impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 747    fn deref_mut(&mut self) -> &mut Self::Target {
 748        &mut *self.guard
 749    }
 750}
 751
 752impl<'a> Drop for ConnectionPoolGuard<'a> {
 753    fn drop(&mut self) {
 754        #[cfg(test)]
 755        self.check_invariants();
 756    }
 757}
 758
 759fn broadcast<F>(
 760    sender_id: Option<ConnectionId>,
 761    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 762    mut f: F,
 763) where
 764    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 765{
 766    for receiver_id in receiver_ids {
 767        if Some(receiver_id) != sender_id {
 768            if let Err(error) = f(receiver_id) {
 769                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 770            }
 771        }
 772    }
 773}
 774
 775lazy_static! {
 776    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 777}
 778
 779pub struct ProtocolVersion(u32);
 780
 781impl Header for ProtocolVersion {
 782    fn name() -> &'static HeaderName {
 783        &ZED_PROTOCOL_VERSION
 784    }
 785
 786    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 787    where
 788        Self: Sized,
 789        I: Iterator<Item = &'i axum::http::HeaderValue>,
 790    {
 791        let version = values
 792            .next()
 793            .ok_or_else(axum::headers::Error::invalid)?
 794            .to_str()
 795            .map_err(|_| axum::headers::Error::invalid())?
 796            .parse()
 797            .map_err(|_| axum::headers::Error::invalid())?;
 798        Ok(Self(version))
 799    }
 800
 801    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 802        values.extend([self.0.to_string().parse().unwrap()]);
 803    }
 804}
 805
 806pub fn routes(server: Arc<Server>) -> Router<Body> {
 807    Router::new()
 808        .route("/rpc", get(handle_websocket_request))
 809        .layer(
 810            ServiceBuilder::new()
 811                .layer(Extension(server.app_state.clone()))
 812                .layer(middleware::from_fn(auth::validate_header)),
 813        )
 814        .route("/metrics", get(handle_metrics))
 815        .layer(Extension(server))
 816}
 817
 818pub async fn handle_websocket_request(
 819    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 820    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 821    Extension(server): Extension<Arc<Server>>,
 822    Extension(user): Extension<User>,
 823    ws: WebSocketUpgrade,
 824) -> axum::response::Response {
 825    if protocol_version != rpc::PROTOCOL_VERSION {
 826        return (
 827            StatusCode::UPGRADE_REQUIRED,
 828            "client must be upgraded".to_string(),
 829        )
 830            .into_response();
 831    }
 832    let socket_address = socket_address.to_string();
 833    ws.on_upgrade(move |socket| {
 834        use util::ResultExt;
 835        let socket = socket
 836            .map_ok(to_tungstenite_message)
 837            .err_into()
 838            .with(|message| async move { Ok(to_axum_message(message)) });
 839        let connection = Connection::new(Box::pin(socket));
 840        async move {
 841            server
 842                .handle_connection(connection, socket_address, user, None, Executor::Production)
 843                .await
 844                .log_err();
 845        }
 846    })
 847}
 848
 849pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 850    let connections = server
 851        .connection_pool
 852        .lock()
 853        .connections()
 854        .filter(|connection| !connection.admin)
 855        .count();
 856
 857    METRIC_CONNECTIONS.set(connections as _);
 858
 859    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 860    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 861
 862    let encoder = prometheus::TextEncoder::new();
 863    let metric_families = prometheus::gather();
 864    let encoded_metrics = encoder
 865        .encode_to_string(&metric_families)
 866        .map_err(|err| anyhow!("{}", err))?;
 867    Ok(encoded_metrics)
 868}
 869
 870#[instrument(err, skip(executor))]
 871async fn connection_lost(
 872    session: Session,
 873    mut teardown: watch::Receiver<()>,
 874    executor: Executor,
 875) -> Result<()> {
 876    session.peer.disconnect(session.connection_id);
 877    session
 878        .connection_pool()
 879        .await
 880        .remove_connection(session.connection_id)?;
 881
 882    session
 883        .db()
 884        .await
 885        .connection_lost(session.connection_id)
 886        .await
 887        .trace_err();
 888
 889    futures::select_biased! {
 890        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 891            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 892            leave_room_for_session(&session).await.trace_err();
 893            leave_channel_buffers_for_session(&session)
 894                .await
 895                .trace_err();
 896
 897            if !session
 898                .connection_pool()
 899                .await
 900                .is_user_online(session.user_id)
 901            {
 902                let db = session.db().await;
 903                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 904                    room_updated(&room, &session.peer);
 905                }
 906            }
 907
 908            update_user_contacts(session.user_id, &session).await?;
 909        }
 910        _ = teardown.changed().fuse() => {}
 911    }
 912
 913    Ok(())
 914}
 915
 916async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 917    response.send(proto::Ack {})?;
 918    Ok(())
 919}
 920
 921async fn create_room(
 922    _request: proto::CreateRoom,
 923    response: Response<proto::CreateRoom>,
 924    session: Session,
 925) -> Result<()> {
 926    let live_kit_room = nanoid::nanoid!(30);
 927
 928    let live_kit_connection_info = {
 929        let live_kit_room = live_kit_room.clone();
 930        let live_kit = session.live_kit_client.as_ref();
 931
 932        util::async_iife!({
 933            let live_kit = live_kit?;
 934
 935            live_kit
 936                .create_room(live_kit_room.clone())
 937                .await
 938                .trace_err()?;
 939
 940            let token = live_kit
 941                .room_token(&live_kit_room, &session.user_id.to_string())
 942                .trace_err()?;
 943
 944            Some(proto::LiveKitConnectionInfo {
 945                server_url: live_kit.url().into(),
 946                token,
 947            })
 948        })
 949    }
 950    .await;
 951
 952    let room = session
 953        .db()
 954        .await
 955        .create_room(session.user_id, session.connection_id, &live_kit_room)
 956        .await?;
 957
 958    response.send(proto::CreateRoomResponse {
 959        room: Some(room.clone()),
 960        live_kit_connection_info,
 961    })?;
 962
 963    update_user_contacts(session.user_id, &session).await?;
 964    Ok(())
 965}
 966
 967async fn join_room(
 968    request: proto::JoinRoom,
 969    response: Response<proto::JoinRoom>,
 970    session: Session,
 971) -> Result<()> {
 972    let room_id = RoomId::from_proto(request.id);
 973    let joined_room = {
 974        let room = session
 975            .db()
 976            .await
 977            .join_room(room_id, session.user_id, session.connection_id)
 978            .await?;
 979        room_updated(&room.room, &session.peer);
 980        room.into_inner()
 981    };
 982
 983    if let Some(channel_id) = joined_room.channel_id {
 984        channel_updated(
 985            channel_id,
 986            &joined_room.room,
 987            &joined_room.channel_members,
 988            &session.peer,
 989            &*session.connection_pool().await,
 990        )
 991    }
 992
 993    for connection_id in session
 994        .connection_pool()
 995        .await
 996        .user_connection_ids(session.user_id)
 997    {
 998        session
 999            .peer
1000            .send(
1001                connection_id,
1002                proto::CallCanceled {
1003                    room_id: room_id.to_proto(),
1004                },
1005            )
1006            .trace_err();
1007    }
1008
1009    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1010        if let Some(token) = live_kit
1011            .room_token(
1012                &joined_room.room.live_kit_room,
1013                &session.user_id.to_string(),
1014            )
1015            .trace_err()
1016        {
1017            Some(proto::LiveKitConnectionInfo {
1018                server_url: live_kit.url().into(),
1019                token,
1020            })
1021        } else {
1022            None
1023        }
1024    } else {
1025        None
1026    };
1027
1028    response.send(proto::JoinRoomResponse {
1029        room: Some(joined_room.room),
1030        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1031        live_kit_connection_info,
1032    })?;
1033
1034    update_user_contacts(session.user_id, &session).await?;
1035    Ok(())
1036}
1037
1038async fn rejoin_room(
1039    request: proto::RejoinRoom,
1040    response: Response<proto::RejoinRoom>,
1041    session: Session,
1042) -> Result<()> {
1043    let room;
1044    let channel_id;
1045    let channel_members;
1046    {
1047        let mut rejoined_room = session
1048            .db()
1049            .await
1050            .rejoin_room(request, session.user_id, session.connection_id)
1051            .await?;
1052
1053        response.send(proto::RejoinRoomResponse {
1054            room: Some(rejoined_room.room.clone()),
1055            reshared_projects: rejoined_room
1056                .reshared_projects
1057                .iter()
1058                .map(|project| proto::ResharedProject {
1059                    id: project.id.to_proto(),
1060                    collaborators: project
1061                        .collaborators
1062                        .iter()
1063                        .map(|collaborator| collaborator.to_proto())
1064                        .collect(),
1065                })
1066                .collect(),
1067            rejoined_projects: rejoined_room
1068                .rejoined_projects
1069                .iter()
1070                .map(|rejoined_project| proto::RejoinedProject {
1071                    id: rejoined_project.id.to_proto(),
1072                    worktrees: rejoined_project
1073                        .worktrees
1074                        .iter()
1075                        .map(|worktree| proto::WorktreeMetadata {
1076                            id: worktree.id,
1077                            root_name: worktree.root_name.clone(),
1078                            visible: worktree.visible,
1079                            abs_path: worktree.abs_path.clone(),
1080                        })
1081                        .collect(),
1082                    collaborators: rejoined_project
1083                        .collaborators
1084                        .iter()
1085                        .map(|collaborator| collaborator.to_proto())
1086                        .collect(),
1087                    language_servers: rejoined_project.language_servers.clone(),
1088                })
1089                .collect(),
1090        })?;
1091        room_updated(&rejoined_room.room, &session.peer);
1092
1093        for project in &rejoined_room.reshared_projects {
1094            for collaborator in &project.collaborators {
1095                session
1096                    .peer
1097                    .send(
1098                        collaborator.connection_id,
1099                        proto::UpdateProjectCollaborator {
1100                            project_id: project.id.to_proto(),
1101                            old_peer_id: Some(project.old_connection_id.into()),
1102                            new_peer_id: Some(session.connection_id.into()),
1103                        },
1104                    )
1105                    .trace_err();
1106            }
1107
1108            broadcast(
1109                Some(session.connection_id),
1110                project
1111                    .collaborators
1112                    .iter()
1113                    .map(|collaborator| collaborator.connection_id),
1114                |connection_id| {
1115                    session.peer.forward_send(
1116                        session.connection_id,
1117                        connection_id,
1118                        proto::UpdateProject {
1119                            project_id: project.id.to_proto(),
1120                            worktrees: project.worktrees.clone(),
1121                        },
1122                    )
1123                },
1124            );
1125        }
1126
1127        for project in &rejoined_room.rejoined_projects {
1128            for collaborator in &project.collaborators {
1129                session
1130                    .peer
1131                    .send(
1132                        collaborator.connection_id,
1133                        proto::UpdateProjectCollaborator {
1134                            project_id: project.id.to_proto(),
1135                            old_peer_id: Some(project.old_connection_id.into()),
1136                            new_peer_id: Some(session.connection_id.into()),
1137                        },
1138                    )
1139                    .trace_err();
1140            }
1141        }
1142
1143        for project in &mut rejoined_room.rejoined_projects {
1144            for worktree in mem::take(&mut project.worktrees) {
1145                #[cfg(any(test, feature = "test-support"))]
1146                const MAX_CHUNK_SIZE: usize = 2;
1147                #[cfg(not(any(test, feature = "test-support")))]
1148                const MAX_CHUNK_SIZE: usize = 256;
1149
1150                // Stream this worktree's entries.
1151                let message = proto::UpdateWorktree {
1152                    project_id: project.id.to_proto(),
1153                    worktree_id: worktree.id,
1154                    abs_path: worktree.abs_path.clone(),
1155                    root_name: worktree.root_name,
1156                    updated_entries: worktree.updated_entries,
1157                    removed_entries: worktree.removed_entries,
1158                    scan_id: worktree.scan_id,
1159                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1160                    updated_repositories: worktree.updated_repositories,
1161                    removed_repositories: worktree.removed_repositories,
1162                };
1163                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1164                    session.peer.send(session.connection_id, update.clone())?;
1165                }
1166
1167                // Stream this worktree's diagnostics.
1168                for summary in worktree.diagnostic_summaries {
1169                    session.peer.send(
1170                        session.connection_id,
1171                        proto::UpdateDiagnosticSummary {
1172                            project_id: project.id.to_proto(),
1173                            worktree_id: worktree.id,
1174                            summary: Some(summary),
1175                        },
1176                    )?;
1177                }
1178
1179                for settings_file in worktree.settings_files {
1180                    session.peer.send(
1181                        session.connection_id,
1182                        proto::UpdateWorktreeSettings {
1183                            project_id: project.id.to_proto(),
1184                            worktree_id: worktree.id,
1185                            path: settings_file.path,
1186                            content: Some(settings_file.content),
1187                        },
1188                    )?;
1189                }
1190            }
1191
1192            for language_server in &project.language_servers {
1193                session.peer.send(
1194                    session.connection_id,
1195                    proto::UpdateLanguageServer {
1196                        project_id: project.id.to_proto(),
1197                        language_server_id: language_server.id,
1198                        variant: Some(
1199                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1200                                proto::LspDiskBasedDiagnosticsUpdated {},
1201                            ),
1202                        ),
1203                    },
1204                )?;
1205            }
1206        }
1207
1208        let rejoined_room = rejoined_room.into_inner();
1209
1210        room = rejoined_room.room;
1211        channel_id = rejoined_room.channel_id;
1212        channel_members = rejoined_room.channel_members;
1213    }
1214
1215    if let Some(channel_id) = channel_id {
1216        channel_updated(
1217            channel_id,
1218            &room,
1219            &channel_members,
1220            &session.peer,
1221            &*session.connection_pool().await,
1222        );
1223    }
1224
1225    update_user_contacts(session.user_id, &session).await?;
1226    Ok(())
1227}
1228
1229async fn leave_room(
1230    _: proto::LeaveRoom,
1231    response: Response<proto::LeaveRoom>,
1232    session: Session,
1233) -> Result<()> {
1234    leave_room_for_session(&session).await?;
1235    response.send(proto::Ack {})?;
1236    Ok(())
1237}
1238
1239async fn call(
1240    request: proto::Call,
1241    response: Response<proto::Call>,
1242    session: Session,
1243) -> Result<()> {
1244    let room_id = RoomId::from_proto(request.room_id);
1245    let calling_user_id = session.user_id;
1246    let calling_connection_id = session.connection_id;
1247    let called_user_id = UserId::from_proto(request.called_user_id);
1248    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1249    if !session
1250        .db()
1251        .await
1252        .has_contact(calling_user_id, called_user_id)
1253        .await?
1254    {
1255        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1256    }
1257
1258    let incoming_call = {
1259        let (room, incoming_call) = &mut *session
1260            .db()
1261            .await
1262            .call(
1263                room_id,
1264                calling_user_id,
1265                calling_connection_id,
1266                called_user_id,
1267                initial_project_id,
1268            )
1269            .await?;
1270        room_updated(&room, &session.peer);
1271        mem::take(incoming_call)
1272    };
1273    update_user_contacts(called_user_id, &session).await?;
1274
1275    let mut calls = session
1276        .connection_pool()
1277        .await
1278        .user_connection_ids(called_user_id)
1279        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1280        .collect::<FuturesUnordered<_>>();
1281
1282    while let Some(call_response) = calls.next().await {
1283        match call_response.as_ref() {
1284            Ok(_) => {
1285                response.send(proto::Ack {})?;
1286                return Ok(());
1287            }
1288            Err(_) => {
1289                call_response.trace_err();
1290            }
1291        }
1292    }
1293
1294    {
1295        let room = session
1296            .db()
1297            .await
1298            .call_failed(room_id, called_user_id)
1299            .await?;
1300        room_updated(&room, &session.peer);
1301    }
1302    update_user_contacts(called_user_id, &session).await?;
1303
1304    Err(anyhow!("failed to ring user"))?
1305}
1306
1307async fn cancel_call(
1308    request: proto::CancelCall,
1309    response: Response<proto::CancelCall>,
1310    session: Session,
1311) -> Result<()> {
1312    let called_user_id = UserId::from_proto(request.called_user_id);
1313    let room_id = RoomId::from_proto(request.room_id);
1314    {
1315        let room = session
1316            .db()
1317            .await
1318            .cancel_call(room_id, session.connection_id, called_user_id)
1319            .await?;
1320        room_updated(&room, &session.peer);
1321    }
1322
1323    for connection_id in session
1324        .connection_pool()
1325        .await
1326        .user_connection_ids(called_user_id)
1327    {
1328        session
1329            .peer
1330            .send(
1331                connection_id,
1332                proto::CallCanceled {
1333                    room_id: room_id.to_proto(),
1334                },
1335            )
1336            .trace_err();
1337    }
1338    response.send(proto::Ack {})?;
1339
1340    update_user_contacts(called_user_id, &session).await?;
1341    Ok(())
1342}
1343
1344async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1345    let room_id = RoomId::from_proto(message.room_id);
1346    {
1347        let room = session
1348            .db()
1349            .await
1350            .decline_call(Some(room_id), session.user_id)
1351            .await?
1352            .ok_or_else(|| anyhow!("failed to decline call"))?;
1353        room_updated(&room, &session.peer);
1354    }
1355
1356    for connection_id in session
1357        .connection_pool()
1358        .await
1359        .user_connection_ids(session.user_id)
1360    {
1361        session
1362            .peer
1363            .send(
1364                connection_id,
1365                proto::CallCanceled {
1366                    room_id: room_id.to_proto(),
1367                },
1368            )
1369            .trace_err();
1370    }
1371    update_user_contacts(session.user_id, &session).await?;
1372    Ok(())
1373}
1374
1375async fn update_participant_location(
1376    request: proto::UpdateParticipantLocation,
1377    response: Response<proto::UpdateParticipantLocation>,
1378    session: Session,
1379) -> Result<()> {
1380    let room_id = RoomId::from_proto(request.room_id);
1381    let location = request
1382        .location
1383        .ok_or_else(|| anyhow!("invalid location"))?;
1384
1385    let db = session.db().await;
1386    let room = db
1387        .update_room_participant_location(room_id, session.connection_id, location)
1388        .await?;
1389
1390    room_updated(&room, &session.peer);
1391    response.send(proto::Ack {})?;
1392    Ok(())
1393}
1394
1395async fn share_project(
1396    request: proto::ShareProject,
1397    response: Response<proto::ShareProject>,
1398    session: Session,
1399) -> Result<()> {
1400    let (project_id, room) = &*session
1401        .db()
1402        .await
1403        .share_project(
1404            RoomId::from_proto(request.room_id),
1405            session.connection_id,
1406            &request.worktrees,
1407        )
1408        .await?;
1409    response.send(proto::ShareProjectResponse {
1410        project_id: project_id.to_proto(),
1411    })?;
1412    room_updated(&room, &session.peer);
1413
1414    Ok(())
1415}
1416
1417async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1418    let project_id = ProjectId::from_proto(message.project_id);
1419
1420    let (room, guest_connection_ids) = &*session
1421        .db()
1422        .await
1423        .unshare_project(project_id, session.connection_id)
1424        .await?;
1425
1426    broadcast(
1427        Some(session.connection_id),
1428        guest_connection_ids.iter().copied(),
1429        |conn_id| session.peer.send(conn_id, message.clone()),
1430    );
1431    room_updated(&room, &session.peer);
1432
1433    Ok(())
1434}
1435
1436async fn join_project(
1437    request: proto::JoinProject,
1438    response: Response<proto::JoinProject>,
1439    session: Session,
1440) -> Result<()> {
1441    let project_id = ProjectId::from_proto(request.project_id);
1442    let guest_user_id = session.user_id;
1443
1444    tracing::info!(%project_id, "join project");
1445
1446    let (project, replica_id) = &mut *session
1447        .db()
1448        .await
1449        .join_project(project_id, session.connection_id)
1450        .await?;
1451
1452    let collaborators = project
1453        .collaborators
1454        .iter()
1455        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1456        .map(|collaborator| collaborator.to_proto())
1457        .collect::<Vec<_>>();
1458
1459    let worktrees = project
1460        .worktrees
1461        .iter()
1462        .map(|(id, worktree)| proto::WorktreeMetadata {
1463            id: *id,
1464            root_name: worktree.root_name.clone(),
1465            visible: worktree.visible,
1466            abs_path: worktree.abs_path.clone(),
1467        })
1468        .collect::<Vec<_>>();
1469
1470    for collaborator in &collaborators {
1471        session
1472            .peer
1473            .send(
1474                collaborator.peer_id.unwrap().into(),
1475                proto::AddProjectCollaborator {
1476                    project_id: project_id.to_proto(),
1477                    collaborator: Some(proto::Collaborator {
1478                        peer_id: Some(session.connection_id.into()),
1479                        replica_id: replica_id.0 as u32,
1480                        user_id: guest_user_id.to_proto(),
1481                    }),
1482                },
1483            )
1484            .trace_err();
1485    }
1486
1487    // First, we send the metadata associated with each worktree.
1488    response.send(proto::JoinProjectResponse {
1489        worktrees: worktrees.clone(),
1490        replica_id: replica_id.0 as u32,
1491        collaborators: collaborators.clone(),
1492        language_servers: project.language_servers.clone(),
1493    })?;
1494
1495    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1496        #[cfg(any(test, feature = "test-support"))]
1497        const MAX_CHUNK_SIZE: usize = 2;
1498        #[cfg(not(any(test, feature = "test-support")))]
1499        const MAX_CHUNK_SIZE: usize = 256;
1500
1501        // Stream this worktree's entries.
1502        let message = proto::UpdateWorktree {
1503            project_id: project_id.to_proto(),
1504            worktree_id,
1505            abs_path: worktree.abs_path.clone(),
1506            root_name: worktree.root_name,
1507            updated_entries: worktree.entries,
1508            removed_entries: Default::default(),
1509            scan_id: worktree.scan_id,
1510            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1511            updated_repositories: worktree.repository_entries.into_values().collect(),
1512            removed_repositories: Default::default(),
1513        };
1514        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1515            session.peer.send(session.connection_id, update.clone())?;
1516        }
1517
1518        // Stream this worktree's diagnostics.
1519        for summary in worktree.diagnostic_summaries {
1520            session.peer.send(
1521                session.connection_id,
1522                proto::UpdateDiagnosticSummary {
1523                    project_id: project_id.to_proto(),
1524                    worktree_id: worktree.id,
1525                    summary: Some(summary),
1526                },
1527            )?;
1528        }
1529
1530        for settings_file in worktree.settings_files {
1531            session.peer.send(
1532                session.connection_id,
1533                proto::UpdateWorktreeSettings {
1534                    project_id: project_id.to_proto(),
1535                    worktree_id: worktree.id,
1536                    path: settings_file.path,
1537                    content: Some(settings_file.content),
1538                },
1539            )?;
1540        }
1541    }
1542
1543    for language_server in &project.language_servers {
1544        session.peer.send(
1545            session.connection_id,
1546            proto::UpdateLanguageServer {
1547                project_id: project_id.to_proto(),
1548                language_server_id: language_server.id,
1549                variant: Some(
1550                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1551                        proto::LspDiskBasedDiagnosticsUpdated {},
1552                    ),
1553                ),
1554            },
1555        )?;
1556    }
1557
1558    Ok(())
1559}
1560
1561async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1562    let sender_id = session.connection_id;
1563    let project_id = ProjectId::from_proto(request.project_id);
1564
1565    let (room, project) = &*session
1566        .db()
1567        .await
1568        .leave_project(project_id, sender_id)
1569        .await?;
1570    tracing::info!(
1571        %project_id,
1572        host_user_id = %project.host_user_id,
1573        host_connection_id = %project.host_connection_id,
1574        "leave project"
1575    );
1576
1577    project_left(&project, &session);
1578    room_updated(&room, &session.peer);
1579
1580    Ok(())
1581}
1582
1583async fn update_project(
1584    request: proto::UpdateProject,
1585    response: Response<proto::UpdateProject>,
1586    session: Session,
1587) -> Result<()> {
1588    let project_id = ProjectId::from_proto(request.project_id);
1589    let (room, guest_connection_ids) = &*session
1590        .db()
1591        .await
1592        .update_project(project_id, session.connection_id, &request.worktrees)
1593        .await?;
1594    broadcast(
1595        Some(session.connection_id),
1596        guest_connection_ids.iter().copied(),
1597        |connection_id| {
1598            session
1599                .peer
1600                .forward_send(session.connection_id, connection_id, request.clone())
1601        },
1602    );
1603    room_updated(&room, &session.peer);
1604    response.send(proto::Ack {})?;
1605
1606    Ok(())
1607}
1608
1609async fn update_worktree(
1610    request: proto::UpdateWorktree,
1611    response: Response<proto::UpdateWorktree>,
1612    session: Session,
1613) -> Result<()> {
1614    let guest_connection_ids = session
1615        .db()
1616        .await
1617        .update_worktree(&request, session.connection_id)
1618        .await?;
1619
1620    broadcast(
1621        Some(session.connection_id),
1622        guest_connection_ids.iter().copied(),
1623        |connection_id| {
1624            session
1625                .peer
1626                .forward_send(session.connection_id, connection_id, request.clone())
1627        },
1628    );
1629    response.send(proto::Ack {})?;
1630    Ok(())
1631}
1632
1633async fn update_diagnostic_summary(
1634    message: proto::UpdateDiagnosticSummary,
1635    session: Session,
1636) -> Result<()> {
1637    let guest_connection_ids = session
1638        .db()
1639        .await
1640        .update_diagnostic_summary(&message, session.connection_id)
1641        .await?;
1642
1643    broadcast(
1644        Some(session.connection_id),
1645        guest_connection_ids.iter().copied(),
1646        |connection_id| {
1647            session
1648                .peer
1649                .forward_send(session.connection_id, connection_id, message.clone())
1650        },
1651    );
1652
1653    Ok(())
1654}
1655
1656async fn update_worktree_settings(
1657    message: proto::UpdateWorktreeSettings,
1658    session: Session,
1659) -> Result<()> {
1660    let guest_connection_ids = session
1661        .db()
1662        .await
1663        .update_worktree_settings(&message, session.connection_id)
1664        .await?;
1665
1666    broadcast(
1667        Some(session.connection_id),
1668        guest_connection_ids.iter().copied(),
1669        |connection_id| {
1670            session
1671                .peer
1672                .forward_send(session.connection_id, connection_id, message.clone())
1673        },
1674    );
1675
1676    Ok(())
1677}
1678
1679async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1680    broadcast_project_message(request.project_id, request, session).await
1681}
1682
1683async fn start_language_server(
1684    request: proto::StartLanguageServer,
1685    session: Session,
1686) -> Result<()> {
1687    let guest_connection_ids = session
1688        .db()
1689        .await
1690        .start_language_server(&request, session.connection_id)
1691        .await?;
1692
1693    broadcast(
1694        Some(session.connection_id),
1695        guest_connection_ids.iter().copied(),
1696        |connection_id| {
1697            session
1698                .peer
1699                .forward_send(session.connection_id, connection_id, request.clone())
1700        },
1701    );
1702    Ok(())
1703}
1704
1705async fn update_language_server(
1706    request: proto::UpdateLanguageServer,
1707    session: Session,
1708) -> Result<()> {
1709    session.executor.record_backtrace();
1710    let project_id = ProjectId::from_proto(request.project_id);
1711    let project_connection_ids = session
1712        .db()
1713        .await
1714        .project_connection_ids(project_id, session.connection_id)
1715        .await?;
1716    broadcast(
1717        Some(session.connection_id),
1718        project_connection_ids.iter().copied(),
1719        |connection_id| {
1720            session
1721                .peer
1722                .forward_send(session.connection_id, connection_id, request.clone())
1723        },
1724    );
1725    Ok(())
1726}
1727
1728async fn forward_project_request<T>(
1729    request: T,
1730    response: Response<T>,
1731    session: Session,
1732) -> Result<()>
1733where
1734    T: EntityMessage + RequestMessage,
1735{
1736    session.executor.record_backtrace();
1737    let project_id = ProjectId::from_proto(request.remote_entity_id());
1738    let host_connection_id = {
1739        let collaborators = session
1740            .db()
1741            .await
1742            .project_collaborators(project_id, session.connection_id)
1743            .await?;
1744        collaborators
1745            .iter()
1746            .find(|collaborator| collaborator.is_host)
1747            .ok_or_else(|| anyhow!("host not found"))?
1748            .connection_id
1749    };
1750
1751    let payload = session
1752        .peer
1753        .forward_request(session.connection_id, host_connection_id, request)
1754        .await?;
1755
1756    response.send(payload)?;
1757    Ok(())
1758}
1759
1760async fn create_buffer_for_peer(
1761    request: proto::CreateBufferForPeer,
1762    session: Session,
1763) -> Result<()> {
1764    session.executor.record_backtrace();
1765    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1766    session
1767        .peer
1768        .forward_send(session.connection_id, peer_id.into(), request)?;
1769    Ok(())
1770}
1771
1772async fn update_buffer(
1773    request: proto::UpdateBuffer,
1774    response: Response<proto::UpdateBuffer>,
1775    session: Session,
1776) -> Result<()> {
1777    session.executor.record_backtrace();
1778    let project_id = ProjectId::from_proto(request.project_id);
1779    let mut guest_connection_ids;
1780    let mut host_connection_id = None;
1781    {
1782        let collaborators = session
1783            .db()
1784            .await
1785            .project_collaborators(project_id, session.connection_id)
1786            .await?;
1787        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1788        for collaborator in collaborators.iter() {
1789            if collaborator.is_host {
1790                host_connection_id = Some(collaborator.connection_id);
1791            } else {
1792                guest_connection_ids.push(collaborator.connection_id);
1793            }
1794        }
1795    }
1796    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1797
1798    session.executor.record_backtrace();
1799    broadcast(
1800        Some(session.connection_id),
1801        guest_connection_ids,
1802        |connection_id| {
1803            session
1804                .peer
1805                .forward_send(session.connection_id, connection_id, request.clone())
1806        },
1807    );
1808    if host_connection_id != session.connection_id {
1809        session
1810            .peer
1811            .forward_request(session.connection_id, host_connection_id, request.clone())
1812            .await?;
1813    }
1814
1815    response.send(proto::Ack {})?;
1816    Ok(())
1817}
1818
1819async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1820    let project_id = ProjectId::from_proto(request.project_id);
1821    let project_connection_ids = session
1822        .db()
1823        .await
1824        .project_connection_ids(project_id, session.connection_id)
1825        .await?;
1826
1827    broadcast(
1828        Some(session.connection_id),
1829        project_connection_ids.iter().copied(),
1830        |connection_id| {
1831            session
1832                .peer
1833                .forward_send(session.connection_id, connection_id, request.clone())
1834        },
1835    );
1836    Ok(())
1837}
1838
1839async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1840    let project_id = ProjectId::from_proto(request.project_id);
1841    let project_connection_ids = session
1842        .db()
1843        .await
1844        .project_connection_ids(project_id, session.connection_id)
1845        .await?;
1846    broadcast(
1847        Some(session.connection_id),
1848        project_connection_ids.iter().copied(),
1849        |connection_id| {
1850            session
1851                .peer
1852                .forward_send(session.connection_id, connection_id, request.clone())
1853        },
1854    );
1855    Ok(())
1856}
1857
1858async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1859    broadcast_project_message(request.project_id, request, session).await
1860}
1861
1862async fn broadcast_project_message<T: EnvelopedMessage>(
1863    project_id: u64,
1864    request: T,
1865    session: Session,
1866) -> Result<()> {
1867    let project_id = ProjectId::from_proto(project_id);
1868    let project_connection_ids = session
1869        .db()
1870        .await
1871        .project_connection_ids(project_id, session.connection_id)
1872        .await?;
1873    broadcast(
1874        Some(session.connection_id),
1875        project_connection_ids.iter().copied(),
1876        |connection_id| {
1877            session
1878                .peer
1879                .forward_send(session.connection_id, connection_id, request.clone())
1880        },
1881    );
1882    Ok(())
1883}
1884
1885async fn follow(
1886    request: proto::Follow,
1887    response: Response<proto::Follow>,
1888    session: Session,
1889) -> Result<()> {
1890    let project_id = ProjectId::from_proto(request.project_id);
1891    let leader_id = request
1892        .leader_id
1893        .ok_or_else(|| anyhow!("invalid leader id"))?
1894        .into();
1895    let follower_id = session.connection_id;
1896
1897    {
1898        let project_connection_ids = session
1899            .db()
1900            .await
1901            .project_connection_ids(project_id, session.connection_id)
1902            .await?;
1903
1904        if !project_connection_ids.contains(&leader_id) {
1905            Err(anyhow!("no such peer"))?;
1906        }
1907    }
1908
1909    let mut response_payload = session
1910        .peer
1911        .forward_request(session.connection_id, leader_id, request)
1912        .await?;
1913    response_payload
1914        .views
1915        .retain(|view| view.leader_id != Some(follower_id.into()));
1916    response.send(response_payload)?;
1917
1918    let room = session
1919        .db()
1920        .await
1921        .follow(project_id, leader_id, follower_id)
1922        .await?;
1923    room_updated(&room, &session.peer);
1924
1925    Ok(())
1926}
1927
1928async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1929    let project_id = ProjectId::from_proto(request.project_id);
1930    let leader_id = request
1931        .leader_id
1932        .ok_or_else(|| anyhow!("invalid leader id"))?
1933        .into();
1934    let follower_id = session.connection_id;
1935
1936    if !session
1937        .db()
1938        .await
1939        .project_connection_ids(project_id, session.connection_id)
1940        .await?
1941        .contains(&leader_id)
1942    {
1943        Err(anyhow!("no such peer"))?;
1944    }
1945
1946    session
1947        .peer
1948        .forward_send(session.connection_id, leader_id, request)?;
1949
1950    let room = session
1951        .db()
1952        .await
1953        .unfollow(project_id, leader_id, follower_id)
1954        .await?;
1955    room_updated(&room, &session.peer);
1956
1957    Ok(())
1958}
1959
1960async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1961    let project_id = ProjectId::from_proto(request.project_id);
1962    let project_connection_ids = session
1963        .db
1964        .lock()
1965        .await
1966        .project_connection_ids(project_id, session.connection_id)
1967        .await?;
1968
1969    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1970        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1971        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1972        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1973    });
1974    for follower_peer_id in request.follower_ids.iter().copied() {
1975        let follower_connection_id = follower_peer_id.into();
1976        if project_connection_ids.contains(&follower_connection_id)
1977            && Some(follower_peer_id) != leader_id
1978        {
1979            session.peer.forward_send(
1980                session.connection_id,
1981                follower_connection_id,
1982                request.clone(),
1983            )?;
1984        }
1985    }
1986    Ok(())
1987}
1988
1989async fn get_users(
1990    request: proto::GetUsers,
1991    response: Response<proto::GetUsers>,
1992    session: Session,
1993) -> Result<()> {
1994    let user_ids = request
1995        .user_ids
1996        .into_iter()
1997        .map(UserId::from_proto)
1998        .collect();
1999    let users = session
2000        .db()
2001        .await
2002        .get_users_by_ids(user_ids)
2003        .await?
2004        .into_iter()
2005        .map(|user| proto::User {
2006            id: user.id.to_proto(),
2007            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2008            github_login: user.github_login,
2009        })
2010        .collect();
2011    response.send(proto::UsersResponse { users })?;
2012    Ok(())
2013}
2014
2015async fn fuzzy_search_users(
2016    request: proto::FuzzySearchUsers,
2017    response: Response<proto::FuzzySearchUsers>,
2018    session: Session,
2019) -> Result<()> {
2020    let query = request.query;
2021    let users = match query.len() {
2022        0 => vec![],
2023        1 | 2 => session
2024            .db()
2025            .await
2026            .get_user_by_github_login(&query)
2027            .await?
2028            .into_iter()
2029            .collect(),
2030        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2031    };
2032    let users = users
2033        .into_iter()
2034        .filter(|user| user.id != session.user_id)
2035        .map(|user| proto::User {
2036            id: user.id.to_proto(),
2037            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2038            github_login: user.github_login,
2039        })
2040        .collect();
2041    response.send(proto::UsersResponse { users })?;
2042    Ok(())
2043}
2044
2045async fn request_contact(
2046    request: proto::RequestContact,
2047    response: Response<proto::RequestContact>,
2048    session: Session,
2049) -> Result<()> {
2050    let requester_id = session.user_id;
2051    let responder_id = UserId::from_proto(request.responder_id);
2052    if requester_id == responder_id {
2053        return Err(anyhow!("cannot add yourself as a contact"))?;
2054    }
2055
2056    session
2057        .db()
2058        .await
2059        .send_contact_request(requester_id, responder_id)
2060        .await?;
2061
2062    // Update outgoing contact requests of requester
2063    let mut update = proto::UpdateContacts::default();
2064    update.outgoing_requests.push(responder_id.to_proto());
2065    for connection_id in session
2066        .connection_pool()
2067        .await
2068        .user_connection_ids(requester_id)
2069    {
2070        session.peer.send(connection_id, update.clone())?;
2071    }
2072
2073    // Update incoming contact requests of responder
2074    let mut update = proto::UpdateContacts::default();
2075    update
2076        .incoming_requests
2077        .push(proto::IncomingContactRequest {
2078            requester_id: requester_id.to_proto(),
2079            should_notify: true,
2080        });
2081    for connection_id in session
2082        .connection_pool()
2083        .await
2084        .user_connection_ids(responder_id)
2085    {
2086        session.peer.send(connection_id, update.clone())?;
2087    }
2088
2089    response.send(proto::Ack {})?;
2090    Ok(())
2091}
2092
2093async fn respond_to_contact_request(
2094    request: proto::RespondToContactRequest,
2095    response: Response<proto::RespondToContactRequest>,
2096    session: Session,
2097) -> Result<()> {
2098    let responder_id = session.user_id;
2099    let requester_id = UserId::from_proto(request.requester_id);
2100    let db = session.db().await;
2101    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2102        db.dismiss_contact_notification(responder_id, requester_id)
2103            .await?;
2104    } else {
2105        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2106
2107        db.respond_to_contact_request(responder_id, requester_id, accept)
2108            .await?;
2109        let requester_busy = db.is_user_busy(requester_id).await?;
2110        let responder_busy = db.is_user_busy(responder_id).await?;
2111
2112        let pool = session.connection_pool().await;
2113        // Update responder with new contact
2114        let mut update = proto::UpdateContacts::default();
2115        if accept {
2116            update
2117                .contacts
2118                .push(contact_for_user(requester_id, false, requester_busy, &pool));
2119        }
2120        update
2121            .remove_incoming_requests
2122            .push(requester_id.to_proto());
2123        for connection_id in pool.user_connection_ids(responder_id) {
2124            session.peer.send(connection_id, update.clone())?;
2125        }
2126
2127        // Update requester with new contact
2128        let mut update = proto::UpdateContacts::default();
2129        if accept {
2130            update
2131                .contacts
2132                .push(contact_for_user(responder_id, true, responder_busy, &pool));
2133        }
2134        update
2135            .remove_outgoing_requests
2136            .push(responder_id.to_proto());
2137        for connection_id in pool.user_connection_ids(requester_id) {
2138            session.peer.send(connection_id, update.clone())?;
2139        }
2140    }
2141
2142    response.send(proto::Ack {})?;
2143    Ok(())
2144}
2145
2146async fn remove_contact(
2147    request: proto::RemoveContact,
2148    response: Response<proto::RemoveContact>,
2149    session: Session,
2150) -> Result<()> {
2151    let requester_id = session.user_id;
2152    let responder_id = UserId::from_proto(request.user_id);
2153    let db = session.db().await;
2154    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2155
2156    let pool = session.connection_pool().await;
2157    // Update outgoing contact requests of requester
2158    let mut update = proto::UpdateContacts::default();
2159    if contact_accepted {
2160        update.remove_contacts.push(responder_id.to_proto());
2161    } else {
2162        update
2163            .remove_outgoing_requests
2164            .push(responder_id.to_proto());
2165    }
2166    for connection_id in pool.user_connection_ids(requester_id) {
2167        session.peer.send(connection_id, update.clone())?;
2168    }
2169
2170    // Update incoming contact requests of responder
2171    let mut update = proto::UpdateContacts::default();
2172    if contact_accepted {
2173        update.remove_contacts.push(requester_id.to_proto());
2174    } else {
2175        update
2176            .remove_incoming_requests
2177            .push(requester_id.to_proto());
2178    }
2179    for connection_id in pool.user_connection_ids(responder_id) {
2180        session.peer.send(connection_id, update.clone())?;
2181    }
2182
2183    response.send(proto::Ack {})?;
2184    Ok(())
2185}
2186
2187async fn create_channel(
2188    request: proto::CreateChannel,
2189    response: Response<proto::CreateChannel>,
2190    session: Session,
2191) -> Result<()> {
2192    let db = session.db().await;
2193    let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2194
2195    if let Some(live_kit) = session.live_kit_client.as_ref() {
2196        live_kit.create_room(live_kit_room.clone()).await?;
2197    }
2198
2199    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2200    let id = db
2201        .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2202        .await?;
2203
2204    let channel = proto::Channel {
2205        id: id.to_proto(),
2206        name: request.name,
2207        parent_id: request.parent_id,
2208    };
2209
2210    response.send(proto::ChannelResponse {
2211        channel: Some(channel.clone()),
2212    })?;
2213
2214    let mut update = proto::UpdateChannels::default();
2215    update.channels.push(channel);
2216
2217    let user_ids_to_notify = if let Some(parent_id) = parent_id {
2218        db.get_channel_members(parent_id).await?
2219    } else {
2220        vec![session.user_id]
2221    };
2222
2223    let connection_pool = session.connection_pool().await;
2224    for user_id in user_ids_to_notify {
2225        for connection_id in connection_pool.user_connection_ids(user_id) {
2226            let mut update = update.clone();
2227            if user_id == session.user_id {
2228                update.channel_permissions.push(proto::ChannelPermission {
2229                    channel_id: id.to_proto(),
2230                    is_admin: true,
2231                });
2232            }
2233            session.peer.send(connection_id, update)?;
2234        }
2235    }
2236
2237    Ok(())
2238}
2239
2240async fn remove_channel(
2241    request: proto::RemoveChannel,
2242    response: Response<proto::RemoveChannel>,
2243    session: Session,
2244) -> Result<()> {
2245    let db = session.db().await;
2246
2247    let channel_id = request.channel_id;
2248    let (removed_channels, member_ids) = db
2249        .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2250        .await?;
2251    response.send(proto::Ack {})?;
2252
2253    // Notify members of removed channels
2254    let mut update = proto::UpdateChannels::default();
2255    update
2256        .remove_channels
2257        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2258
2259    let connection_pool = session.connection_pool().await;
2260    for member_id in member_ids {
2261        for connection_id in connection_pool.user_connection_ids(member_id) {
2262            session.peer.send(connection_id, update.clone())?;
2263        }
2264    }
2265
2266    Ok(())
2267}
2268
2269async fn invite_channel_member(
2270    request: proto::InviteChannelMember,
2271    response: Response<proto::InviteChannelMember>,
2272    session: Session,
2273) -> Result<()> {
2274    let db = session.db().await;
2275    let channel_id = ChannelId::from_proto(request.channel_id);
2276    let invitee_id = UserId::from_proto(request.user_id);
2277    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2278        .await?;
2279
2280    let (channel, _) = db
2281        .get_channel(channel_id, session.user_id)
2282        .await?
2283        .ok_or_else(|| anyhow!("channel not found"))?;
2284
2285    let mut update = proto::UpdateChannels::default();
2286    update.channel_invitations.push(proto::Channel {
2287        id: channel.id.to_proto(),
2288        name: channel.name,
2289        parent_id: None,
2290    });
2291    for connection_id in session
2292        .connection_pool()
2293        .await
2294        .user_connection_ids(invitee_id)
2295    {
2296        session.peer.send(connection_id, update.clone())?;
2297    }
2298
2299    response.send(proto::Ack {})?;
2300    Ok(())
2301}
2302
2303async fn remove_channel_member(
2304    request: proto::RemoveChannelMember,
2305    response: Response<proto::RemoveChannelMember>,
2306    session: Session,
2307) -> Result<()> {
2308    let db = session.db().await;
2309    let channel_id = ChannelId::from_proto(request.channel_id);
2310    let member_id = UserId::from_proto(request.user_id);
2311
2312    db.remove_channel_member(channel_id, member_id, session.user_id)
2313        .await?;
2314
2315    let mut update = proto::UpdateChannels::default();
2316    update.remove_channels.push(channel_id.to_proto());
2317
2318    for connection_id in session
2319        .connection_pool()
2320        .await
2321        .user_connection_ids(member_id)
2322    {
2323        session.peer.send(connection_id, update.clone())?;
2324    }
2325
2326    response.send(proto::Ack {})?;
2327    Ok(())
2328}
2329
2330async fn set_channel_member_admin(
2331    request: proto::SetChannelMemberAdmin,
2332    response: Response<proto::SetChannelMemberAdmin>,
2333    session: Session,
2334) -> Result<()> {
2335    let db = session.db().await;
2336    let channel_id = ChannelId::from_proto(request.channel_id);
2337    let member_id = UserId::from_proto(request.user_id);
2338    db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2339        .await?;
2340
2341    let (channel, has_accepted) = db
2342        .get_channel(channel_id, member_id)
2343        .await?
2344        .ok_or_else(|| anyhow!("channel not found"))?;
2345
2346    let mut update = proto::UpdateChannels::default();
2347    if has_accepted {
2348        update.channel_permissions.push(proto::ChannelPermission {
2349            channel_id: channel.id.to_proto(),
2350            is_admin: request.admin,
2351        });
2352    }
2353
2354    for connection_id in session
2355        .connection_pool()
2356        .await
2357        .user_connection_ids(member_id)
2358    {
2359        session.peer.send(connection_id, update.clone())?;
2360    }
2361
2362    response.send(proto::Ack {})?;
2363    Ok(())
2364}
2365
2366async fn rename_channel(
2367    request: proto::RenameChannel,
2368    response: Response<proto::RenameChannel>,
2369    session: Session,
2370) -> Result<()> {
2371    let db = session.db().await;
2372    let channel_id = ChannelId::from_proto(request.channel_id);
2373    let new_name = db
2374        .rename_channel(channel_id, session.user_id, &request.name)
2375        .await?;
2376
2377    let channel = proto::Channel {
2378        id: request.channel_id,
2379        name: new_name,
2380        parent_id: None,
2381    };
2382    response.send(proto::ChannelResponse {
2383        channel: Some(channel.clone()),
2384    })?;
2385    let mut update = proto::UpdateChannels::default();
2386    update.channels.push(channel);
2387
2388    let member_ids = db.get_channel_members(channel_id).await?;
2389
2390    let connection_pool = session.connection_pool().await;
2391    for member_id in member_ids {
2392        for connection_id in connection_pool.user_connection_ids(member_id) {
2393            session.peer.send(connection_id, update.clone())?;
2394        }
2395    }
2396
2397    Ok(())
2398}
2399
2400async fn get_channel_members(
2401    request: proto::GetChannelMembers,
2402    response: Response<proto::GetChannelMembers>,
2403    session: Session,
2404) -> Result<()> {
2405    let db = session.db().await;
2406    let channel_id = ChannelId::from_proto(request.channel_id);
2407    let members = db
2408        .get_channel_member_details(channel_id, session.user_id)
2409        .await?;
2410    response.send(proto::GetChannelMembersResponse { members })?;
2411    Ok(())
2412}
2413
2414async fn respond_to_channel_invite(
2415    request: proto::RespondToChannelInvite,
2416    response: Response<proto::RespondToChannelInvite>,
2417    session: Session,
2418) -> Result<()> {
2419    let db = session.db().await;
2420    let channel_id = ChannelId::from_proto(request.channel_id);
2421    db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2422        .await?;
2423
2424    let mut update = proto::UpdateChannels::default();
2425    update
2426        .remove_channel_invitations
2427        .push(channel_id.to_proto());
2428    if request.accept {
2429        let result = db.get_channels_for_user(session.user_id).await?;
2430        update
2431            .channels
2432            .extend(result.channels.into_iter().map(|channel| proto::Channel {
2433                id: channel.id.to_proto(),
2434                name: channel.name,
2435                parent_id: channel.parent_id.map(ChannelId::to_proto),
2436            }));
2437        update
2438            .channel_participants
2439            .extend(
2440                result
2441                    .channel_participants
2442                    .into_iter()
2443                    .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2444                        channel_id: channel_id.to_proto(),
2445                        participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2446                    }),
2447            );
2448        update
2449            .channel_permissions
2450            .extend(
2451                result
2452                    .channels_with_admin_privileges
2453                    .into_iter()
2454                    .map(|channel_id| proto::ChannelPermission {
2455                        channel_id: channel_id.to_proto(),
2456                        is_admin: true,
2457                    }),
2458            );
2459    }
2460    session.peer.send(session.connection_id, update)?;
2461    response.send(proto::Ack {})?;
2462
2463    Ok(())
2464}
2465
2466async fn join_channel(
2467    request: proto::JoinChannel,
2468    response: Response<proto::JoinChannel>,
2469    session: Session,
2470) -> Result<()> {
2471    let channel_id = ChannelId::from_proto(request.channel_id);
2472
2473    let joined_room = {
2474        leave_room_for_session(&session).await?;
2475        let db = session.db().await;
2476
2477        let room_id = db.room_id_for_channel(channel_id).await?;
2478
2479        let joined_room = db
2480            .join_room(room_id, session.user_id, session.connection_id)
2481            .await?;
2482
2483        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2484            let token = live_kit
2485                .room_token(
2486                    &joined_room.room.live_kit_room,
2487                    &session.user_id.to_string(),
2488                )
2489                .trace_err()?;
2490
2491            Some(LiveKitConnectionInfo {
2492                server_url: live_kit.url().into(),
2493                token,
2494            })
2495        });
2496
2497        response.send(proto::JoinRoomResponse {
2498            room: Some(joined_room.room.clone()),
2499            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2500            live_kit_connection_info,
2501        })?;
2502
2503        room_updated(&joined_room.room, &session.peer);
2504
2505        joined_room.into_inner()
2506    };
2507
2508    channel_updated(
2509        channel_id,
2510        &joined_room.room,
2511        &joined_room.channel_members,
2512        &session.peer,
2513        &*session.connection_pool().await,
2514    );
2515
2516    update_user_contacts(session.user_id, &session).await?;
2517
2518    Ok(())
2519}
2520
2521async fn join_channel_buffer(
2522    request: proto::JoinChannelBuffer,
2523    response: Response<proto::JoinChannelBuffer>,
2524    session: Session,
2525) -> Result<()> {
2526    let db = session.db().await;
2527    let channel_id = ChannelId::from_proto(request.channel_id);
2528
2529    let open_response = db
2530        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2531        .await?;
2532
2533    let replica_id = open_response.replica_id;
2534    let collaborators = open_response.collaborators.clone();
2535
2536    response.send(open_response)?;
2537
2538    let update = AddChannelBufferCollaborator {
2539        channel_id: channel_id.to_proto(),
2540        collaborator: Some(proto::Collaborator {
2541            user_id: session.user_id.to_proto(),
2542            peer_id: Some(session.connection_id.into()),
2543            replica_id,
2544        }),
2545    };
2546    channel_buffer_updated(
2547        session.connection_id,
2548        collaborators
2549            .iter()
2550            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2551        &update,
2552        &session.peer,
2553    );
2554
2555    Ok(())
2556}
2557
2558async fn update_channel_buffer(
2559    request: proto::UpdateChannelBuffer,
2560    session: Session,
2561) -> Result<()> {
2562    let db = session.db().await;
2563    let channel_id = ChannelId::from_proto(request.channel_id);
2564
2565    let collaborators = db
2566        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2567        .await?;
2568
2569    channel_buffer_updated(
2570        session.connection_id,
2571        collaborators,
2572        &proto::UpdateChannelBuffer {
2573            channel_id: channel_id.to_proto(),
2574            operations: request.operations,
2575        },
2576        &session.peer,
2577    );
2578    Ok(())
2579}
2580
2581async fn rejoin_channel_buffers(
2582    request: proto::RejoinChannelBuffers,
2583    response: Response<proto::RejoinChannelBuffers>,
2584    session: Session,
2585) -> Result<()> {
2586    let db = session.db().await;
2587    let buffers = db
2588        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2589        .await?;
2590
2591    for buffer in &buffers {
2592        let collaborators_to_notify = buffer
2593            .buffer
2594            .collaborators
2595            .iter()
2596            .filter_map(|c| Some(c.peer_id?.into()));
2597        channel_buffer_updated(
2598            session.connection_id,
2599            collaborators_to_notify,
2600            &proto::UpdateChannelBufferCollaborator {
2601                channel_id: buffer.buffer.channel_id,
2602                old_peer_id: Some(buffer.old_connection_id.into()),
2603                new_peer_id: Some(session.connection_id.into()),
2604            },
2605            &session.peer,
2606        );
2607    }
2608
2609    response.send(proto::RejoinChannelBuffersResponse {
2610        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2611    })?;
2612
2613    Ok(())
2614}
2615
2616async fn leave_channel_buffer(
2617    request: proto::LeaveChannelBuffer,
2618    response: Response<proto::LeaveChannelBuffer>,
2619    session: Session,
2620) -> Result<()> {
2621    let db = session.db().await;
2622    let channel_id = ChannelId::from_proto(request.channel_id);
2623
2624    let collaborators_to_notify = db
2625        .leave_channel_buffer(channel_id, session.connection_id)
2626        .await?;
2627
2628    response.send(Ack {})?;
2629
2630    channel_buffer_updated(
2631        session.connection_id,
2632        collaborators_to_notify,
2633        &proto::RemoveChannelBufferCollaborator {
2634            channel_id: channel_id.to_proto(),
2635            peer_id: Some(session.connection_id.into()),
2636        },
2637        &session.peer,
2638    );
2639
2640    Ok(())
2641}
2642
2643fn channel_buffer_updated<T: EnvelopedMessage>(
2644    sender_id: ConnectionId,
2645    collaborators: impl IntoIterator<Item = ConnectionId>,
2646    message: &T,
2647    peer: &Peer,
2648) {
2649    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2650        peer.send(peer_id.into(), message.clone())
2651    });
2652}
2653
2654async fn send_channel_message(
2655    request: proto::SendChannelMessage,
2656    response: Response<proto::SendChannelMessage>,
2657    session: Session,
2658) -> Result<()> {
2659    // Validate the message body.
2660    let body = request.body.trim().to_string();
2661    if body.len() > MAX_MESSAGE_LEN {
2662        return Err(anyhow!("message is too long"))?;
2663    }
2664    if body.is_empty() {
2665        return Err(anyhow!("message can't be blank"))?;
2666    }
2667
2668    let timestamp = OffsetDateTime::now_utc();
2669    let nonce = request
2670        .nonce
2671        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2672
2673    let channel_id = ChannelId::from_proto(request.channel_id);
2674    let (message_id, connection_ids) = session
2675        .db()
2676        .await
2677        .create_channel_message(
2678            channel_id,
2679            session.user_id,
2680            &body,
2681            timestamp,
2682            nonce.clone().into(),
2683        )
2684        .await?;
2685    let message = proto::ChannelMessage {
2686        sender_id: session.user_id.to_proto(),
2687        id: message_id.to_proto(),
2688        body,
2689        timestamp: timestamp.unix_timestamp() as u64,
2690        nonce: Some(nonce),
2691    };
2692    broadcast(Some(session.connection_id), connection_ids, |connection| {
2693        session.peer.send(
2694            connection,
2695            proto::ChannelMessageSent {
2696                channel_id: channel_id.to_proto(),
2697                message: Some(message.clone()),
2698            },
2699        )
2700    });
2701    response.send(proto::SendChannelMessageResponse {
2702        message: Some(message),
2703    })?;
2704    Ok(())
2705}
2706
2707async fn join_channel_chat(
2708    request: proto::JoinChannelChat,
2709    response: Response<proto::JoinChannelChat>,
2710    session: Session,
2711) -> Result<()> {
2712    let channel_id = ChannelId::from_proto(request.channel_id);
2713
2714    let db = session.db().await;
2715    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2716        .await?;
2717    let messages = db
2718        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2719        .await?;
2720    response.send(proto::JoinChannelChatResponse {
2721        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2722        messages,
2723    })?;
2724    Ok(())
2725}
2726
2727async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2728    let channel_id = ChannelId::from_proto(request.channel_id);
2729    session
2730        .db()
2731        .await
2732        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2733        .await?;
2734    Ok(())
2735}
2736
2737async fn get_channel_messages(
2738    request: proto::GetChannelMessages,
2739    response: Response<proto::GetChannelMessages>,
2740    session: Session,
2741) -> Result<()> {
2742    let channel_id = ChannelId::from_proto(request.channel_id);
2743    let messages = session
2744        .db()
2745        .await
2746        .get_channel_messages(
2747            channel_id,
2748            session.user_id,
2749            MESSAGE_COUNT_PER_PAGE,
2750            Some(MessageId::from_proto(request.before_message_id)),
2751        )
2752        .await?;
2753    response.send(proto::GetChannelMessagesResponse {
2754        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2755        messages,
2756    })?;
2757    Ok(())
2758}
2759
2760async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2761    let project_id = ProjectId::from_proto(request.project_id);
2762    let project_connection_ids = session
2763        .db()
2764        .await
2765        .project_connection_ids(project_id, session.connection_id)
2766        .await?;
2767    broadcast(
2768        Some(session.connection_id),
2769        project_connection_ids.iter().copied(),
2770        |connection_id| {
2771            session
2772                .peer
2773                .forward_send(session.connection_id, connection_id, request.clone())
2774        },
2775    );
2776    Ok(())
2777}
2778
2779async fn get_private_user_info(
2780    _request: proto::GetPrivateUserInfo,
2781    response: Response<proto::GetPrivateUserInfo>,
2782    session: Session,
2783) -> Result<()> {
2784    let db = session.db().await;
2785
2786    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2787    let user = db
2788        .get_user_by_id(session.user_id)
2789        .await?
2790        .ok_or_else(|| anyhow!("user not found"))?;
2791    let flags = db.get_user_flags(session.user_id).await?;
2792
2793    response.send(proto::GetPrivateUserInfoResponse {
2794        metrics_id,
2795        staff: user.admin,
2796        flags,
2797    })?;
2798    Ok(())
2799}
2800
2801fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2802    match message {
2803        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2804        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2805        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2806        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2807        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2808            code: frame.code.into(),
2809            reason: frame.reason,
2810        })),
2811    }
2812}
2813
2814fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2815    match message {
2816        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2817        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2818        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2819        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2820        AxumMessage::Close(frame) => {
2821            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2822                code: frame.code.into(),
2823                reason: frame.reason,
2824            }))
2825        }
2826    }
2827}
2828
2829fn build_initial_channels_update(
2830    channels: ChannelsForUser,
2831    channel_invites: Vec<db::Channel>,
2832) -> proto::UpdateChannels {
2833    let mut update = proto::UpdateChannels::default();
2834
2835    for channel in channels.channels {
2836        update.channels.push(proto::Channel {
2837            id: channel.id.to_proto(),
2838            name: channel.name,
2839            parent_id: channel.parent_id.map(|id| id.to_proto()),
2840        });
2841    }
2842
2843    for (channel_id, participants) in channels.channel_participants {
2844        update
2845            .channel_participants
2846            .push(proto::ChannelParticipants {
2847                channel_id: channel_id.to_proto(),
2848                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2849            });
2850    }
2851
2852    update
2853        .channel_permissions
2854        .extend(
2855            channels
2856                .channels_with_admin_privileges
2857                .into_iter()
2858                .map(|id| proto::ChannelPermission {
2859                    channel_id: id.to_proto(),
2860                    is_admin: true,
2861                }),
2862        );
2863
2864    for channel in channel_invites {
2865        update.channel_invitations.push(proto::Channel {
2866            id: channel.id.to_proto(),
2867            name: channel.name,
2868            parent_id: None,
2869        });
2870    }
2871
2872    update
2873}
2874
2875fn build_initial_contacts_update(
2876    contacts: Vec<db::Contact>,
2877    pool: &ConnectionPool,
2878) -> proto::UpdateContacts {
2879    let mut update = proto::UpdateContacts::default();
2880
2881    for contact in contacts {
2882        match contact {
2883            db::Contact::Accepted {
2884                user_id,
2885                should_notify,
2886                busy,
2887            } => {
2888                update
2889                    .contacts
2890                    .push(contact_for_user(user_id, should_notify, busy, &pool));
2891            }
2892            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2893            db::Contact::Incoming {
2894                user_id,
2895                should_notify,
2896            } => update
2897                .incoming_requests
2898                .push(proto::IncomingContactRequest {
2899                    requester_id: user_id.to_proto(),
2900                    should_notify,
2901                }),
2902        }
2903    }
2904
2905    update
2906}
2907
2908fn contact_for_user(
2909    user_id: UserId,
2910    should_notify: bool,
2911    busy: bool,
2912    pool: &ConnectionPool,
2913) -> proto::Contact {
2914    proto::Contact {
2915        user_id: user_id.to_proto(),
2916        online: pool.is_user_online(user_id),
2917        busy,
2918        should_notify,
2919    }
2920}
2921
2922fn room_updated(room: &proto::Room, peer: &Peer) {
2923    broadcast(
2924        None,
2925        room.participants
2926            .iter()
2927            .filter_map(|participant| Some(participant.peer_id?.into())),
2928        |peer_id| {
2929            peer.send(
2930                peer_id.into(),
2931                proto::RoomUpdated {
2932                    room: Some(room.clone()),
2933                },
2934            )
2935        },
2936    );
2937}
2938
2939fn channel_updated(
2940    channel_id: ChannelId,
2941    room: &proto::Room,
2942    channel_members: &[UserId],
2943    peer: &Peer,
2944    pool: &ConnectionPool,
2945) {
2946    let participants = room
2947        .participants
2948        .iter()
2949        .map(|p| p.user_id)
2950        .collect::<Vec<_>>();
2951
2952    broadcast(
2953        None,
2954        channel_members
2955            .iter()
2956            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2957        |peer_id| {
2958            peer.send(
2959                peer_id.into(),
2960                proto::UpdateChannels {
2961                    channel_participants: vec![proto::ChannelParticipants {
2962                        channel_id: channel_id.to_proto(),
2963                        participant_user_ids: participants.clone(),
2964                    }],
2965                    ..Default::default()
2966                },
2967            )
2968        },
2969    );
2970}
2971
2972async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2973    let db = session.db().await;
2974
2975    let contacts = db.get_contacts(user_id).await?;
2976    let busy = db.is_user_busy(user_id).await?;
2977
2978    let pool = session.connection_pool().await;
2979    let updated_contact = contact_for_user(user_id, false, busy, &pool);
2980    for contact in contacts {
2981        if let db::Contact::Accepted {
2982            user_id: contact_user_id,
2983            ..
2984        } = contact
2985        {
2986            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2987                session
2988                    .peer
2989                    .send(
2990                        contact_conn_id,
2991                        proto::UpdateContacts {
2992                            contacts: vec![updated_contact.clone()],
2993                            remove_contacts: Default::default(),
2994                            incoming_requests: Default::default(),
2995                            remove_incoming_requests: Default::default(),
2996                            outgoing_requests: Default::default(),
2997                            remove_outgoing_requests: Default::default(),
2998                        },
2999                    )
3000                    .trace_err();
3001            }
3002        }
3003    }
3004    Ok(())
3005}
3006
3007async fn leave_room_for_session(session: &Session) -> Result<()> {
3008    let mut contacts_to_update = HashSet::default();
3009
3010    let room_id;
3011    let canceled_calls_to_user_ids;
3012    let live_kit_room;
3013    let delete_live_kit_room;
3014    let room;
3015    let channel_members;
3016    let channel_id;
3017
3018    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3019        contacts_to_update.insert(session.user_id);
3020
3021        for project in left_room.left_projects.values() {
3022            project_left(project, session);
3023        }
3024
3025        room_id = RoomId::from_proto(left_room.room.id);
3026        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3027        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3028        delete_live_kit_room = left_room.deleted;
3029        room = mem::take(&mut left_room.room);
3030        channel_members = mem::take(&mut left_room.channel_members);
3031        channel_id = left_room.channel_id;
3032
3033        room_updated(&room, &session.peer);
3034    } else {
3035        return Ok(());
3036    }
3037
3038    if let Some(channel_id) = channel_id {
3039        channel_updated(
3040            channel_id,
3041            &room,
3042            &channel_members,
3043            &session.peer,
3044            &*session.connection_pool().await,
3045        );
3046    }
3047
3048    {
3049        let pool = session.connection_pool().await;
3050        for canceled_user_id in canceled_calls_to_user_ids {
3051            for connection_id in pool.user_connection_ids(canceled_user_id) {
3052                session
3053                    .peer
3054                    .send(
3055                        connection_id,
3056                        proto::CallCanceled {
3057                            room_id: room_id.to_proto(),
3058                        },
3059                    )
3060                    .trace_err();
3061            }
3062            contacts_to_update.insert(canceled_user_id);
3063        }
3064    }
3065
3066    for contact_user_id in contacts_to_update {
3067        update_user_contacts(contact_user_id, &session).await?;
3068    }
3069
3070    if let Some(live_kit) = session.live_kit_client.as_ref() {
3071        live_kit
3072            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3073            .await
3074            .trace_err();
3075
3076        if delete_live_kit_room {
3077            live_kit.delete_room(live_kit_room).await.trace_err();
3078        }
3079    }
3080
3081    Ok(())
3082}
3083
3084async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3085    let left_channel_buffers = session
3086        .db()
3087        .await
3088        .leave_channel_buffers(session.connection_id)
3089        .await?;
3090
3091    for (channel_id, connections) in left_channel_buffers {
3092        channel_buffer_updated(
3093            session.connection_id,
3094            connections,
3095            &proto::RemoveChannelBufferCollaborator {
3096                channel_id: channel_id.to_proto(),
3097                peer_id: Some(session.connection_id.into()),
3098            },
3099            &session.peer,
3100        );
3101    }
3102
3103    Ok(())
3104}
3105
3106fn project_left(project: &db::LeftProject, session: &Session) {
3107    for connection_id in &project.connection_ids {
3108        if project.host_user_id == session.user_id {
3109            session
3110                .peer
3111                .send(
3112                    *connection_id,
3113                    proto::UnshareProject {
3114                        project_id: project.id.to_proto(),
3115                    },
3116                )
3117                .trace_err();
3118        } else {
3119            session
3120                .peer
3121                .send(
3122                    *connection_id,
3123                    proto::RemoveProjectCollaborator {
3124                        project_id: project.id.to_proto(),
3125                        peer_id: Some(session.connection_id.into()),
3126                    },
3127                )
3128                .trace_err();
3129        }
3130    }
3131}
3132
3133pub trait ResultExt {
3134    type Ok;
3135
3136    fn trace_err(self) -> Option<Self::Ok>;
3137}
3138
3139impl<T, E> ResultExt for Result<T, E>
3140where
3141    E: std::fmt::Debug,
3142{
3143    type Ok = T;
3144
3145    fn trace_err(self) -> Option<T> {
3146        match self {
3147            Ok(value) => Some(value),
3148            Err(error) => {
3149                tracing::error!("{:?}", error);
3150                None
3151            }
3152        }
3153    }
3154}